torchtext.nn¶
MultiheadAttentionContainer¶
- class torchtext.nn.MultiheadAttentionContainer(nhead, in_proj_container, attention_layer, out_proj, batch_first=False)[source]¶
- __init__(nhead, in_proj_container, attention_layer, out_proj, batch_first=False) None[source]¶
- A multi-head attention container - Parameters:
- nhead – the number of heads in the multiheadattention model 
- in_proj_container – A container of multi-head in-projection linear layers (a.k.a nn.Linear). 
- attention_layer – The custom attention layer. The input sent from MHA container to the attention layer is in the shape of (…, L, N * H, E / H) for query and (…, S, N * H, E / H) for key/value while the output shape of the attention layer is expected to be (…, L, N * H, E / H). The attention_layer needs to support broadcast if users want the overall MultiheadAttentionContainer with broadcast. 
- out_proj – The multi-head out-projection layer (a.k.a nn.Linear). 
- batch_first – If - True, then the input and output tensors are provided as (…, N, L, E). Default:- False
 
 - Examples::
- >>> import torch >>> from torchtext.nn import MultiheadAttentionContainer, InProjContainer, ScaledDotProduct >>> embed_dim, num_heads, bsz = 10, 5, 64 >>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim)) >>> MHA = MultiheadAttentionContainer(num_heads, in_proj_container, ScaledDotProduct(), torch.nn.Linear(embed_dim, embed_dim)) >>> query = torch.rand((21, bsz, embed_dim)) >>> key = value = torch.rand((16, bsz, embed_dim)) >>> attn_output, attn_weights = MHA(query, key, value) >>> print(attn_output.shape) >>> torch.Size([21, 64, 10]) 
 
 - forward(query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, bias_k: Optional[Tensor] = None, bias_v: Optional[Tensor] = None) Tuple[Tensor, Tensor][source]¶
- Parameters:
- query (Tensor) – The query of the attention function. See “Attention Is All You Need” for more details. 
- key (Tensor) – The keys of the attention function. See “Attention Is All You Need” for more details. 
- value (Tensor) – The values of the attention function. See “Attention Is All You Need” for more details. 
- attn_mask (BoolTensor, optional) – 3D mask that prevents attention to certain positions. 
- bias_k (Tensor, optional) – one more key and value sequence to be added to keys at sequence dim (dim=-3). Those are used for incremental decoding. Users should provide - bias_v.
- bias_v (Tensor, optional) – one more key and value sequence to be added to values at sequence dim (dim=-3). Those are used for incremental decoding. Users should also provide - bias_k.
 
 - Shape: - Inputs: - query: \((..., L, N, E)\) 
- key: \((..., S, N, E)\) 
- value: \((..., S, N, E)\) 
- attn_mask, bias_k and bias_v: same with the shape of the corresponding args in attention layer. 
 
- Outputs: - attn_output: \((..., L, N, E)\) 
- attn_output_weights: \((N * H, L, S)\) 
 
 - Note: It’s optional to have the query/key/value inputs with more than three dimensions (for broadcast purpose). The MultiheadAttentionContainer module will operate on the last three dimensions. - where where L is the target length, S is the sequence length, H is the number of attention heads, N is the batch size, and E is the embedding dimension. 
 
InProjContainer¶
- class torchtext.nn.InProjContainer(query_proj, key_proj, value_proj)[source]¶
- __init__(query_proj, key_proj, value_proj) None[source]¶
- A in-proj container to project query/key/value in MultiheadAttention. This module happens before reshaping the projected query/key/value into multiple heads. See the linear layers (bottom) of Multi-head Attention in Fig 2 of Attention Is All You Need paper. Also check the usage example in torchtext.nn.MultiheadAttentionContainer. - Parameters:
- query_proj – a proj layer for query. A typical projection layer is torch.nn.Linear. 
- key_proj – a proj layer for key. A typical projection layer is torch.nn.Linear. 
- value_proj – a proj layer for value. A typical projection layer is torch.nn.Linear. 
 
 
 - forward(query: Tensor, key: Tensor, value: Tensor) Tuple[Tensor, Tensor, Tensor][source]¶
- Projects the input sequences using in-proj layers. query/key/value are simply passed to the forward func of query/key/value_proj, respectively. - Parameters:
- query (Tensor) – The query to be projected. 
- key (Tensor) – The keys to be projected. 
- value (Tensor) – The values to be projected. 
 
 - Examples::
- >>> import torch >>> from torchtext.nn import InProjContainer >>> embed_dim, bsz = 10, 64 >>> in_proj_container = InProjContainer(torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim), torch.nn.Linear(embed_dim, embed_dim)) >>> q = torch.rand((5, bsz, embed_dim)) >>> k = v = torch.rand((6, bsz, embed_dim)) >>> q, k, v = in_proj_container(q, k, v) 
 
 
ScaledDotProduct¶
- class torchtext.nn.ScaledDotProduct(dropout=0.0, batch_first=False)[source]¶
- __init__(dropout=0.0, batch_first=False) None[source]¶
- Processes a projected query and key-value pair to apply scaled dot product attention. - Parameters:
- dropout (float) – probability of dropping an attention weight. 
- batch_first – If - True, then the input and output tensors are provided as (batch, seq, feature). Default:- False
 
 - Examples::
- >>> import torch, torchtext >>> SDP = torchtext.nn.ScaledDotProduct(dropout=0.1) >>> q = torch.randn(21, 256, 3) >>> k = v = torch.randn(21, 256, 3) >>> attn_output, attn_weights = SDP(q, k, v) >>> print(attn_output.shape, attn_weights.shape) torch.Size([21, 256, 3]) torch.Size([256, 21, 21]) 
 
 - forward(query: Tensor, key: Tensor, value: Tensor, attn_mask: Optional[Tensor] = None, bias_k: Optional[Tensor] = None, bias_v: Optional[Tensor] = None) Tuple[Tensor, Tensor][source]¶
- Uses a scaled dot product with the projected key-value pair to update the projected query. - Parameters:
- query (Tensor) – Projected query 
- key (Tensor) – Projected key 
- value (Tensor) – Projected value 
- attn_mask (BoolTensor, optional) – 3D mask that prevents attention to certain positions. 
- attn_mask – 3D mask that prevents attention to certain positions. 
- bias_k (Tensor, optional) – one more key and value sequence to be added to keys at sequence dim (dim=-3). Those are used for incremental decoding. Users should provide - bias_v.
- bias_v (Tensor, optional) – one more key and value sequence to be added to values at sequence dim (dim=-3). Those are used for incremental decoding. Users should also provide - bias_k.
 
 - Shape:
- query: \((..., L, N * H, E / H)\) 
- key: \((..., S, N * H, E / H)\) 
- value: \((..., S, N * H, E / H)\) 
- attn_mask: \((N * H, L, S)\), positions with Trueare not allowed to attend
- while - Falsevalues will be unchanged.
 
- attn_mask: \((N * H, L, S)\), positions with 
- bias_k and bias_v:bias: \((1, N * H, E / H)\) 
- Output: \((..., L, N * H, E / H)\), \((N * H, L, S)\) 
 - Note: It’s optional to have the query/key/value inputs with more than three dimensions (for broadcast purpose).
- The ScaledDotProduct module will operate on the last three dimensions. 
 - where L is the target length, S is the source length, H is the number of attention heads, N is the batch size, and E is the embedding dimension.