MultiheadAttention¶
- 
class 
torch.nn.MultiheadAttention(embed_dim, num_heads, dropout=0.0, bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, batch_first=False, device=None, dtype=None)[source]¶ Allows the model to jointly attend to information from different representation subspaces. See Attention Is All You Need.
where .
- Parameters
 embed_dim – Total dimension of the model.
num_heads – Number of parallel attention heads. Note that
embed_dimwill be split acrossnum_heads(i.e. each head will have dimensionembed_dim // num_heads).dropout – Dropout probability on
attn_output_weights. Default:0.0(no dropout).bias – If specified, adds bias to input / output projection layers. Default:
True.add_bias_kv – If specified, adds bias to the key and value sequences at dim=0. Default:
False.add_zero_attn – If specified, adds a new batch of zeros to the key and value sequences at dim=1. Default:
False.kdim – Total number of features for keys. Default:
None(useskdim=embed_dim).vdim – Total number of features for values. Default:
None(usesvdim=embed_dim).batch_first – If
True, then the input and output tensors are provided as (batch, seq, feature). Default:False(seq, batch, feature).
Examples:
>>> multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) >>> attn_output, attn_output_weights = multihead_attn(query, key, value)
- 
forward(query, key, value, key_padding_mask=None, need_weights=True, attn_mask=None)[source]¶ - Parameters
 query – Query embeddings of shape when
batch_first=Falseor whenbatch_first=True, where is the target sequence length, is the batch size, and is the query embedding dimensionembed_dim. Queries are compared against key-value pairs to produce the output. See “Attention Is All You Need” for more details.key – Key embeddings of shape when
batch_first=Falseor whenbatch_first=True, where is the source sequence length, is the batch size, and is the key embedding dimensionkdim. See “Attention Is All You Need” for more details.value – Value embeddings of shape when
batch_first=Falseor whenbatch_first=True, where is the source sequence length, is the batch size, and is the value embedding dimensionvdim. See “Attention Is All You Need” for more details.key_padding_mask – If specified, a mask of shape indicating which elements within
keyto ignore for the purpose of attention (i.e. treat as “padding”). Binary and byte masks are supported. For a binary mask, aTruevalue indicates that the correspondingkeyvalue will be ignored for the purpose of attention. For a byte mask, a non-zero value indicates that the correspondingkeyvalue will be ignored.need_weights – If specified, returns
attn_output_weightsin addition toattn_outputs. Default:True.attn_mask – If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape or , where is the batch size, is the target sequence length, and is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary, byte, and float masks are supported. For a binary mask, a
Truevalue indicates that the corresponding position is not allowed to attend. For a byte mask, a non-zero value indicates that the corresponding position is not allowed to attend. For a float mask, the mask values will be added to the attention weight.
- Outputs:
 attn_output - Attention outputs of shape when
batch_first=Falseor whenbatch_first=True, where is the target sequence length, is the batch size, and is the embedding dimensionembed_dim.attn_output_weights - Attention output weights of shape , where is the batch size, is the target sequence length, and is the source sequence length. Only returned when
need_weights=True.