:github_url: https://github.com/pytorch/pytorch .. _program_listing_file_torch_csrc_api_include_torch_nn_init.h: Program Listing for File init.h =============================== |exhale_lsh| :ref:`Return to documentation for file ` (``torch/csrc/api/include/torch/nn/init.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include #include #include namespace torch { namespace nn::init { using NonlinearityType = std::variant< enumtype::kLinear, enumtype::kConv1D, enumtype::kConv2D, enumtype::kConv3D, enumtype::kConvTranspose1D, enumtype::kConvTranspose2D, enumtype::kConvTranspose3D, enumtype::kSigmoid, enumtype::kTanh, enumtype::kReLU, enumtype::kLeakyReLU>; using FanModeType = std::variant; } // namespace nn::init namespace nn::init { TORCH_API double calculate_gain( NonlinearityType nonlinearity, double param = 0.01); TORCH_API Tensor constant_(Tensor tensor, Scalar value); TORCH_API Tensor dirac_(Tensor tensor); TORCH_API Tensor eye_(Tensor matrix); TORCH_API Tensor normal_(Tensor tensor, double mean = 0, double std = 1); TORCH_API Tensor ones_(Tensor tensor); TORCH_API Tensor orthogonal_(Tensor tensor, double gain = 1.0); TORCH_API Tensor sparse_(Tensor tensor, double sparsity, double std = 0.01); TORCH_API Tensor uniform_(Tensor tensor, double low = 0, double high = 1); TORCH_API Tensor kaiming_normal_( Tensor tensor, double a = 0, FanModeType mode = torch::kFanIn, NonlinearityType nonlinearity = torch::kLeakyReLU); TORCH_API Tensor kaiming_uniform_( Tensor tensor, double a = 0, FanModeType mode = torch::kFanIn, NonlinearityType nonlinearity = torch::kLeakyReLU); TORCH_API Tensor xavier_normal_(Tensor tensor, double gain = 1.0); TORCH_API Tensor xavier_uniform_(Tensor tensor, double gain = 1.0); TORCH_API Tensor zeros_(Tensor tensor); TORCH_API std::tuple _calculate_fan_in_and_fan_out( const Tensor& tensor); } // namespace nn::init } // namespace torch