Class Module#
Defined in File module.h
Page Contents
Inheritance Relationships#
Base Type#
public std::enable_shared_from_this< Module >
Derived Types#
public torch::nn::Cloneable< SoftshrinkImpl >(Template Class Cloneable)public torch::nn::Cloneable< PReLUImpl >(Template Class Cloneable)public torch::nn::Cloneable< LogSoftmaxImpl >(Template Class Cloneable)public torch::nn::Cloneable< L1LossImpl >(Template Class Cloneable)public torch::nn::Cloneable< SequentialImpl >(Template Class Cloneable)public torch::nn::Cloneable< HardshrinkImpl >(Template Class Cloneable)public torch::nn::Cloneable< GLUImpl >(Template Class Cloneable)public torch::nn::Cloneable< RReLUImpl >(Template Class Cloneable)public torch::nn::Cloneable< ParameterDictImpl >(Template Class Cloneable)public torch::nn::Cloneable< IdentityImpl >(Template Class Cloneable)public torch::nn::Cloneable< FoldImpl >(Template Class Cloneable)public torch::nn::Cloneable< EmbeddingBagImpl >(Template Class Cloneable)public torch::nn::Cloneable< BilinearImpl >(Template Class Cloneable)public torch::nn::Cloneable< TripletMarginWithDistanceLossImpl >(Template Class Cloneable)public torch::nn::Cloneable< SoftminImpl >(Template Class Cloneable)public torch::nn::Cloneable< SmoothL1LossImpl >(Template Class Cloneable)public torch::nn::Cloneable< MultiLabelMarginLossImpl >(Template Class Cloneable)public torch::nn::Cloneable< LeakyReLUImpl >(Template Class Cloneable)public torch::nn::Cloneable< FunctionalImpl >(Template Class Cloneable)public torch::nn::Cloneable< ELUImpl >(Template Class Cloneable)public torch::nn::Cloneable< TanhshrinkImpl >(Template Class Cloneable)public torch::nn::Cloneable< PairwiseDistanceImpl >(Template Class Cloneable)public torch::nn::Cloneable< LogSigmoidImpl >(Template Class Cloneable)public torch::nn::Cloneable< HardtanhImpl >(Template Class Cloneable)public torch::nn::Cloneable< FractionalMaxPool2dImpl >(Template Class Cloneable)public torch::nn::Cloneable< FlattenImpl >(Template Class Cloneable)public torch::nn::Cloneable< CrossMapLRN2dImpl >(Template Class Cloneable)public torch::nn::Cloneable< TransformerEncoderLayerImpl >(Template Class Cloneable)public torch::nn::Cloneable< ThresholdImpl >(Template Class Cloneable)public torch::nn::Cloneable< SoftsignImpl >(Template Class Cloneable)public torch::nn::Cloneable< MultiMarginLossImpl >(Template Class Cloneable)public torch::nn::Cloneable< FractionalMaxPool3dImpl >(Template Class Cloneable)public torch::nn::Cloneable< CTCLossImpl >(Template Class Cloneable)public torch::nn::Cloneable< UnfoldImpl >(Template Class Cloneable)public torch::nn::Cloneable< SiLUImpl >(Template Class Cloneable)public torch::nn::Cloneable< ParameterListImpl >(Template Class Cloneable)public torch::nn::Cloneable< MultiheadAttentionImpl >(Template Class Cloneable)public torch::nn::Cloneable< CELUImpl >(Template Class Cloneable)public torch::nn::Cloneable< UpsampleImpl >(Template Class Cloneable)public torch::nn::Cloneable< TransformerImpl >(Template Class Cloneable)public torch::nn::Cloneable< SELUImpl >(Template Class Cloneable)public torch::nn::Cloneable< PixelUnshuffleImpl >(Template Class Cloneable)public torch::nn::Cloneable< LinearImpl >(Template Class Cloneable)public torch::nn::Cloneable< HingeEmbeddingLossImpl >(Template Class Cloneable)public torch::nn::Cloneable< EmbeddingImpl >(Template Class Cloneable)public torch::nn::Cloneable< MultiLabelSoftMarginLossImpl >(Template Class Cloneable)public torch::nn::Cloneable< CrossEntropyLossImpl >(Template Class Cloneable)public torch::nn::Cloneable< TripletMarginLossImpl >(Template Class Cloneable)public torch::nn::Cloneable< TransformerDecoderLayerImpl >(Template Class Cloneable)public torch::nn::Cloneable< SoftMarginLossImpl >(Template Class Cloneable)public torch::nn::Cloneable< LocalResponseNormImpl >(Template Class Cloneable)public torch::nn::Cloneable< BCELossImpl >(Template Class Cloneable)public torch::nn::Cloneable< LayerNormImpl >(Template Class Cloneable)public torch::nn::Cloneable< AdaptiveLogSoftmaxWithLossImpl >(Template Class Cloneable)public torch::nn::Cloneable< ReLUImpl >(Template Class Cloneable)public torch::nn::Cloneable< ModuleListImpl >(Template Class Cloneable)public torch::nn::Cloneable< HuberLossImpl >(Template Class Cloneable)public torch::nn::Cloneable< GELUImpl >(Template Class Cloneable)public torch::nn::Cloneable< SoftmaxImpl >(Template Class Cloneable)public torch::nn::Cloneable< Softmax2dImpl >(Template Class Cloneable)public torch::nn::Cloneable< SoftplusImpl >(Template Class Cloneable)public torch::nn::Cloneable< SigmoidImpl >(Template Class Cloneable)public torch::nn::Cloneable< PoissonNLLLossImpl >(Template Class Cloneable)public torch::nn::Cloneable< ModuleDictImpl >(Template Class Cloneable)public torch::nn::Cloneable< MishImpl >(Template Class Cloneable)public torch::nn::Cloneable< UnflattenImpl >(Template Class Cloneable)public torch::nn::Cloneable< ReLU6Impl >(Template Class Cloneable)public torch::nn::Cloneable< MSELossImpl >(Template Class Cloneable)public torch::nn::Cloneable< CosineSimilarityImpl >(Template Class Cloneable)public torch::nn::Cloneable< CosineEmbeddingLossImpl >(Template Class Cloneable)public torch::nn::Cloneable< TransformerDecoderImpl >(Template Class Cloneable)public torch::nn::Cloneable< TanhImpl >(Template Class Cloneable)public torch::nn::Cloneable< NLLLossImpl >(Template Class Cloneable)public torch::nn::Cloneable< MarginRankingLossImpl >(Template Class Cloneable)public torch::nn::Cloneable< BCEWithLogitsLossImpl >(Template Class Cloneable)public torch::nn::Cloneable< TransformerEncoderImpl >(Template Class Cloneable)public torch::nn::Cloneable< PixelShuffleImpl >(Template Class Cloneable)public torch::nn::Cloneable< KLDivLossImpl >(Template Class Cloneable)public torch::nn::Cloneable< GroupNormImpl >(Template Class Cloneable)public torch::nn::Cloneable< Derived >(Template Class Cloneable)
Class Documentation#
-
class Module : public std::enable_shared_from_this<Module>#
The base class for all modules in PyTorch.
A
Moduleis an abstraction over the implementation of some function or algorithm, possibly associated with some persistent data. AModulemay contain furtherModules (“submodules”), each with their own implementation, persistent data and further submodules.Modules can thus be said to form a recursive tree structure. AModuleis registered as a submodule to anotherModuleby callingregister_module(), typically from within a parent module’s constructor.A distinction is made between three kinds of persistent data that may be associated with a
Module:Parameters: tensors that record gradients, typically weights updated during the backward step (e.g. the
weightof aLinearmodule),Buffers: tensors that do not record gradients, typically updated during the forward step, such as running statistics (e.g.
meanandvariancein theBatchNormmodule),Any additional state, not necessarily tensors, required for the implementation or configuration of a
Module.
The first two kinds of state are special in that they may be registered with the
Modulesystem to allow convenient access and batch configuration. For example, registered parameters in anyModulemay be iterated over via theparameters()accessor. Further, changing the data type of aModule’s registered parameters can be done conveniently viaModule::to(), e.g.module->to(torch::kCUDA)to move all parameters to GPU memory. Lastly, registered parameters and buffers are handled specially during aclone()operation, which performs a deepcopy of a cloneableModulehierarchy.Parameters are registered with a
Moduleviaregister_parameter. Buffers are registered separately viaregister_buffer. These methods are part of the public API ofModuleand are typically invoked from within a concreteModules constructor.Note
The design and implementation of this class is largely based on the Python API. You may want to consult the python documentation for
torch.nn.Modulefor further clarification on certain methods or behavior.Subclassed by torch::nn::Cloneable< SoftshrinkImpl >, torch::nn::Cloneable< PReLUImpl >, torch::nn::Cloneable< LogSoftmaxImpl >, torch::nn::Cloneable< L1LossImpl >, torch::nn::Cloneable< SequentialImpl >, torch::nn::Cloneable< HardshrinkImpl >, torch::nn::Cloneable< GLUImpl >, torch::nn::Cloneable< RReLUImpl >, torch::nn::Cloneable< ParameterDictImpl >, torch::nn::Cloneable< IdentityImpl >, torch::nn::Cloneable< FoldImpl >, torch::nn::Cloneable< EmbeddingBagImpl >, torch::nn::Cloneable< BilinearImpl >, torch::nn::Cloneable< TripletMarginWithDistanceLossImpl >, torch::nn::Cloneable< SoftminImpl >, torch::nn::Cloneable< SmoothL1LossImpl >, torch::nn::Cloneable< MultiLabelMarginLossImpl >, torch::nn::Cloneable< LeakyReLUImpl >, torch::nn::Cloneable< FunctionalImpl >, torch::nn::Cloneable< ELUImpl >, torch::nn::Cloneable< TanhshrinkImpl >, torch::nn::Cloneable< PairwiseDistanceImpl >, torch::nn::Cloneable< LogSigmoidImpl >, torch::nn::Cloneable< HardtanhImpl >, torch::nn::Cloneable< FractionalMaxPool2dImpl >, torch::nn::Cloneable< FlattenImpl >, torch::nn::Cloneable< CrossMapLRN2dImpl >, torch::nn::Cloneable< TransformerEncoderLayerImpl >, torch::nn::Cloneable< ThresholdImpl >, torch::nn::Cloneable< SoftsignImpl >, torch::nn::Cloneable< MultiMarginLossImpl >, torch::nn::Cloneable< FractionalMaxPool3dImpl >, torch::nn::Cloneable< CTCLossImpl >, torch::nn::Cloneable< UnfoldImpl >, torch::nn::Cloneable< SiLUImpl >, torch::nn::Cloneable< ParameterListImpl >, torch::nn::Cloneable< MultiheadAttentionImpl >, torch::nn::Cloneable< CELUImpl >, torch::nn::Cloneable< UpsampleImpl >, torch::nn::Cloneable< TransformerImpl >, torch::nn::Cloneable< SELUImpl >, torch::nn::Cloneable< PixelUnshuffleImpl >, torch::nn::Cloneable< LinearImpl >, torch::nn::Cloneable< HingeEmbeddingLossImpl >, torch::nn::Cloneable< EmbeddingImpl >, torch::nn::Cloneable< MultiLabelSoftMarginLossImpl >, torch::nn::Cloneable< CrossEntropyLossImpl >, torch::nn::Cloneable< TripletMarginLossImpl >, torch::nn::Cloneable< TransformerDecoderLayerImpl >, torch::nn::Cloneable< SoftMarginLossImpl >, torch::nn::Cloneable< LocalResponseNormImpl >, torch::nn::Cloneable< BCELossImpl >, torch::nn::Cloneable< LayerNormImpl >, torch::nn::Cloneable< AdaptiveLogSoftmaxWithLossImpl >, torch::nn::Cloneable< ReLUImpl >, torch::nn::Cloneable< ModuleListImpl >, torch::nn::Cloneable< HuberLossImpl >, torch::nn::Cloneable< GELUImpl >, torch::nn::Cloneable< SoftmaxImpl >, torch::nn::Cloneable< Softmax2dImpl >, torch::nn::Cloneable< SoftplusImpl >, torch::nn::Cloneable< SigmoidImpl >, torch::nn::Cloneable< PoissonNLLLossImpl >, torch::nn::Cloneable< ModuleDictImpl >, torch::nn::Cloneable< MishImpl >, torch::nn::Cloneable< UnflattenImpl >, torch::nn::Cloneable< ReLU6Impl >, torch::nn::Cloneable< MSELossImpl >, torch::nn::Cloneable< CosineSimilarityImpl >, torch::nn::Cloneable< CosineEmbeddingLossImpl >, torch::nn::Cloneable< TransformerDecoderImpl >, torch::nn::Cloneable< TanhImpl >, torch::nn::Cloneable< NLLLossImpl >, torch::nn::Cloneable< MarginRankingLossImpl >, torch::nn::Cloneable< BCEWithLogitsLossImpl >, torch::nn::Cloneable< TransformerEncoderImpl >, torch::nn::Cloneable< PixelShuffleImpl >, torch::nn::Cloneable< KLDivLossImpl >, torch::nn::Cloneable< GroupNormImpl >, torch::nn::Cloneable< Derived >
Public Types
Public Functions
-
Module()#
Constructs the module without immediate knowledge of the submodule’s name.
The name of the submodule is inferred via RTTI (if possible) the first time
.name()is invoked.
-
virtual ~Module() = default#
-
const std::string &name() const noexcept#
Returns the name of the
Module.A
Modulehas an associatedname, which is a string representation of the kind of concreteModuleit represents, such as"Linear"for theLinearmodule. Under most circumstances, this name is automatically inferred via runtime type information (RTTI). In the unusual circumstance that you have this feature disabled, you may want to manually name yourModules by passing the string name to theModulebase class’ constructor.
-
virtual std::shared_ptr<Module> clone(const std::optional<Device> &device = std::nullopt) const#
Performs a recursive deep copy of the module and all its registered parameters, buffers and submodules.
Optionally, this method sets the current device to the one supplied before cloning. If no device is given, each parameter and buffer will be moved to the device of its source.
Attention
Attempting to call the clone() method inherited from the base Module class (the one documented here) will fail. To inherit an actual implementation of clone(), you must subclass Cloneable. Cloneable is templatized on the concrete module type, and can thus properly copy a Module. This method is provided on the base class’ API solely for an easier-to-use polymorphic interface.
-
void apply(const ModuleApplyFunction &function)#
Applies the
functionto theModuleand recursively to every submodule.The function must accept a
Module&.
-
void apply(const ConstModuleApplyFunction &function) const#
Applies the
functionto theModuleand recursively to every submodule.The function must accept a
const Module&.
-
void apply(const NamedModuleApplyFunction &function, const std::string &name_prefix = std::string())#
Applies the
functionto theModuleand recursively to every submodule.The function must accept a
const std::string&for the key of the module, and aModule&. The key of the module itself is the empty string. Ifname_prefixis given, it is prepended to every key as<name_prefix>.<key>(and justname_prefixfor the module itself).
-
void apply(const ConstNamedModuleApplyFunction &function, const std::string &name_prefix = std::string()) const#
Applies the
functionto theModuleand recursively to every submodule.The function must accept a
const std::string&for the key of the module, and aconst Module&. The key of the module itself is the empty string. Ifname_prefixis given, it is prepended to every key as<name_prefix>.<key>(and justname_prefixfor the module itself).
-
void apply(const ModulePointerApplyFunction &function) const#
Applies the
functionto theModuleand recursively to every submodule.The function must accept a
const std::shared_ptr<Module>&.
-
void apply(const NamedModulePointerApplyFunction &function, const std::string &name_prefix = std::string()) const#
Applies the
functionto theModuleand recursively to every submodule.The function must accept a
const std::string&for the key of the module, and aconst std::shared_ptr<Module>&. The key of the module itself is the empty string. Ifname_prefixis given, it is prepended to every key as<name_prefix>.<key>(and justname_prefixfor the module itself).
-
std::vector<Tensor> parameters(bool recurse = true) const#
Returns the parameters of this
Moduleand ifrecurseis true, also recursively of every submodule.
-
OrderedDict<std::string, Tensor> named_parameters(bool recurse = true) const#
Returns an
OrderedDictwith the parameters of thisModulealong with their keys, and ifrecurseis true also recursively of every submodule.
-
std::vector<Tensor> buffers(bool recurse = true) const#
Returns the buffers of this
Moduleand ifrecurseis true, also recursively of every submodule.
-
OrderedDict<std::string, Tensor> named_buffers(bool recurse = true) const#
Returns an
OrderedDictwith the buffers of thisModulealong with their keys, and ifrecurseis true also recursively of every submodule.
-
std::vector<std::shared_ptr<Module>> modules(bool include_self = true) const#
Returns the submodules of this
Module(the entire submodule hierarchy) and ifinclude_selfis true, also inserts ashared_ptrto this module in the first position.Warning
Only pass include_self as true if this Module is stored in a shared_ptr! Otherwise an exception will be thrown. You may still call this method with include_self set to false if your Module is not stored in a shared_ptr.
-
OrderedDict<std::string, std::shared_ptr<Module>> named_modules(const std::string &name_prefix = std::string(), bool include_self = true) const#
Returns an
OrderedDictof the submodules of thisModule(the entire submodule hierarchy) and their keys, and ifinclude_selfis true, also inserts ashared_ptrto this module in the first position.If
name_prefixis given, it is prepended to every key as<name_prefix>.<key>(and justname_prefixfor the module itself).Warning
Only pass include_self as true if this Module is stored in a shared_ptr! Otherwise an exception will be thrown. You may still call this method with include_self set to false if your Module is not stored in a shared_ptr.
-
std::vector<std::shared_ptr<Module>> children() const#
Returns the direct submodules of this
Module.
-
OrderedDict<std::string, std::shared_ptr<Module>> named_children() const#
Returns an
OrderedDictof the direct submodules of thisModuleand their keys.
-
virtual void train(bool on = true)#
Enables “training” mode.
-
void eval()#
Calls train(false) to enable “eval” mode.
Do not override this method, override
train()instead.
-
virtual bool is_training() const noexcept#
True if the module is in training mode.
Every
Modulehas a boolean associated with it that determines whether theModuleis currently in training mode (set via.train()) or in evaluation (inference) mode (set via.eval()). This property is exposed viais_training(), and may be used by the implementation of a concrete module to modify its runtime behavior. See theBatchNormorDropoutmodules for examples ofModules that use different code paths depending on this property.
-
virtual void to(torch::Device device, torch::Dtype dtype, bool non_blocking = false)#
Recursively casts all parameters to the given
dtypeanddevice.If
non_blockingis true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
-
virtual void to(torch::Dtype dtype, bool non_blocking = false)#
Recursively casts all parameters to the given dtype.
If
non_blockingis true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
-
virtual void to(torch::Device device, bool non_blocking = false)#
Recursively moves all parameters to the given device.
If
non_blockingis true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
-
virtual void zero_grad(bool set_to_none = true)#
Recursively zeros out the
gradvalue of each registered parameter.
-
template<typename ModuleType>
ModuleType::ContainedType *as() noexcept# Attempts to cast this
Moduleto the givenModuleType.This method is useful when calling
apply().void initialize_weights(nn::Module& module) { torch::NoGradGuard no_grad; if (auto* linear = module.as<nn::Linear>()) { linear->weight.normal_(0.0, 0.02); } } MyModule module; module->apply(initialize_weights);
-
template<typename ModuleType>
const ModuleType::ContainedType *as() const noexcept# Attempts to cast this
Moduleto the givenModuleType.This method is useful when calling
apply().
-
template<typename ModuleType, typename = torch::detail::disable_if_module_holder_t<ModuleType>>
ModuleType *as() noexcept# Attempts to cast this
Moduleto the givenModuleType.This method is useful when calling
apply().void initialize_weights(nn::Module& module) { torch::NoGradGuard no_grad; if (auto* linear = module.as<nn::Linear>()) { linear->weight.normal_(0.0, 0.02); } } MyModule module; module.apply(initialize_weights);
-
template<typename ModuleType, typename = torch::detail::disable_if_module_holder_t<ModuleType>>
const ModuleType *as() const noexcept# Attempts to cast this
Moduleto the givenModuleType.This method is useful when calling
apply().void initialize_weights(nn::Module& module) { torch::NoGradGuard no_grad; if (auto* linear = module.as<nn::Linear>()) { linear->weight.normal_(0.0, 0.02); } } MyModule module; module.apply(initialize_weights);
-
virtual void save(serialize::OutputArchive &archive) const#
Serializes the
Moduleinto the givenOutputArchive.If the
Modulecontains unserializable submodules (e.g.nn::Functional), those submodules are skipped when serializing.
-
virtual void load(serialize::InputArchive &archive)#
Deserializes the
Modulefrom the givenInputArchive.If the
Modulecontains unserializable submodules (e.g.nn::Functional), we don’t check the existence of those submodules in theInputArchivewhen deserializing.
-
virtual void pretty_print(std::ostream &stream) const#
Streams a pretty representation of the
Moduleinto the givenstream.By default, this representation will be the name of the module (taken from
name()), followed by a recursive pretty print of all of theModule’s submodules.Override this method to change the pretty print. The input
streamshould be returned from the method, to allow easy chaining.
-
Tensor ®ister_parameter(std::string name, Tensor tensor, bool requires_grad = true)#
Registers a parameter with this
Module.A parameter should be any gradient-recording tensor used in the implementation of your
Module. Registering it makes it available to methods such asparameters(),clone()orto().Note that registering an undefined Tensor (e.g.
module.register_parameter("param", Tensor())) is allowed, and is equivalent tomodule.register_parameter("param", None)in Python API.MyModule::MyModule() { weight_ = register_parameter("weight", torch::randn({A, B})); }
-
Tensor ®ister_buffer(std::string name, Tensor tensor)#
Registers a buffer with this
Module.A buffer is intended to be state in your module that does not record gradients, such as running statistics. Registering it makes it available to methods such as
buffers(),clone()or `to().MyModule::MyModule() { mean_ = register_buffer("mean", torch::empty({num_features_})); }
Registers a submodule with this
Module.Registering a module makes it available to methods such as
modules(),clone()orto().MyModule::MyModule() { submodule_ = register_module("linear", torch::nn::Linear(3, 4)); }
Registers a submodule with this
Module.This method deals with
ModuleHolders.Registering a module makes it available to methods such as
modules(),clone()orto().MyModule::MyModule() { submodule_ = register_module("linear", torch::nn::Linear(3, 4)); }
Replaces a registered submodule with this
Module.This takes care of the registration, if you used submodule members, you should module->submodule_ = module->replace_module(“linear”, torch::nn::Linear(3, 4)); It only works when a module of the name is already registered.
This is useful for replacing a module after initialization, e.g. for finetuning.
Replaces a registered submodule with this
Module.This method deals with
ModuleHolders.This takes care of the registration, if you used submodule members, you should module->submodule_ = module->replace_module(“linear”, linear_holder); It only works when a module of the name is already registered.
This is useful for replacing a module after initialization, e.g. for finetuning.
Protected Functions
-
inline virtual bool _forward_has_default_args()#
The following three functions allow a module with default arguments in its forward method to be used in a Sequential module.
You should NEVER override these functions manually. Instead, you should use the
FORWARD_HAS_DEFAULT_ARGSmacro.
-
inline virtual unsigned int _forward_num_required_args()#
Protected Attributes
-
OrderedDict<std::string, Tensor> parameters_#
The registered parameters of this
Module.Inorder to access parameters_ in ParameterDict and ParameterList