Neural Network Modules (torch::nn)#
The torch::nn namespace provides neural network building blocks that mirror
Python’s torch.nn module. It uses a PIMPL (Pointer to Implementation) pattern
where user-facing classes like Conv2d wrap internal Conv2dImpl classes.
When to use torch::nn:
Building neural network models in C++
Creating custom layers and modules
Porting Python models to C++ for production inference
Training models entirely in C++
Basic usage:
#include <torch/torch.h>
// Define a simple model
struct Net : torch::nn::Module {
torch::nn::Conv2d conv1{nullptr};
torch::nn::Linear fc1{nullptr};
Net() {
conv1 = register_module("conv1", torch::nn::Conv2d(
torch::nn::Conv2dOptions(1, 32, 3).stride(1).padding(1)));
fc1 = register_module("fc1", torch::nn::Linear(32 * 28 * 28, 10));
}
torch::Tensor forward(torch::Tensor x) {
x = torch::relu(conv1->forward(x));
x = x.view({-1, 32 * 28 * 28});
return fc1->forward(x);
}
};
// Create and use the model
auto model = std::make_shared<Net>();
auto input = torch::randn({1, 1, 28, 28});
auto output = model->forward(input);
Header Files#
torch/nn.h- Main neural network header (includes all modules)torch/nn/module.h- Base Module classtorch/nn/modules.h- All module implementationstorch/nn/options.h- Options structs for modulestorch/nn/functional.h- Functional API
Module Base Class#
All neural network modules inherit from torch::nn::Module, which provides
parameter management, serialization, device/dtype conversion, and hooks.
-
class Module : public std::enable_shared_from_this<Module>#
The base class for all modules in PyTorch.
.. note::
The design and implementation of this class is largely based on the Python
API. You may want to consult the python documentation for
:py:class:
pytorch:torch.nn.Modulefor further clarification on certainmethods or behavior.
A
Moduleis an abstraction over the implementation of some function or algorithm, possibly associated with some persistent data. AModulemay contain furtherModules (“submodules”), each with their own implementation, persistent data and further submodules.Modules can thus be said to form a recursive tree structure. AModuleis registered as a submodule to anotherModuleby callingregister_module(), typically from within a parent module’s constructor.A distinction is made between three kinds of persistent data that may be associated with a
Module:Parameters: tensors that record gradients, typically weights updated during the backward step (e.g. the
weightof aLinearmodule),Buffers: tensors that do not record gradients, typically updated during the forward step, such as running statistics (e.g.
meanandvariancein theBatchNormmodule),Any additional state, not necessarily tensors, required for the implementation or configuration of a
Module.
The first two kinds of state are special in that they may be registered with the
Modulesystem to allow convenient access and batch configuration. For example, registered parameters in anyModulemay be iterated over via theparameters()accessor. Further, changing the data type of aModule’s registered parameters can be done conveniently viaModule::to(), e.g.module->to(torch::kCUDA)to move all parameters to GPU memory. Lastly, registered parameters and buffers are handled specially during aclone()operation, which performs a deepcopy of a cloneableModulehierarchy.Parameters are registered with a
Moduleviaregister_parameter. Buffers are registered separately viaregister_buffer. These methods are part of the public API ofModuleand are typically invoked from within a concreteModules constructor.Subclassed by torch::nn::Cloneable< SoftshrinkImpl >, torch::nn::Cloneable< PReLUImpl >, torch::nn::Cloneable< LogSoftmaxImpl >, torch::nn::Cloneable< L1LossImpl >, torch::nn::Cloneable< SequentialImpl >, torch::nn::Cloneable< HardshrinkImpl >, torch::nn::Cloneable< GLUImpl >, torch::nn::Cloneable< RReLUImpl >, torch::nn::Cloneable< ParameterDictImpl >, torch::nn::Cloneable< IdentityImpl >, torch::nn::Cloneable< FoldImpl >, torch::nn::Cloneable< EmbeddingBagImpl >, torch::nn::Cloneable< BilinearImpl >, torch::nn::Cloneable< TripletMarginWithDistanceLossImpl >, torch::nn::Cloneable< SoftminImpl >, torch::nn::Cloneable< SmoothL1LossImpl >, torch::nn::Cloneable< MultiLabelMarginLossImpl >, torch::nn::Cloneable< LeakyReLUImpl >, torch::nn::Cloneable< FunctionalImpl >, torch::nn::Cloneable< ELUImpl >, torch::nn::Cloneable< TanhshrinkImpl >, torch::nn::Cloneable< PairwiseDistanceImpl >, torch::nn::Cloneable< LogSigmoidImpl >, torch::nn::Cloneable< HardtanhImpl >, torch::nn::Cloneable< FractionalMaxPool2dImpl >, torch::nn::Cloneable< FlattenImpl >, torch::nn::Cloneable< CrossMapLRN2dImpl >, torch::nn::Cloneable< TransformerEncoderLayerImpl >, torch::nn::Cloneable< ThresholdImpl >, torch::nn::Cloneable< SoftsignImpl >, torch::nn::Cloneable< MultiMarginLossImpl >, torch::nn::Cloneable< FractionalMaxPool3dImpl >, torch::nn::Cloneable< CTCLossImpl >, torch::nn::Cloneable< UnfoldImpl >, torch::nn::Cloneable< SiLUImpl >, torch::nn::Cloneable< ParameterListImpl >, torch::nn::Cloneable< MultiheadAttentionImpl >, torch::nn::Cloneable< CELUImpl >, torch::nn::Cloneable< UpsampleImpl >, torch::nn::Cloneable< TransformerImpl >, torch::nn::Cloneable< SELUImpl >, torch::nn::Cloneable< PixelUnshuffleImpl >, torch::nn::Cloneable< LinearImpl >, torch::nn::Cloneable< HingeEmbeddingLossImpl >, torch::nn::Cloneable< EmbeddingImpl >, torch::nn::Cloneable< MultiLabelSoftMarginLossImpl >, torch::nn::Cloneable< CrossEntropyLossImpl >, torch::nn::Cloneable< TripletMarginLossImpl >, torch::nn::Cloneable< TransformerDecoderLayerImpl >, torch::nn::Cloneable< SoftMarginLossImpl >, torch::nn::Cloneable< LocalResponseNormImpl >, torch::nn::Cloneable< BCELossImpl >, torch::nn::Cloneable< LayerNormImpl >, torch::nn::Cloneable< AdaptiveLogSoftmaxWithLossImpl >, torch::nn::Cloneable< ReLUImpl >, torch::nn::Cloneable< ModuleListImpl >, torch::nn::Cloneable< HuberLossImpl >, torch::nn::Cloneable< GELUImpl >, torch::nn::Cloneable< SoftmaxImpl >, torch::nn::Cloneable< Softmax2dImpl >, torch::nn::Cloneable< SoftplusImpl >, torch::nn::Cloneable< SigmoidImpl >, torch::nn::Cloneable< PoissonNLLLossImpl >, torch::nn::Cloneable< ModuleDictImpl >, torch::nn::Cloneable< MishImpl >, torch::nn::Cloneable< UnflattenImpl >, torch::nn::Cloneable< ReLU6Impl >, torch::nn::Cloneable< MSELossImpl >, torch::nn::Cloneable< CosineSimilarityImpl >, torch::nn::Cloneable< CosineEmbeddingLossImpl >, torch::nn::Cloneable< TransformerDecoderImpl >, torch::nn::Cloneable< TanhImpl >, torch::nn::Cloneable< NLLLossImpl >, torch::nn::Cloneable< MarginRankingLossImpl >, torch::nn::Cloneable< BCEWithLogitsLossImpl >, torch::nn::Cloneable< TransformerEncoderImpl >, torch::nn::Cloneable< PixelShuffleImpl >, torch::nn::Cloneable< KLDivLossImpl >, torch::nn::Cloneable< GroupNormImpl >, torch::nn::Cloneable< Derived >
Public Types
Public Functions
-
Module()#
Constructs the module without immediate knowledge of the submodule’s name.
The name of the submodule is inferred via RTTI (if possible) the first time
.name()is invoked.
-
virtual ~Module() = default#
-
const std::string &name() const noexcept#
Returns the name of the
Module.A
Modulehas an associatedname, which is a string representation of the kind of concreteModuleit represents, such as"Linear"for theLinearmodule. Under most circumstances, this name is automatically inferred via runtime type information (RTTI). In the unusual circumstance that you have this feature disabled, you may want to manually name yourModules by passing the string name to theModulebase class’ constructor.
-
virtual std::shared_ptr<Module> clone(const std::optional<Device> &device = std::nullopt) const#
Performs a recursive deep copy of the module and all its registered parameters, buffers and submodules.
Optionally, this method sets the current device to the one supplied before cloning. If no device is given, each parameter and buffer will be moved to the device of its source.
.. attention::
Attempting to call the
clone()method inherited from the baseModuleclass (the one documented here) will fail. To inherit an actual
implementation of
clone(), you must subclassCloneable.Cloneableis templatized on the concrete module type, and can thus properly copy a
Module. This method is provided on the base class’ API solely for aneasier-to-use polymorphic interface.
-
void apply(const ModuleApplyFunction &function)#
Applies the
functionto theModuleand recursively to every submodule.The function must accept a
Module&... code-block:: cpp
MyModule module;
module->apply([](nn::Module& module) {std::cout << module.name() << std::endl;
});
-
void apply(const ConstModuleApplyFunction &function) const#
Applies the
functionto theModuleand recursively to every submodule.The function must accept a
const Module&... code-block:: cpp
MyModule module;
module->apply([](const nn::Module& module) {std::cout << module.name() << std::endl;
});
-
void apply(const NamedModuleApplyFunction &function, const std::string &name_prefix = std::string())#
Applies the
functionto theModuleand recursively to every submodule.The function must accept a
const std::string&for the key of the module, and aModule&. The key of the module itself is the empty string. Ifname_prefixis given, it is prepended to every key as<name_prefix>.<key>(and justname_prefixfor the module itself)... code-block:: cpp
MyModule module;
module->apply([](const std::string& key, nn::Module& module) {std::cout << key << ": " << module.name() << std::endl;
});
-
void apply(const ConstNamedModuleApplyFunction &function, const std::string &name_prefix = std::string()) const#
Applies the
functionto theModuleand recursively to every submodule.The function must accept a
const std::string&for the key of the module, and aconst Module&. The key of the module itself is the empty string. Ifname_prefixis given, it is prepended to every key as<name_prefix>.<key>(and justname_prefixfor the module itself)... code-block:: cpp
MyModule module;
module->apply([](const std::string& key, const nn::Module& module) {std::cout << key << ": " << module.name() << std::endl;
});
-
void apply(const ModulePointerApplyFunction &function) const#
Applies the
functionto theModuleand recursively to every submodule.The function must accept a
const std::shared_ptr<Module>&... code-block:: cpp
MyModule module;
module->apply([](const std::shared_ptr<nn::Module>& module) {std::cout << module->name() << std::endl;
});
-
void apply(const NamedModulePointerApplyFunction &function, const std::string &name_prefix = std::string()) const#
Applies the
functionto theModuleand recursively to every submodule.The function must accept a
const std::string&for the key of the module, and aconst std::shared_ptr<Module>&. The key of the module itself is the empty string. Ifname_prefixis given, it is prepended to every key as<name_prefix>.<key>(and justname_prefixfor the module itself)... code-block:: cpp
MyModule module;
module->apply([](const std::string& key,
const std::shared_ptr<nn::Module>& module) {std::cout << key << ": " << module->name() << std::endl;
});
-
std::vector<Tensor> parameters(bool recurse = true) const#
Returns the parameters of this
Moduleand ifrecurseis true, also recursively of every submodule.
-
OrderedDict<std::string, Tensor> named_parameters(bool recurse = true) const#
Returns an
OrderedDictwith the parameters of thisModulealong with their keys, and ifrecurseis true also recursively of every submodule.
-
std::vector<Tensor> buffers(bool recurse = true) const#
Returns the buffers of this
Moduleand ifrecurseis true, also recursively of every submodule.
-
OrderedDict<std::string, Tensor> named_buffers(bool recurse = true) const#
Returns an
OrderedDictwith the buffers of thisModulealong with their keys, and ifrecurseis true also recursively of every submodule.
-
std::vector<std::shared_ptr<Module>> modules(bool include_self = true) const#
Returns the submodules of this
Module(the entire submodule hierarchy) and ifinclude_selfis true, also inserts ashared_ptrto this module in the first position... warning::
Only pass
include_selfastrueif thisModuleis stored in ashared_ptr! Otherwise an exception will be thrown. You may still callthis method with
include_selfset to false if yourModuleis notstored in a
shared_ptr.
-
OrderedDict<std::string, std::shared_ptr<Module>> named_modules(const std::string &name_prefix = std::string(), bool include_self = true) const#
Returns an
OrderedDictof the submodules of thisModule(the entire submodule hierarchy) and their keys, and ifinclude_selfis true, also inserts ashared_ptrto this module in the first position.If
name_prefixis given, it is prepended to every key as<name_prefix>.<key>(and justname_prefixfor the module itself)... warning::
Only pass
include_selfastrueif thisModuleis stored in ashared_ptr! Otherwise an exception will be thrown. You may still callthis method with
include_selfset to false if yourModuleis notstored in a
shared_ptr.
-
std::vector<std::shared_ptr<Module>> children() const#
Returns the direct submodules of this
Module.
-
OrderedDict<std::string, std::shared_ptr<Module>> named_children() const#
Returns an
OrderedDictof the direct submodules of thisModuleand their keys.
-
virtual void train(bool on = true)#
Enables “training” mode.
-
void eval()#
Calls train(false) to enable “eval” mode.
Do not override this method, override
train()instead.
-
virtual bool is_training() const noexcept#
True if the module is in training mode.
Every
Modulehas a boolean associated with it that determines whether theModuleis currently in training mode (set via.train()) or in evaluation (inference) mode (set via.eval()). This property is exposed viais_training(), and may be used by the implementation of a concrete module to modify its runtime behavior. See theBatchNormorDropoutmodules for examples ofModules that use different code paths depending on this property.
-
virtual void to(torch::Device device, torch::Dtype dtype, bool non_blocking = false)#
Recursively casts all parameters to the given
dtypeanddevice.If
non_blockingis true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
-
virtual void to(torch::Dtype dtype, bool non_blocking = false)#
Recursively casts all parameters to the given dtype.
If
non_blockingis true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
-
virtual void to(torch::Device device, bool non_blocking = false)#
Recursively moves all parameters to the given device.
If
non_blockingis true and the source is in pinned memory and destination is on the GPU or vice versa, the copy is performed asynchronously with respect to the host. Otherwise, the argument has no effect.
-
virtual void zero_grad(bool set_to_none = true)#
Recursively zeros out the
gradvalue of each registered parameter.
-
template<typename ModuleType>
ModuleType::ContainedType *as() noexcept# Attempts to cast this
Moduleto the givenModuleType.This method is useful when calling
apply()... code-block:: cpp
void initialize_weights(nn::Module& module) {torch::NoGradGuard no_grad;
if (auto* linear = module.as<nn::Linear>()) {linear->weight.normal_(0.0, 0.02);
}
}
MyModule module;
module->apply(initialize_weights);
-
template<typename ModuleType>
const ModuleType::ContainedType *as() const noexcept# Attempts to cast this
Moduleto the givenModuleType.This method is useful when calling
apply()... code-block:: cpp
void initialize_weights(nn::Module& module) {torch::NoGradGuard no_grad;
if (auto* linear = module.as<nn::Linear>()) {linear->weight.normal_(0.0, 0.02);
}
}
MyModule module;
module->apply(initialize_weights);
-
template<typename ModuleType, typename = torch::detail::disable_if_module_holder_t<ModuleType>>
ModuleType *as() noexcept# Attempts to cast this
Moduleto the givenModuleType.This method is useful when calling
apply()... code-block:: cpp
void initialize_weights(nn::Module& module) {torch::NoGradGuard no_grad;
if (auto* linear = module.as<nn::Linear>()) {linear->weight.normal_(0.0, 0.02);
}
}
MyModule module;
module.apply(initialize_weights);
-
template<typename ModuleType, typename = torch::detail::disable_if_module_holder_t<ModuleType>>
const ModuleType *as() const noexcept# Attempts to cast this
Moduleto the givenModuleType.This method is useful when calling
apply()... code-block:: cpp
void initialize_weights(nn::Module& module) {torch::NoGradGuard no_grad;
if (auto* linear = module.as<nn::Linear>()) {linear->weight.normal_(0.0, 0.02);
}
}
MyModule module;
module.apply(initialize_weights);
-
virtual void save(serialize::OutputArchive &archive) const#
Serializes the
Moduleinto the givenOutputArchive.If the
Modulecontains unserializable submodules (e.g.nn::Functional), those submodules are skipped when serializing.
-
virtual void load(serialize::InputArchive &archive)#
Deserializes the
Modulefrom the givenInputArchive.If the
Modulecontains unserializable submodules (e.g.nn::Functional), we don’t check the existence of those submodules in theInputArchivewhen deserializing.
-
virtual void pretty_print(std::ostream &stream) const#
Streams a pretty representation of the
Moduleinto the givenstream.By default, this representation will be the name of the module (taken from
name()), followed by a recursive pretty print of all of theModule’s submodules.Override this method to change the pretty print. The input
streamshould be returned from the method, to allow easy chaining.
-
Tensor ®ister_parameter(std::string name, Tensor tensor, bool requires_grad = true)#
Registers a parameter with this
Module.A parameter should be any gradient-recording tensor used in the implementation of your
Module. Registering it makes it available to methods such asparameters(),clone()orto().Note that registering an undefined Tensor (e.g.
module.register_parameter("param", Tensor())) is allowed, and is equivalent tomodule.register_parameter("param", None)in Python API... code-block:: cpp
MyModule::MyModule() {
weight_ = register_parameter("weight", torch::randn({A, B}));}
-
Tensor ®ister_buffer(std::string name, Tensor tensor)#
Registers a buffer with this
Module.A buffer is intended to be state in your module that does not record gradients, such as running statistics. Registering it makes it available to methods such as
buffers(),clone()orto()... code-block:: cpp
MyModule::MyModule() {
mean_ = register_buffer("mean", torch::empty({num_features_}));}
Registers a submodule with this
Module.Registering a module makes it available to methods such as
modules(),clone()orto()... code-block:: cpp
MyModule::MyModule() {
submodule_ = register_module("linear", torch::nn::Linear(3, 4));}
Registers a submodule with this
Module.This method deals with
ModuleHolders.Registering a module makes it available to methods such as
modules(),clone()orto()... code-block:: cpp
MyModule::MyModule() {
submodule_ = register_module("linear", torch::nn::Linear(3, 4));}
Replaces a registered submodule with this
Module.This takes care of the registration, if you used submodule members, you should assign the submodule as well, i.e. use as module->submodule_ = module->replace_module(“linear”, torch::nn::Linear(3, 4)); It only works when a module of the name is already registered.
This is useful for replacing a module after initialization, e.g. for finetuning.
Replaces a registered submodule with this
Module.This method deals with
ModuleHolders.This takes care of the registration, if you used submodule members, you should assign the submodule as well, i.e. use as module->submodule_ = module->replace_module(“linear”, linear_holder); It only works when a module of the name is already registered.
This is useful for replacing a module after initialization, e.g. for finetuning.
Key features:
register_module(): Register submodules for parameter trackingregister_parameter(): Register learnable parametersregister_buffer(): Register non-learnable state (e.g., running mean)parameters()/named_parameters(): Iterate over all parametersto(): Move module to a device or convert dtypetrain()/eval(): Toggle training/evaluation modesave()/load(): Serialize and deserialize module state