Rate this Page

Program Listing for File expanding_array.h#

Return to documentation for file (torch/csrc/api/include/torch/expanding_array.h)

#pragma once

#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <optional>

#include <algorithm>
#include <array>
#include <cstdint>
#include <initializer_list>
#include <string>
#include <vector>

namespace torch {

template <size_t D, typename T = int64_t>
class ExpandingArray {
 public:
  /*implicit*/ ExpandingArray(std::initializer_list<T> list)
      : ExpandingArray(c10::ArrayRef<T>(list)) {}

  /*implicit*/ ExpandingArray(std::vector<T> vec)
      : ExpandingArray(c10::ArrayRef<T>(vec)) {}

  /*implicit*/ ExpandingArray(c10::ArrayRef<T> values) {
    // clang-format off
    TORCH_CHECK(
        values.size() == D,
        "Expected ", D, " values, but instead got ", values.size());
    // clang-format on
    std::copy(values.begin(), values.end(), values_.begin());
  }

  /*implicit*/ ExpandingArray(T single_size) {
    values_.fill(single_size);
  }

  /*implicit*/ ExpandingArray(const std::array<T, D>& values)
      : values_(values) {}

  std::array<T, D>& operator*() {
    return values_;
  }

  const std::array<T, D>& operator*() const {
    return values_;
  }

  std::array<T, D>* operator->() {
    return &values_;
  }

  const std::array<T, D>* operator->() const {
    return &values_;
  }

  operator c10::ArrayRef<T>() const {
    return values_;
  }

  size_t size() const noexcept {
    return D;
  }

 protected:
  // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
  std::array<T, D> values_;
};

template <size_t D, typename T>
std::ostream& operator<<(
    std::ostream& stream,
    const ExpandingArray<D, T>& expanding_array) {
  if (expanding_array.size() == 1) {
    return stream << expanding_array->at(0);
  }
  return stream << static_cast<c10::ArrayRef<T>>(expanding_array);
}

template <size_t D, typename T = int64_t>
class ExpandingArrayWithOptionalElem
    : public ExpandingArray<D, std::optional<T>> {
 public:
  using ExpandingArray<D, std::optional<T>>::ExpandingArray;

  /*implicit*/ ExpandingArrayWithOptionalElem(std::initializer_list<T> list)
      : ExpandingArrayWithOptionalElem(c10::ArrayRef<T>(list)) {}

  /*implicit*/ ExpandingArrayWithOptionalElem(std::vector<T> vec)
      : ExpandingArrayWithOptionalElem(c10::ArrayRef<T>(vec)) {}

  /*implicit*/ ExpandingArrayWithOptionalElem(c10::ArrayRef<T> values)
      : ExpandingArray<D, std::optional<T>>(0) {
    // clang-format off
    TORCH_CHECK(
        values.size() == D,
        "Expected ", D, " values, but instead got ", values.size());
    // clang-format on
    for (const auto i : c10::irange(this->values_.size())) {
      this->values_[i] = values[i];
    }
  }

  /*implicit*/ ExpandingArrayWithOptionalElem(T single_size)
      : ExpandingArray<D, std::optional<T>>(0) {
    for (const auto i : c10::irange(this->values_.size())) {
      this->values_[i] = single_size;
    }
  }

  /*implicit*/ ExpandingArrayWithOptionalElem(const std::array<T, D>& values)
      : ExpandingArray<D, std::optional<T>>(0) {
    for (const auto i : c10::irange(this->values_.size())) {
      this->values_[i] = values[i];
    }
  }
};

template <size_t D, typename T>
std::ostream& operator<<(
    std::ostream& stream,
    const ExpandingArrayWithOptionalElem<D, T>& expanding_array_with_opt_elem) {
  if (expanding_array_with_opt_elem.size() == 1) {
    const auto& elem = expanding_array_with_opt_elem->at(0);
    stream << (elem.has_value() ? c10::str(elem.value()) : "None");
  } else {
    std::vector<std::string> str_array;
    for (const auto& elem : *expanding_array_with_opt_elem) {
      str_array.emplace_back(
          elem.has_value() ? c10::str(elem.value()) : "None");
    }
    stream << c10::ArrayRef<std::string>(str_array);
  }
  return stream;
}

} // namespace torch