:github_url: https://github.com/pytorch/pytorch .. _program_listing_file_torch_csrc_api_include_torch_nn_options_activation.h: Program Listing for File activation.h ===================================== |exhale_lsh| :ref:`Return to documentation for file ` (``torch/csrc/api/include/torch/nn/options/activation.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include #include #include #include namespace torch::nn { struct TORCH_API ELUOptions { TORCH_ARG(double, alpha) = 1.0; TORCH_ARG(bool, inplace) = false; }; namespace functional { using ELUFuncOptions = ELUOptions; } // namespace functional // ============================================================================ struct TORCH_API SELUOptions { /* implicit */ SELUOptions(bool inplace = false); TORCH_ARG(bool, inplace); }; namespace functional { using SELUFuncOptions = SELUOptions; } // namespace functional // ============================================================================ struct TORCH_API GLUOptions { /* implicit */ GLUOptions(int64_t dim = -1); TORCH_ARG(int64_t, dim); }; namespace functional { using GLUFuncOptions = GLUOptions; } // namespace functional // ============================================================================ struct TORCH_API GELUOptions { TORCH_ARG(std::string, approximate) = "none"; }; namespace functional { using GELUFuncOptions = GELUOptions; } // namespace functional // ============================================================================ struct TORCH_API HardshrinkOptions { /* implicit */ HardshrinkOptions(double lambda = 0.5); TORCH_ARG(double, lambda); }; namespace functional { using HardshrinkFuncOptions = HardshrinkOptions; } // namespace functional // ============================================================================ struct TORCH_API HardtanhOptions { TORCH_ARG(double, min_val) = -1.0; TORCH_ARG(double, max_val) = 1.0; TORCH_ARG(bool, inplace) = false; }; namespace functional { using HardtanhFuncOptions = HardtanhOptions; } // namespace functional // ============================================================================ struct TORCH_API LeakyReLUOptions { TORCH_ARG(double, negative_slope) = 1e-2; TORCH_ARG(bool, inplace) = false; }; namespace functional { using LeakyReLUFuncOptions = LeakyReLUOptions; } // namespace functional // ============================================================================ struct TORCH_API SoftmaxOptions { SoftmaxOptions(int64_t dim); TORCH_ARG(int64_t, dim); }; // ============================================================================ namespace functional { struct TORCH_API SoftmaxFuncOptions { SoftmaxFuncOptions(int64_t dim); TORCH_ARG(int64_t, dim); TORCH_ARG(std::optional, dtype) = std::nullopt; }; } // namespace functional // ============================================================================ struct TORCH_API SoftminOptions { SoftminOptions(int64_t dim); TORCH_ARG(int64_t, dim); }; // ============================================================================ namespace functional { struct TORCH_API SoftminFuncOptions { SoftminFuncOptions(int64_t dim); TORCH_ARG(int64_t, dim); TORCH_ARG(std::optional, dtype) = std::nullopt; }; } // namespace functional // ============================================================================ struct TORCH_API LogSoftmaxOptions { LogSoftmaxOptions(int64_t dim); TORCH_ARG(int64_t, dim); }; // ============================================================================ namespace functional { struct TORCH_API LogSoftmaxFuncOptions { LogSoftmaxFuncOptions(int64_t dim); TORCH_ARG(int64_t, dim); TORCH_ARG(std::optional, dtype) = std::nullopt; }; } // namespace functional // ============================================================================ struct TORCH_API PReLUOptions { TORCH_ARG(int64_t, num_parameters) = 1; TORCH_ARG(double, init) = 0.25; }; // ============================================================================ struct TORCH_API ReLUOptions { /* implicit */ ReLUOptions(bool inplace = false); TORCH_ARG(bool, inplace); }; namespace functional { using ReLUFuncOptions = ReLUOptions; } // namespace functional // ============================================================================ struct TORCH_API ReLU6Options { /* implicit */ ReLU6Options(bool inplace = false); TORCH_ARG(bool, inplace); }; namespace functional { using ReLU6FuncOptions = ReLU6Options; } // namespace functional // ============================================================================ struct TORCH_API RReLUOptions { TORCH_ARG(double, lower) = 1.0 / 8.0; TORCH_ARG(double, upper) = 1.0 / 3.0; TORCH_ARG(bool, inplace) = false; }; // ============================================================================ namespace functional { struct TORCH_API RReLUFuncOptions { TORCH_ARG(double, lower) = 1.0 / 8.0; TORCH_ARG(double, upper) = 1.0 / 3.0; TORCH_ARG(bool, training) = false; TORCH_ARG(bool, inplace) = false; }; } // namespace functional // ============================================================================ struct TORCH_API CELUOptions { TORCH_ARG(double, alpha) = 1.0; TORCH_ARG(bool, inplace) = false; }; namespace functional { using CELUFuncOptions = CELUOptions; } // namespace functional // ============================================================================ struct TORCH_API SoftplusOptions { TORCH_ARG(double, beta) = 1.0; TORCH_ARG(double, threshold) = 20.0; }; namespace functional { using SoftplusFuncOptions = SoftplusOptions; } // namespace functional // ============================================================================ struct TORCH_API SoftshrinkOptions { /* implicit */ SoftshrinkOptions(double lambda = 0.5); TORCH_ARG(double, lambda); }; namespace functional { using SoftshrinkFuncOptions = SoftshrinkOptions; } // namespace functional // ============================================================================ struct TORCH_API ThresholdOptions { ThresholdOptions(double threshold, double value) : threshold_(threshold), value_(value) {} TORCH_ARG(double, threshold); TORCH_ARG(double, value); TORCH_ARG(bool, inplace) = false; }; namespace functional { using ThresholdFuncOptions = ThresholdOptions; } // namespace functional // ============================================================================ namespace functional { struct TORCH_API GumbelSoftmaxFuncOptions { TORCH_ARG(double, tau) = 1.0; TORCH_ARG(bool, hard) = false; TORCH_ARG(int, dim) = -1; }; } // namespace functional // ============================================================================ struct TORCH_API MultiheadAttentionOptions { MultiheadAttentionOptions(int64_t embed_dim, int64_t num_heads); TORCH_ARG(int64_t, embed_dim); TORCH_ARG(int64_t, num_heads); TORCH_ARG(double, dropout) = 0.0; TORCH_ARG(bool, bias) = true; TORCH_ARG(bool, add_bias_kv) = false; TORCH_ARG(bool, add_zero_attn) = false; TORCH_ARG(int64_t, kdim); TORCH_ARG(int64_t, vdim); }; // ============================================================================ namespace functional { struct TORCH_API MultiheadAttentionForwardFuncOptions { MultiheadAttentionForwardFuncOptions( int64_t embed_dim_to_check, int64_t num_heads, Tensor in_proj_weight, Tensor in_proj_bias, Tensor bias_k, Tensor bias_v, bool add_zero_attn, double dropout_p, Tensor out_proj_weight, Tensor out_proj_bias); TORCH_ARG(int64_t, embed_dim_to_check); TORCH_ARG(int64_t, num_heads); TORCH_ARG(Tensor, in_proj_weight); TORCH_ARG(Tensor, in_proj_bias); TORCH_ARG(Tensor, bias_k); TORCH_ARG(Tensor, bias_v); TORCH_ARG(bool, add_zero_attn); TORCH_ARG(double, dropout_p); TORCH_ARG(Tensor, out_proj_weight); TORCH_ARG(Tensor, out_proj_bias); TORCH_ARG(bool, training) = true; TORCH_ARG(Tensor, key_padding_mask); TORCH_ARG(bool, need_weights) = true; TORCH_ARG(Tensor, attn_mask); TORCH_ARG(bool, use_separate_proj_weight) = false; TORCH_ARG(Tensor, q_proj_weight); TORCH_ARG(Tensor, k_proj_weight); TORCH_ARG(Tensor, v_proj_weight); TORCH_ARG(Tensor, static_k); TORCH_ARG(Tensor, static_v); TORCH_ARG(bool, average_attn_weights) = true; }; } // namespace functional } // namespace torch::nn