Rate this Page

Program Listing for File moduledict.h#

Return to documentation for file (torch/csrc/api/include/torch/nn/modules/container/moduledict.h)

#pragma once

#include <torch/nn/cloneable.h>
#include <torch/nn/module.h>
#include <torch/ordered_dict.h>
#include <vector>

namespace torch::nn {

class ModuleDictImpl : public Cloneable<ModuleDictImpl> {
 public:
  using Iterator =
      torch::OrderedDict<std::string, std::shared_ptr<Module>>::Iterator;
  using ConstIterator =
      torch::OrderedDict<std::string, std::shared_ptr<Module>>::ConstIterator;

  ModuleDictImpl() = default;

  explicit ModuleDictImpl(
      const std::vector<std::pair<std::string, std::shared_ptr<Module>>>&
          modules) {
    update(modules);
  }

  explicit ModuleDictImpl(
      const torch::OrderedDict<std::string, std::shared_ptr<Module>>& modules) {
    update(modules);
  }

  std::vector<std::pair<std::string, std::shared_ptr<Module>>> items() const {
    return modules_.pairs();
  }

  std::vector<std::string> keys() const {
    return modules_.keys();
  }

  std::vector<std::shared_ptr<Module>> values() const {
    return modules_.values();
  }

  Iterator begin() {
    return modules_.begin();
  }

  ConstIterator begin() const {
    return modules_.begin();
  }

  Iterator end() {
    return modules_.end();
  }

  ConstIterator end() const {
    return modules_.end();
  }

  size_t size() const noexcept {
    return modules_.size();
  }

  bool empty() const noexcept {
    return modules_.is_empty();
  }

  bool contains(const std::string& key) const noexcept {
    return modules_.contains(key);
  }

  void clear() {
    // Not remove the registration of modules to make it consistent with python
    // version.
    modules_.clear();
  }

  std::shared_ptr<Module> clone(
      const std::optional<Device>& device = std::nullopt) const override {
    auto clone = std::make_shared<ModuleDictImpl>();
    for (const auto& module : modules_) {
      clone->insert(module.key(), module.value()->clone(device));
    }
    return clone;
  }

  void reset() override {}

  void pretty_print(std::ostream& stream) const override {
    stream << "torch::nn::ModuleDict";
  }

  std::shared_ptr<Module> operator[](const std::string& key) const {
    return modules_[key];
  }

  template <typename T>
  T& at(const std::string& key) {
    static_assert(
        torch::detail::is_module<T>::value,
        "Can only call ModuleList::at with an nn::Module type");
    auto module = modules_[key]->as<T>();
    TORCH_CHECK(
        module,
        "Unable to cast module[",
        key,
        "] to ",
        c10::demangle(typeid(T).name()));
    return *module;
  }

  template <typename T>
  const T& at(const std::string& key) const {
    static_assert(
        torch::detail::is_module<T>::value,
        "Can only call ModuleList::at with an nn::Module type");
    const auto module = modules_[key]->as<T>();
    TORCH_CHECK(
        module,
        "Unable to cast module[",
        key,
        "] to ",
        c10::demangle(typeid(T).name()));
    return *module;
  }

  std::shared_ptr<Module> pop(const std::string& key) {
    auto module = modules_[key];
    modules_.erase(key);
    // Not remove the registration of the module to make it consistent with
    // python version.
    return module;
  }

  void update(
      const std::vector<std::pair<std::string, std::shared_ptr<Module>>>&
          modules) {
    for (auto& item : modules) {
      insert(item.first, item.second);
    }
  }

  template <typename Container>
  void update(const Container& container) {
    for (auto& item : container) {
      insert(item.key(), item.value());
    }
  }

 private:
  torch::OrderedDict<std::string, std::shared_ptr<Module>> modules_;

  void insert(const std::string& key, std::shared_ptr<Module> module) {
    if (contains(key)) {
      modules_[key] = std::move(module);
      replace_module(key, modules_[key]);
    } else {
      modules_.insert(key, std::move(module));
      register_module(key, modules_.back().value());
    }
  }
};

TORCH_MODULE(ModuleDict);

} // namespace torch::nn