:github_url: https://github.com/pytorch/pytorch .. _program_listing_file_torch_csrc_api_include_torch_nn_modules_transformercoder.h: Program Listing for File transformercoder.h =========================================== |exhale_lsh| :ref:`Return to documentation for file ` (``torch/csrc/api/include/torch/nn/modules/transformercoder.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include #include #include #include #include #include #include #include #include namespace torch::nn { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerEncoder // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API TransformerEncoderImpl : public Cloneable { public: TransformerEncoderImpl( TransformerEncoderLayer encoder_layer, int64_t num_layers) : TransformerEncoderImpl( TransformerEncoderOptions(std::move(encoder_layer), num_layers)) {} explicit TransformerEncoderImpl(TransformerEncoderOptions options_); Tensor forward( const Tensor& src, const Tensor& src_mask = {}, const Tensor& src_key_padding_mask = {}); void reset() override; void reset_parameters(); protected: FORWARD_HAS_DEFAULT_ARGS({1, AnyValue(Tensor())}, {2, AnyValue(Tensor())}) public: TransformerEncoderOptions options; ModuleList layers = nullptr; AnyModule norm; }; TORCH_MODULE(TransformerEncoder); // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ TransformerDecoder // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API TransformerDecoderImpl : public Cloneable { public: TransformerDecoderImpl( TransformerDecoderLayer decoder_layer, int64_t num_layers) : TransformerDecoderImpl( TransformerDecoderOptions(std::move(decoder_layer), num_layers)) {} explicit TransformerDecoderImpl(TransformerDecoderOptions options_); void reset() override; void reset_parameters(); Tensor forward( const Tensor& tgt, const Tensor& memory, const Tensor& tgt_mask = {}, const Tensor& memory_mask = {}, const Tensor& tgt_key_padding_mask = {}, const Tensor& memory_key_padding_mask = {}); TransformerDecoderOptions options; ModuleList layers{nullptr}; AnyModule norm; protected: FORWARD_HAS_DEFAULT_ARGS( {2, AnyValue(Tensor())}, {3, AnyValue(Tensor())}, {4, AnyValue(Tensor())}, {5, AnyValue(Tensor())}) }; TORCH_MODULE(TransformerDecoder); } // namespace torch::nn