Program Listing for File moduledict.h#
↰ Return to documentation for file (torch/csrc/api/include/torch/nn/modules/container/moduledict.h)
#pragma once
#include <torch/nn/cloneable.h>
#include <torch/nn/module.h>
#include <torch/ordered_dict.h>
#include <vector>
namespace torch::nn {
class ModuleDictImpl : public Cloneable<ModuleDictImpl> {
public:
using Iterator =
torch::OrderedDict<std::string, std::shared_ptr<Module>>::Iterator;
using ConstIterator =
torch::OrderedDict<std::string, std::shared_ptr<Module>>::ConstIterator;
ModuleDictImpl() = default;
explicit ModuleDictImpl(
const std::vector<std::pair<std::string, std::shared_ptr<Module>>>&
modules) {
update(modules);
}
explicit ModuleDictImpl(
const torch::OrderedDict<std::string, std::shared_ptr<Module>>& modules) {
update(modules);
}
std::vector<std::pair<std::string, std::shared_ptr<Module>>> items() const {
return modules_.pairs();
}
std::vector<std::string> keys() const {
return modules_.keys();
}
std::vector<std::shared_ptr<Module>> values() const {
return modules_.values();
}
Iterator begin() {
return modules_.begin();
}
ConstIterator begin() const {
return modules_.begin();
}
Iterator end() {
return modules_.end();
}
ConstIterator end() const {
return modules_.end();
}
size_t size() const noexcept {
return modules_.size();
}
bool empty() const noexcept {
return modules_.is_empty();
}
bool contains(const std::string& key) const noexcept {
return modules_.contains(key);
}
void clear() {
// Not remove the registration of modules to make it consistent with python
// version.
modules_.clear();
}
std::shared_ptr<Module> clone(
const std::optional<Device>& device = std::nullopt) const override {
auto clone = std::make_shared<ModuleDictImpl>();
for (const auto& module : modules_) {
clone->insert(module.key(), module.value()->clone(device));
}
return clone;
}
void reset() override {}
void pretty_print(std::ostream& stream) const override {
stream << "torch::nn::ModuleDict";
}
std::shared_ptr<Module> operator[](const std::string& key) const {
return modules_[key];
}
template <typename T>
T& at(const std::string& key) {
static_assert(
torch::detail::is_module<T>::value,
"Can only call ModuleList::at with an nn::Module type");
auto module = modules_[key]->as<T>();
TORCH_CHECK(
module,
"Unable to cast module[",
key,
"] to ",
c10::demangle(typeid(T).name()));
return *module;
}
template <typename T>
const T& at(const std::string& key) const {
static_assert(
torch::detail::is_module<T>::value,
"Can only call ModuleList::at with an nn::Module type");
const auto module = modules_[key]->as<T>();
TORCH_CHECK(
module,
"Unable to cast module[",
key,
"] to ",
c10::demangle(typeid(T).name()));
return *module;
}
std::shared_ptr<Module> pop(const std::string& key) {
auto module = modules_[key];
modules_.erase(key);
// Not remove the registration of the module to make it consistent with
// python version.
return module;
}
void update(
const std::vector<std::pair<std::string, std::shared_ptr<Module>>>&
modules) {
for (auto& item : modules) {
insert(item.first, item.second);
}
}
template <typename Container>
void update(const Container& container) {
for (auto& item : container) {
insert(item.key(), item.value());
}
}
private:
torch::OrderedDict<std::string, std::shared_ptr<Module>> modules_;
void insert(const std::string& key, std::shared_ptr<Module> module) {
if (contains(key)) {
modules_[key] = std::move(module);
replace_module(key, modules_[key]);
} else {
modules_.insert(key, std::move(module));
register_module(key, modules_.back().value());
}
}
};
TORCH_MODULE(ModuleDict);
} // namespace torch::nn