Rate this Page

Program Listing for File any_module_holder.h#

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/container/any_module_holder.h)

#pragma once

#include <torch/csrc/utils/variadic.h>
#include <torch/nn/modules/container/any_value.h>

namespace torch::nn {

class Module;

// ~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModulePlaceholder ~~~~~~~~~~~~~~~~~~~~~~~~~~

struct AnyModulePlaceholder : public AnyValue::Placeholder {
  using AnyValue::Placeholder::Placeholder;

  virtual AnyValue forward(std::vector<AnyValue>&& arguments) = 0;

  virtual std::shared_ptr<Module> ptr() = 0;

  virtual std::unique_ptr<AnyModulePlaceholder> copy() const = 0;

  virtual std::unique_ptr<AnyModulePlaceholder> clone_module(
      std::optional<Device> device) const = 0;
};

// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ AnyModuleHolder ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

template <typename ModuleType, typename... ArgumentTypes>
struct AnyModuleHolder : public AnyModulePlaceholder {
  struct CheckedGetter {
    template <typename T>
    std::decay_t<T>&& operator()(size_t index) {
      AT_ASSERT(index < arguments_.size());
      auto& value = arguments_[index];
      if (auto* maybe_value = value.template try_get<std::decay_t<T>>()) {
        return std::move(*maybe_value);
      }
      TORCH_CHECK(
          false,
          "Expected argument #",
          index,
          " to be of type ",
          c10::demangle(typeid(T).name()),
          ", but received value of type ",
          c10::demangle(value.type_info().name()));
    }
    // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
    std::vector<AnyValue>& arguments_;
  };

  struct InvokeForward {
    template <typename... Ts>
    AnyValue operator()(Ts&&... ts) {
      return AnyValue(module_->forward(std::forward<Ts>(ts)...));
    }
    // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
    std::shared_ptr<ModuleType>& module_;
  };

  explicit AnyModuleHolder(std::shared_ptr<ModuleType>&& module_)
      : AnyModulePlaceholder(typeid(ModuleType)), module(std::move(module_)) {}

  AnyValue forward(std::vector<AnyValue>&& arguments) override {
    if (module->_forward_has_default_args()) {
      TORCH_CHECK(
          arguments.size() >= module->_forward_num_required_args() &&
              arguments.size() <= sizeof...(ArgumentTypes),
          c10::demangle(type_info.name()),
          "'s forward() method expects at least ",
          module->_forward_num_required_args(),
          " argument(s) and at most ",
          sizeof...(ArgumentTypes),
          " argument(s), but received ",
          arguments.size(),
          ".");
      arguments = std::move(
          module->_forward_populate_default_args(std::move(arguments)));
    } else {
      std::string use_default_args_macro_prompt = " If " +
          c10::demangle(type_info.name()) +
          "'s forward() method has default arguments, " +
          "please make sure the forward() method is declared with a corresponding `FORWARD_HAS_DEFAULT_ARGS` macro.";
      TORCH_CHECK(
          arguments.size() == sizeof...(ArgumentTypes),
          c10::demangle(type_info.name()),
          "'s forward() method expects ",
          sizeof...(ArgumentTypes),
          " argument(s), but received ",
          arguments.size(),
          ".",
          (arguments.size() < sizeof...(ArgumentTypes))
              ? use_default_args_macro_prompt
              : "");
    }

    // FYI: During invocation of a module's `forward()` method, the values live
    // in the `arguments` vector inside this function.
    return torch::unpack<AnyValue, ArgumentTypes...>(
        InvokeForward{module}, CheckedGetter{arguments});
  }

  std::shared_ptr<Module> ptr() override {
    return module;
  }

  std::unique_ptr<AnyModulePlaceholder> copy() const override {
    return std::make_unique<AnyModuleHolder>(*this);
  }

  std::unique_ptr<AnyModulePlaceholder> clone_module(
      std::optional<Device> device) const override {
    return std::make_unique<AnyModuleHolder>(
        std::dynamic_pointer_cast<ModuleType>(module->clone(device)));
  }

  std::shared_ptr<ModuleType> module;
};

} // namespace torch::nn