:github_url: https://github.com/pytorch/pytorch .. _program_listing_file_torch_csrc_api_include_torch_nn_modules_activation.h: Program Listing for File activation.h ===================================== |exhale_lsh| :ref:`Return to documentation for file ` (``torch/csrc/api/include/torch/nn/modules/activation.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include #include #include #include #include #include namespace torch::nn { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API ELUImpl : public torch::nn::Cloneable { public: explicit ELUImpl(const ELUOptions& options_ = {}); Tensor forward(Tensor input); void reset() override; void pretty_print(std::ostream& stream) const override; ELUOptions options; }; TORCH_MODULE(ELU); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API SELUImpl : public torch::nn::Cloneable { public: explicit SELUImpl(const SELUOptions& options_ = {}); Tensor forward(Tensor input); void reset() override; void pretty_print(std::ostream& stream) const override; SELUOptions options; }; TORCH_MODULE(SELU); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Hardshrink ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API HardshrinkImpl : public torch::nn::Cloneable { public: explicit HardshrinkImpl(const HardshrinkOptions& options_ = {}); Tensor forward(const Tensor& input); void reset() override; void pretty_print(std::ostream& stream) const override; HardshrinkOptions options; }; TORCH_MODULE(Hardshrink); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Hardtanh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API HardtanhImpl : public torch::nn::Cloneable { public: explicit HardtanhImpl(const HardtanhOptions& options_ = {}); Tensor forward(Tensor input); void reset() override; void pretty_print(std::ostream& stream) const override; HardtanhOptions options; }; TORCH_MODULE(Hardtanh); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LeakyReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API LeakyReLUImpl : public torch::nn::Cloneable { public: explicit LeakyReLUImpl(const LeakyReLUOptions& options_ = {}); Tensor forward(Tensor input); void reset() override; void pretty_print(std::ostream& stream) const override; LeakyReLUOptions options; }; TORCH_MODULE(LeakyReLU); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LogSigmoid ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API LogSigmoidImpl : public torch::nn::Cloneable { public: Tensor forward(const Tensor& input); void reset() override; void pretty_print(std::ostream& stream) const override; }; TORCH_MODULE(LogSigmoid); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softmax ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API SoftmaxImpl : public torch::nn::Cloneable { public: explicit SoftmaxImpl(int64_t dim) : SoftmaxImpl(SoftmaxOptions(dim)) {} explicit SoftmaxImpl(const SoftmaxOptions& options_); Tensor forward(const Tensor& input); void reset() override; void pretty_print(std::ostream& stream) const override; SoftmaxOptions options; }; TORCH_MODULE(Softmax); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softmin ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API SoftminImpl : public torch::nn::Cloneable { public: explicit SoftminImpl(int64_t dim) : SoftminImpl(SoftminOptions(dim)) {} explicit SoftminImpl(const SoftminOptions& options_); Tensor forward(const Tensor& input); void reset() override; void pretty_print(std::ostream& stream) const override; SoftminOptions options; }; TORCH_MODULE(Softmin); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ LogSoftmax ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API LogSoftmaxImpl : public torch::nn::Cloneable { public: explicit LogSoftmaxImpl(int64_t dim) : LogSoftmaxImpl(LogSoftmaxOptions(dim)) {} explicit LogSoftmaxImpl(const LogSoftmaxOptions& options_); Tensor forward(const Tensor& input); void reset() override; void pretty_print(std::ostream& stream) const override; LogSoftmaxOptions options; }; TORCH_MODULE(LogSoftmax); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softmax2d ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API Softmax2dImpl : public torch::nn::Cloneable { public: Tensor forward(const Tensor& input); void reset() override; void pretty_print(std::ostream& stream) const override; }; TORCH_MODULE(Softmax2d); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ PReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API PReLUImpl : public torch::nn::Cloneable { public: explicit PReLUImpl(const PReLUOptions& options_ = {}); Tensor forward(const Tensor& input); void reset() override; void pretty_print(std::ostream& stream) const override; PReLUOptions options; Tensor weight; }; TORCH_MODULE(PReLU); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API ReLUImpl : public torch::nn::Cloneable { public: explicit ReLUImpl(const ReLUOptions& options_ = {}); Tensor forward(Tensor input); void reset() override; void pretty_print(std::ostream& stream) const override; ReLUOptions options; }; TORCH_MODULE(ReLU); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ReLU6 ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API ReLU6Impl : public torch::nn::Cloneable { public: explicit ReLU6Impl(const ReLU6Options& options_ = {}); Tensor forward(Tensor input); void reset() override; void pretty_print(std::ostream& stream) const override; ReLU6Options options; }; TORCH_MODULE(ReLU6); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ RReLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API RReLUImpl : public torch::nn::Cloneable { public: explicit RReLUImpl(const RReLUOptions& options_ = {}); Tensor forward(Tensor input); void reset() override; void pretty_print(std::ostream& stream) const override; RReLUOptions options; }; TORCH_MODULE(RReLU); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API CELUImpl : public torch::nn::Cloneable { public: explicit CELUImpl(const CELUOptions& options_ = {}); Tensor forward(Tensor input); void reset() override; void pretty_print(std::ostream& stream) const override; CELUOptions options; }; TORCH_MODULE(CELU); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API GLUImpl : public torch::nn::Cloneable { public: explicit GLUImpl(const GLUOptions& options_ = {}); Tensor forward(const Tensor& input); void reset() override; void pretty_print(std::ostream& stream) const override; GLUOptions options; }; TORCH_MODULE(GLU); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GELU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API GELUImpl : public torch::nn::Cloneable { public: explicit GELUImpl(GELUOptions options_ = {}); Tensor forward(const Tensor& input); void reset() override; void pretty_print(std::ostream& stream) const override; GELUOptions options; }; TORCH_MODULE(GELU); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ SiLU ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API SiLUImpl : public torch::nn::Cloneable { public: Tensor forward(const Tensor& input); void reset() override; void pretty_print(std::ostream& stream) const override; }; TORCH_MODULE(SiLU); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Mish ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API MishImpl : public torch::nn::Cloneable { public: Tensor forward(const Tensor& input); void reset() override; void pretty_print(std::ostream& stream) const override; }; TORCH_MODULE(Mish); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Sigmoid ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API SigmoidImpl : public torch::nn::Cloneable { public: Tensor forward(const Tensor& input); void reset() override; void pretty_print(std::ostream& stream) const override; }; TORCH_MODULE(Sigmoid); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softplus ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API SoftplusImpl : public torch::nn::Cloneable { public: explicit SoftplusImpl(const SoftplusOptions& options_ = {}); Tensor forward(const Tensor& input); void reset() override; void pretty_print(std::ostream& stream) const override; SoftplusOptions options; }; TORCH_MODULE(Softplus); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softshrink ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API SoftshrinkImpl : public torch::nn::Cloneable { public: explicit SoftshrinkImpl(const SoftshrinkOptions& options_ = {}); Tensor forward(const Tensor& input); void reset() override; void pretty_print(std::ostream& stream) const override; SoftshrinkOptions options; }; TORCH_MODULE(Softshrink); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Softsign ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API SoftsignImpl : public torch::nn::Cloneable { public: Tensor forward(const Tensor& input); void reset() override; void pretty_print(std::ostream& stream) const override; }; TORCH_MODULE(Softsign); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tanh ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API TanhImpl : public torch::nn::Cloneable { public: Tensor forward(const Tensor& input); void reset() override; void pretty_print(std::ostream& stream) const override; }; TORCH_MODULE(Tanh); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tanhshrink ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API TanhshrinkImpl : public torch::nn::Cloneable { public: Tensor forward(const Tensor& input); void reset() override; void pretty_print(std::ostream& stream) const override; }; TORCH_MODULE(Tanhshrink); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Threshold ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API ThresholdImpl : public torch::nn::Cloneable { public: ThresholdImpl(double threshold, double value) : ThresholdImpl(ThresholdOptions(threshold, value)) {} explicit ThresholdImpl(const ThresholdOptions& options_); Tensor forward(Tensor input); void reset() override; void pretty_print(std::ostream& stream) const override; ThresholdOptions options; }; TORCH_MODULE(Threshold); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MultiheadAttention ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API MultiheadAttentionImpl : public torch::nn::Cloneable { public: MultiheadAttentionImpl(int64_t embed_dim, int64_t num_heads) : MultiheadAttentionImpl( MultiheadAttentionOptions(embed_dim, num_heads)) {} explicit MultiheadAttentionImpl(const MultiheadAttentionOptions& options_); std::tuple forward( const Tensor& query, const Tensor& key, const Tensor& value, const Tensor& key_padding_mask = {}, bool need_weights = true, const Tensor& attn_mask = {}, bool average_attn_weights = true); protected: FORWARD_HAS_DEFAULT_ARGS( {3, AnyValue(Tensor())}, {4, AnyValue(true)}, {5, AnyValue(Tensor())}, {6, AnyValue(true)}) public: void reset() override; void _reset_parameters(); MultiheadAttentionOptions options; bool _qkv_same_embed_dim{}; Tensor in_proj_weight; Tensor in_proj_bias; Tensor bias_k; Tensor bias_v; Linear out_proj = nullptr; Tensor q_proj_weight; Tensor k_proj_weight; Tensor v_proj_weight; int64_t head_dim{}; }; TORCH_MODULE(MultiheadAttention); } // namespace torch::nn