Rate this Page

Program Listing for File common.h#

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/common.h)

 public:
  explicit MImpl(int value_) : value(value_) {}
  torch::Tensor forward(int a, int b = 2, double c = 3.0) {
    return torch::tensor(a + b + c);
  }
 private:
  int value;
};
TORCH_MODULE(M);
seq->forward(1);
If MImpl's forward() method has default arguments, please make sure
the forward() method is declared with a corresponding
`FORWARD_HAS_DEFAULT_ARGS` macro.
 public:
  explicit MImpl(int value_) : value(value_) {}
  torch::Tensor forward(int a, int b = 2, double c = 3.0) {
    return torch::tensor(a + b + c);
  }
 protected:
  /*
  NOTE: looking at the argument list of `forward`:
  `forward(int a, int b = 2, double c = 3.0)`
  we saw the following default arguments:
  ----------------------------------------------------------------
  0-based index of default |         Default value of arg
  arg in forward arg list  |  (wrapped by `torch::nn::AnyValue()`)
  ----------------------------------------------------------------
              1            |       torch::nn::AnyValue(2)
              2            |       torch::nn::AnyValue(3.0)
  ----------------------------------------------------------------
  Thus we pass the following arguments to the `FORWARD_HAS_DEFAULT_ARGS`
  macro:
  *‍/
  FORWARD_HAS_DEFAULT_ARGS({1, torch::nn::AnyValue(2)}, {2,
  torch::nn::AnyValue(3.0)})
 private:
  int value;
};
TORCH_MODULE(M);
#pragma once

#define FORWARD_HAS_DEFAULT_ARGS(...)                                    \
  template <typename ModuleType, typename... ArgumentTypes>              \
  friend struct torch::nn::AnyModuleHolder;                              \
  bool _forward_has_default_args() override {                            \
    return true;                                                         \
  }                                                                      \
  unsigned int _forward_num_required_args() override {                   \
    std::vector<std::pair<unsigned int, torch::nn::AnyValue>> args_info{ \
        __VA_ARGS__};                                                    \
    return std::begin(args_info)->first;                                 \
  }                                                                      \
  std::vector<torch::nn::AnyValue> _forward_populate_default_args(       \
      std::vector<torch::nn::AnyValue>&& arguments) override {           \
    std::vector<std::pair<unsigned int, torch::nn::AnyValue>> args_info{ \
        __VA_ARGS__};                                                    \
    unsigned int num_all_args = std::rbegin(args_info)->first + 1;       \
    TORCH_INTERNAL_ASSERT(                                               \
        arguments.size() >= _forward_num_required_args() &&              \
        arguments.size() <= num_all_args);                               \
    std::vector<torch::nn::AnyValue> ret = std::move(arguments);         \
    ret.reserve(num_all_args);                                           \
    for (auto& arg_info : args_info) {                                   \
      if (arg_info.first > ret.size() - 1)                               \
        ret.emplace_back(std::move(arg_info.second));                    \
    }                                                                    \
    return ret;                                                          \
  }