Program Listing for File utils.h#
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/modules/utils.h)
#pragma once
#include <c10/util/ArrayRef.h>
#include <c10/util/irange.h>
#include <optional>
#include <vector>
namespace torch::nn::modules::utils {
// Reverse the order of `t` and repeat each element for `n` times.
// This can be used to translate padding arg used by Conv and Pooling modules
// to the ones used by `F::pad`.
//
// This mirrors `_reverse_repeat_tuple` in `torch/nn/modules/utils.py`.
inline std::vector<int64_t> _reverse_repeat_vector(
c10::ArrayRef<int64_t> t,
int64_t n) {
TORCH_INTERNAL_ASSERT(n >= 0);
std::vector<int64_t> ret;
ret.reserve(t.size() * n);
for (auto rit = t.rbegin(); rit != t.rend(); ++rit) {
for ([[maybe_unused]] const auto i : c10::irange(n)) {
ret.emplace_back(*rit);
}
}
return ret;
}
inline std::vector<int64_t> _list_with_default(
c10::ArrayRef<std::optional<int64_t>> out_size,
c10::IntArrayRef defaults) {
TORCH_CHECK(
defaults.size() > out_size.size(),
"Input dimension should be at least ",
out_size.size() + 1);
std::vector<int64_t> ret;
c10::IntArrayRef defaults_slice =
defaults.slice(defaults.size() - out_size.size(), out_size.size());
for (const auto i : c10::irange(out_size.size())) {
auto v = out_size.at(i);
auto d = defaults_slice.at(i);
ret.emplace_back(v.has_value() ? v.value() : d);
}
return ret;
}
} // namespace torch::nn::modules::utils