--- myst: html_meta: description: PyTorch C++ neural network modules — torch::nn API for defining and training models. keywords: PyTorch, C++, nn, Module, neural network, torch::nn --- # Neural Network Modules (torch::nn) The `torch::nn` namespace provides neural network building blocks that mirror Python's `torch.nn` module. It uses a PIMPL (Pointer to Implementation) pattern where user-facing classes like `Conv2d` wrap internal `Conv2dImpl` classes. **When to use torch::nn:** - Building neural network models in C++ - Creating custom layers and modules - Porting Python models to C++ for production inference - Training models entirely in C++ **Basic usage:** ```cpp #include // Define a simple model struct Net : torch::nn::Module { torch::nn::Conv2d conv1{nullptr}; torch::nn::Linear fc1{nullptr}; Net() { conv1 = register_module("conv1", torch::nn::Conv2d( torch::nn::Conv2dOptions(1, 32, 3).stride(1).padding(1))); fc1 = register_module("fc1", torch::nn::Linear(32 * 28 * 28, 10)); } torch::Tensor forward(torch::Tensor x) { x = torch::relu(conv1->forward(x)); x = x.view({-1, 32 * 28 * 28}); return fc1->forward(x); } }; // Create and use the model auto model = std::make_shared(); auto input = torch::randn({1, 1, 28, 28}); auto output = model->forward(input); ``` ## Header Files - `torch/nn.h` - Main neural network header (includes all modules) - `torch/nn/module.h` - Base Module class - `torch/nn/modules.h` - All module implementations - `torch/nn/options.h` - Options structs for modules - `torch/nn/functional.h` - Functional API ## Module Base Class All neural network modules inherit from `torch::nn::Module`, which provides parameter management, serialization, device/dtype conversion, and hooks. ```{doxygenclass} torch::nn::Module ``` **Key features:** - `register_module()`: Register submodules for parameter tracking - `register_parameter()`: Register learnable parameters - `register_buffer()`: Register non-learnable state (e.g., running mean) - `parameters()` / `named_parameters()`: Iterate over all parameters - `to()`: Move module to a device or convert dtype - `train()` / `eval()`: Toggle training/evaluation mode - `save()` / `load()`: Serialize and deserialize module state ## Module Categories ```{toctree} :maxdepth: 1 containers convolution pooling linear activation normalization dropout embedding recurrent transformer loss functional utilities ```