Program Listing for File any_module_holder.h#
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h
)
#pragma once
#include <torch/csrc/utils/variadic.h>
#include <torch/nn/modules/container/any_value.h>
namespace torch::nn {
class Module;
// ~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModulePlaceholder ~~~~~~~~~~~~~~~~~~~~~~~~~~
struct AnyModulePlaceholder : public AnyValue::Placeholder {
using AnyValue::Placeholder::Placeholder;
virtual AnyValue forward(std::vector<AnyValue>&& arguments) = 0;
virtual std::shared_ptr<Module> ptr() = 0;
virtual std::unique_ptr<AnyModulePlaceholder> copy() const = 0;
virtual std::unique_ptr<AnyModulePlaceholder> clone_module(
std::optional<Device> device) const = 0;
};
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModuleHolder ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <typename ModuleType, typename... ArgumentTypes>
struct AnyModuleHolder : public AnyModulePlaceholder {
struct CheckedGetter {
template <typename T>
std::decay_t<T>&& operator()(size_t index) {
AT_ASSERT(index < arguments_.size());
auto& value = arguments_[index];
if (auto* maybe_value = value.template try_get<std::decay_t<T>>()) {
return std::move(*maybe_value);
}
TORCH_CHECK(
false,
"Expected argument #",
index,
" to be of type ",
c10::demangle(typeid(T).name()),
", but received value of type ",
c10::demangle(value.type_info().name()));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
std::vector<AnyValue>& arguments_;
};
struct InvokeForward {
template <typename... Ts>
AnyValue operator()(Ts&&... ts) {
return AnyValue(module_->forward(std::forward<Ts>(ts)...));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
std::shared_ptr<ModuleType>& module_;
};
explicit AnyModuleHolder(std::shared_ptr<ModuleType>&& module_)
: AnyModulePlaceholder(typeid(ModuleType)), module(std::move(module_)) {}
AnyValue forward(std::vector<AnyValue>&& arguments) override {
if (module->_forward_has_default_args()) {
TORCH_CHECK(
arguments.size() >= module->_forward_num_required_args() &&
arguments.size() <= sizeof...(ArgumentTypes),
c10::demangle(type_info.name()),
"'s forward() method expects at least ",
module->_forward_num_required_args(),
" argument(s) and at most ",
sizeof...(ArgumentTypes),
" argument(s), but received ",
arguments.size(),
".");
arguments = std::move(
module->_forward_populate_default_args(std::move(arguments)));
} else {
std::string use_default_args_macro_prompt = " If " +
c10::demangle(type_info.name()) +
"'s forward() method has default arguments, " +
"please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.";
TORCH_CHECK(
arguments.size() == sizeof...(ArgumentTypes),
c10::demangle(type_info.name()),
"'s forward() method expects ",
sizeof...(ArgumentTypes),
" argument(s), but received ",
arguments.size(),
".",
(arguments.size() < sizeof...(ArgumentTypes))
? use_default_args_macro_prompt
: "");
}
// FYI: During invocation of a module's `forward()` method, the values live
// in the `arguments` vector inside this function.
return torch::unpack<AnyValue, ArgumentTypes...>(
InvokeForward{module}, CheckedGetter{arguments});
}
std::shared_ptr<Module> ptr() override {
return module;
}
std::unique_ptr<AnyModulePlaceholder> copy() const override {
return std::make_unique<AnyModuleHolder>(*this);
}
std::unique_ptr<AnyModulePlaceholder> clone_module(
std::optional<Device> device) const override {
return std::make_unique<AnyModuleHolder>(
std::dynamic_pointer_cast<ModuleType>(module->clone(device)));
}
std::shared_ptr<ModuleType> module;
};
} // namespace torch::nn