Program Listing for File expanding_array.h#
↰ Return to documentation for file (torch/csrc/api/include/torch/expanding_array.h)
#pragma once
#include <c10/util/ArrayRef.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <optional>
#include <algorithm>
#include <array>
#include <cstdint>
#include <initializer_list>
#include <string>
#include <vector>
namespace torch {
template <size_t D, typename T = int64_t>
class ExpandingArray {
public:
/*implicit*/ ExpandingArray(std::initializer_list<T> list)
: ExpandingArray(c10::ArrayRef<T>(list)) {}
/*implicit*/ ExpandingArray(std::vector<T> vec)
: ExpandingArray(c10::ArrayRef<T>(vec)) {}
/*implicit*/ ExpandingArray(c10::ArrayRef<T> values) {
// clang-format off
TORCH_CHECK(
values.size() == D,
"Expected ", D, " values, but instead got ", values.size());
// clang-format on
std::copy(values.begin(), values.end(), values_.begin());
}
/*implicit*/ ExpandingArray(T single_size) {
values_.fill(single_size);
}
/*implicit*/ ExpandingArray(const std::array<T, D>& values)
: values_(values) {}
std::array<T, D>& operator*() {
return values_;
}
const std::array<T, D>& operator*() const {
return values_;
}
std::array<T, D>* operator->() {
return &values_;
}
const std::array<T, D>* operator->() const {
return &values_;
}
operator c10::ArrayRef<T>() const {
return values_;
}
size_t size() const noexcept {
return D;
}
protected:
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
std::array<T, D> values_;
};
template <size_t D, typename T>
std::ostream& operator<<(
std::ostream& stream,
const ExpandingArray<D, T>& expanding_array) {
if (expanding_array.size() == 1) {
return stream << expanding_array->at(0);
}
return stream << static_cast<c10::ArrayRef<T>>(expanding_array);
}
template <size_t D, typename T = int64_t>
class ExpandingArrayWithOptionalElem
: public ExpandingArray<D, std::optional<T>> {
public:
using ExpandingArray<D, std::optional<T>>::ExpandingArray;
/*implicit*/ ExpandingArrayWithOptionalElem(std::initializer_list<T> list)
: ExpandingArrayWithOptionalElem(c10::ArrayRef<T>(list)) {}
/*implicit*/ ExpandingArrayWithOptionalElem(std::vector<T> vec)
: ExpandingArrayWithOptionalElem(c10::ArrayRef<T>(vec)) {}
/*implicit*/ ExpandingArrayWithOptionalElem(c10::ArrayRef<T> values)
: ExpandingArray<D, std::optional<T>>(0) {
// clang-format off
TORCH_CHECK(
values.size() == D,
"Expected ", D, " values, but instead got ", values.size());
// clang-format on
for (const auto i : c10::irange(this->values_.size())) {
this->values_[i] = values[i];
}
}
/*implicit*/ ExpandingArrayWithOptionalElem(T single_size)
: ExpandingArray<D, std::optional<T>>(0) {
for (const auto i : c10::irange(this->values_.size())) {
this->values_[i] = single_size;
}
}
/*implicit*/ ExpandingArrayWithOptionalElem(const std::array<T, D>& values)
: ExpandingArray<D, std::optional<T>>(0) {
for (const auto i : c10::irange(this->values_.size())) {
this->values_[i] = values[i];
}
}
};
template <size_t D, typename T>
std::ostream& operator<<(
std::ostream& stream,
const ExpandingArrayWithOptionalElem<D, T>& expanding_array_with_opt_elem) {
if (expanding_array_with_opt_elem.size() == 1) {
const auto& elem = expanding_array_with_opt_elem->at(0);
stream << (elem.has_value() ? c10::str(elem.value()) : "None");
} else {
std::vector<std::string> str_array;
for (const auto& elem : *expanding_array_with_opt_elem) {
str_array.emplace_back(
elem.has_value() ? c10::str(elem.value()) : "None");
}
stream << c10::ArrayRef<std::string>(str_array);
}
return stream;
}
} // namespace torch