Rate this Page

Class MultiheadAttentionImpl#

Inheritance Relationships#

Base Type#

Class Documentation#

class MultiheadAttentionImpl : public torch::nn::Cloneable<MultiheadAttentionImpl>#

Applies the MultiheadAttention function element-wise.

See https://pytorch.org/docs/main/nn.html#torch.nn.MultiheadAttention to learn about the exact behavior of this module.

See the documentation for torch::nn::MultiheadAttentionOptions class to learn what constructor arguments are supported for this module.

Example:

MultiheadAttention model(MultiheadAttentionOptions(20, 10).bias(false));

Public Functions

inline MultiheadAttentionImpl(int64_t embed_dim, int64_t num_heads)#
explicit MultiheadAttentionImpl(const MultiheadAttentionOptions &options_)#
std::tuple<Tensor, Tensor> forward(const Tensor &query, const Tensor &key, const Tensor &value, const Tensor &key_padding_mask = {}, bool need_weights = true, const Tensor &attn_mask = {}, bool average_attn_weights = true)#
virtual void reset() override#

reset() must perform initialization of all members with reference semantics, most importantly parameters, buffers and submodules.

void _reset_parameters()#

Public Members

MultiheadAttentionOptions options#

The options with which this Module was constructed.

bool _qkv_same_embed_dim = {}#
Tensor in_proj_weight#
Tensor in_proj_bias#
Tensor bias_k#
Tensor bias_v#
Linear out_proj = nullptr#
Tensor q_proj_weight#
Tensor k_proj_weight#
Tensor v_proj_weight#
int64_t head_dim = {}#

Protected Functions

inline virtual bool _forward_has_default_args() override#

The following three functions allow a module with default arguments in its forward method to be used in a Sequential module.

You should NEVER override these functions manually. Instead, you should use the FORWARD_HAS_DEFAULT_ARGS macro.

inline virtual unsigned int _forward_num_required_args() override#
inline std::vector<torch::nn::AnyValue> _forward_populate_default_args(std::vector<torch::nn::AnyValue> &&arguments) override#

Friends

friend struct torch::nn::AnyModuleHolder