:github_url: https://github.com/pytorch/pytorch .. _program_listing_file_torch_csrc_api_include_torch_nn_modules_instancenorm.h: Program Listing for File instancenorm.h ======================================= |exhale_lsh| :ref:`Return to documentation for file ` (``torch/csrc/api/include/torch/nn/modules/instancenorm.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include #include #include namespace torch::nn { template // NOLINTNEXTLINE(bugprone-crtp-constructor-accessibility) class InstanceNormImpl : public torch::nn::NormImplBase { private: inline Tensor apply_instance_norm(const Tensor& input) { return torch::nn::functional::detail::instance_norm( input, this->running_mean, this->running_var, this->weight, this->bias, this->is_training() || !this->options.track_running_stats(), this->options.momentum(), this->options.eps()); } inline Tensor handle_no_batch_input(const Tensor& input) { return this->apply_instance_norm(input.unsqueeze(0)).squeeze(0); } public: using torch::nn::NormImplBase::NormImplBase; Tensor forward(const Tensor& input) { this->_check_input_dim(input); // For InstanceNorm1D, 2D is unbatched and 3D is batched // For InstanceNorm2D, 3D is unbatched and 4D is batched // For InstanceNorm3D, 4D is unbatched and 5D is batched // check if input does not have a batch-dim if (input.dim() == D + 1) { return this->handle_no_batch_input(input); } return this->apply_instance_norm(input); } void pretty_print(std::ostream& stream) const override { stream << std::boolalpha << "torch::nn::InstanceNorm" << D << "d(" << this->options.num_features() << ", " << "eps=" << this->options.eps() << ", " << "momentum=" << this->options.momentum() << ", " << "affine=" << this->options.affine() << ", " << "track_running_stats=" << this->options.track_running_stats() << ')'; } }; // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm1d // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API InstanceNorm1dImpl : public InstanceNormImpl<1, InstanceNorm1dImpl> { protected: void _check_input_dim(const Tensor& input) override; public: using InstanceNormImpl<1, InstanceNorm1dImpl>::InstanceNormImpl; }; TORCH_MODULE(InstanceNorm1d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm2d // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API InstanceNorm2dImpl : public InstanceNormImpl<2, InstanceNorm2dImpl> { protected: void _check_input_dim(const Tensor& input) override; public: using InstanceNormImpl<2, InstanceNorm2dImpl>::InstanceNormImpl; }; TORCH_MODULE(InstanceNorm2d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm3d // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API InstanceNorm3dImpl : public InstanceNormImpl<3, InstanceNorm3dImpl> { protected: void _check_input_dim(const Tensor& input) override; public: using InstanceNormImpl<3, InstanceNorm3dImpl>::InstanceNormImpl; }; TORCH_MODULE(InstanceNorm3d); } // namespace torch::nn