:github_url: https://github.com/pytorch/pytorch .. _program_listing_file_torch_csrc_api_include_torch_nn_modules_transformer.h: Program Listing for File transformer.h ====================================== |exhale_lsh| :ref:`Return to documentation for file ` (``torch/csrc/api/include/torch/nn/modules/transformer.h``) .. |exhale_lsh| unicode:: U+021B0 .. UPWARDS ARROW WITH TIP LEFTWARDS .. code-block:: cpp #pragma once #include #include #include #include #include #include #include namespace torch::nn { // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Transformer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ class TORCH_API TransformerImpl : public Cloneable { public: explicit TransformerImpl(TransformerOptions options_); Tensor forward( const Tensor& src, const Tensor& tgt, const Tensor& src_mask = {}, const Tensor& tgt_mask = {}, const Tensor& memory_mask = {}, const Tensor& src_key_padding_mask = {}, const Tensor& tgt_key_padding_mask = {}, const Tensor& memory_key_padding_mask = {}); void reset() override; void reset_parameters(); static Tensor generate_square_subsequent_mask(int64_t sz); protected: FORWARD_HAS_DEFAULT_ARGS( {2, AnyValue(Tensor())}, {3, AnyValue(Tensor())}, {4, AnyValue(Tensor())}, {5, AnyValue(Tensor())}, {6, AnyValue(Tensor())}, {7, AnyValue(Tensor())}) public: TransformerOptions options; AnyModule encoder; AnyModule decoder; }; TORCH_MODULE(Transformer); } // namespace torch::nn