Define FORWARD_HAS_DEFAULT_ARGS#
Defined in File common.h
Define Documentation#
-
FORWARD_HAS_DEFAULT_ARGS(...)
template <typename ModuleType, typename... ArgumentTypes> \
friend struct
torch::nn::AnyModuleHolder; \
bool _forward_has_default_args() override { \
return true; \
} \
unsigned int _forward_num_required_args() override { \
std::vector<std::pair<unsigned int,
torch::nn::AnyValue>> args_info{ \
__VA_ARGS__}; \
return std::begin(args_info)->first; \
} \
std::vector<
torch::nn::AnyValue> _forward_populate_default_args( \
std::vector<
torch::nn::AnyValue>&& arguments) override { \
std::vector<std::pair<unsigned int,
torch::nn::AnyValue>> args_info{ \
__VA_ARGS__}; \
unsigned int num_all_args = std::rbegin(args_info)->first + 1; \
TORCH_INTERNAL_ASSERT( \
arguments.size() >= _forward_num_required_args() && \
arguments.size() <= num_all_args); \
std::vector<
torch::nn::AnyValue> ret = std::move(arguments); \
ret.reserve(num_all_args); \
for (auto& arg_info : args_info) { \
if (arg_info.first > ret.size() - 1) \
ret.emplace_back(std::move(arg_info.second)); \
} \
return ret; \
}
# This macro enables a module with default arguments in its forward method to be used in a Sequential module.
Example usage:
Let’s say we have a module declared like this:
struct MImpl : torch::nn::Module { public: explicit MImpl(int value_) : value(value_) {} torch::Tensor forward(int a, int b = 2, double c = 3.0) { return torch::tensor(a + b + c); } private: int value; }; TORCH_MODULE(M);
If we try to use it in a Sequential module and run forward:
torch::nn::Sequential seq(M(1)); seq->forward(1);
We will receive the following error message:
MImpl's forward() method expects 3 argument(s), but received 1. If MImpl's forward() method has default arguments, please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.
The right way to fix this error is to use the
FORWARD_HAS_DEFAULT_ARGS
macro when declaring the module:Now, running the following would work:struct MImpl : torch::nn::Module { public: explicit MImpl(int value_) : value(value_) {} torch::Tensor forward(int a, int b = 2, double c = 3.0) { return torch::tensor(a + b + c); } protected: /* NOTE: looking at the argument list of `forward`: `forward(int a, int b = 2, double c = 3.0)` we saw the following default arguments: ---------------------------------------------------------------- 0-based index of default | Default value of arg arg in forward arg list | (wrapped by `torch::nn::AnyValue()`) ---------------------------------------------------------------- 1 | torch::nn::AnyValue(2) 2 | torch::nn::AnyValue(3.0) ---------------------------------------------------------------- Thus we pass the following arguments to the `FORWARD_HAS_DEFAULT_ARGS` macro: *‍/ FORWARD_HAS_DEFAULT_ARGS({1, torch::nn::AnyValue(2)}, {2, torch::nn::AnyValue(3.0)}) private: int value; }; TORCH_MODULE(M);
torch::nn::Sequential seq(M(1)); seq->forward(1); // This correctly populates the default arguments for `MImpl::forward`