Shortcuts

Source code for torchtune.modules.attention

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
from typing import Optional

import torch
from torch import nn
from torchtune.modules.attention_utils import _MaskType, _sdpa_or_flex_attention
from torchtune.modules.kv_cache import KVCache

logger = logging.getLogger(__name__)


[docs]class MultiHeadAttention(nn.Module): """Multi-headed attention layer with support for grouped query attention (GQA) introduced in https://arxiv.org/abs/2305.13245v1. GQA is a version of multiheaded attention (MHA) which uses fewer key/value heads than query heads by grouping n query heads for each key and value head. Multi-Query Attention is an extreme version where we have a single key and value head shared by all query heads. Following is an example of MHA, GQA and MQA with num_heads = 4 (credit for the documentation: `litgpt.Config <https://github.com/Lightning-AI/litgpt/blob/eda1aaaf391fd689664f95487ab03dc137e213fd/litgpt/config.py>`_). :: ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ │ v ││ v ││ v ││ v │ │ v │ │ v │ │ v │ └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ │ │ │ │ │ │ │ ┌───┐┌───┐┌───┐┌───┐ ┌───┐ ┌───┐ ┌───┐ │ k ││ k ││ k ││ k │ │ k │ │ k │ │ k │ └───┘└───┘└───┘└───┘ └───┘ └───┘ └───┘ │ │ │ │ ┌──┴──┐ ┌──┴──┐ ┌────┬──┴─┬────┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ ┌───┐┌───┐┌───┐┌───┐ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ │ q ││ q ││ q ││ q │ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ └───┘└───┘└───┘└───┘ ◀──────────────────▶ ◀──────────────────▶ ◀──────────────────▶ MHA GQA MQA n_kv_heads =4 n_kv_heads=2 n_kv_heads=1 Args: embed_dim (int): embedding dimension for the model num_heads (int): number of query heads. For MHA this is also the number of heads for key and value num_kv_heads (int): number of key and value heads. User should ensure ``num_heads % num_kv_heads == 0``. For standard MHA set ``num_kv_heads == num_heads``, for GQA ``num_kv_heads < num_heads``, and for MQA set ``num_kv_heads == 1``. head_dim (int): dimension of each head, calculated by ``embed_dim // num_heads``. q_proj (nn.Module): projection layer for query. k_proj (nn.Module): projection layer for key. v_proj (nn.Module): projection layer for value. output_proj (nn.Module): projection layer for output. pos_embeddings (Optional[nn.Module]): positional embeddings layer, e.g. RotaryPositionalEmbeddings. q_norm (Optional[nn.Module]): normalization layer for query, e.g. RMSNorm. For decoding, this is applied before updating from kv_cache. This means it will only support token wide normalization and not batch or sequence wide normalization. k_norm (Optional[nn.Module]): normalization layer for key, must be set if q_norm is. kv_cache (Optional[KVCache]): KVCache object used to cache key and value max_seq_len (int): maximum sequence length supported by the model. This is needed to compute the RoPE Cache. Default: 4096. is_causal (bool): sets the default mask to causal when no mask is provided attn_dropout (float): dropout value passed onto the scaled_dot_product_attention function. Default value is 0.0. Raises: ValueError: If ``num_heads % num_kv_heads != 0`` ValueError: If ``embed_dim % num_heads != 0`` ValueError: If ``attn_dropout < 0`` or ``attn_dropout > 1`` ValueError: if q_norm is defined without k_norm or vice versa """ def __init__( self, *, embed_dim: int, num_heads: int, num_kv_heads: int, head_dim: int, q_proj: nn.Module, k_proj: nn.Module, v_proj: nn.Module, output_proj: nn.Module, pos_embeddings: Optional[nn.Module] = None, q_norm: Optional[nn.Module] = None, k_norm: Optional[nn.Module] = None, kv_cache: Optional[KVCache] = None, max_seq_len: int = 4096, is_causal: bool = True, attn_dropout: float = 0.0, ) -> None: super().__init__() if num_heads % num_kv_heads != 0: raise ValueError( f"num_heads ({num_heads}) must be divisible by " f"num_kv_heads ({num_kv_heads})" ) if embed_dim % num_heads != 0: raise ValueError( f"embed_dim ({embed_dim}) must be divisible by " f"num_heads ({num_heads})" ) if attn_dropout < 0 or attn_dropout > 1: raise ValueError(f"attn_dropout ({embed_dim}) must be between 0.0 and 1.0") if bool(q_norm) ^ bool(k_norm): raise ValueError("q and k norm must be set together") # Set attributes self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.embed_dim = embed_dim self.attn_dropout = attn_dropout self.head_dim = head_dim self.max_seq_len = max_seq_len self.is_causal = is_causal # Set layers self.kv_cache = kv_cache self.q_proj = q_proj self.k_proj = k_proj self.v_proj = v_proj self.output_proj = output_proj self.q_norm = q_norm self.k_norm = k_norm self.pos_embeddings = pos_embeddings # Use flex attention if supported and we are sample packing self._attention_call = _sdpa_or_flex_attention() # this flag indicates whether to update the kv-cache during forward # passes. when disabled, we can have the cache setup but still # perform normal forward passes self.cache_enabled = False
[docs] def setup_cache( self, batch_size: int, dtype: torch.dtype, max_seq_len: int ) -> None: """Setup key value caches for attention calculation. If called after kv_cache is already setup, this will be skipped. Args: batch_size (int): batch size for the caches. dtype (torch.dtype): dtype for the caches. max_seq_len (int): maximum sequence length model will be run with. """ # Don't overwrite user defined kv_cache from init if self.kv_cache is not None: logger.warning( "Key value caches are already setup. You cannot call ``setup_caches()`` twice. Skipping." ) else: self.kv_cache = KVCache( batch_size=batch_size, max_seq_len=max_seq_len, num_kv_heads=self.num_kv_heads, head_dim=self.head_dim, dtype=dtype, ) self.cache_enabled = True
[docs] def reset_cache(self): """Reset the key value caches.""" if self.kv_cache is None: raise RuntimeError( "Key value caches are not setup. Call ``setup_caches()`` first." ) self.kv_cache.reset()
[docs] def forward( self, x: torch.Tensor, y: Optional[torch.Tensor] = None, *, mask: Optional[_MaskType] = None, input_pos: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: x (torch.Tensor): input tensor with shape [b x s_x x d] for the query y (Optional[torch.Tensor]): second input tensor with shape [b x s_y x d], is the input for k and v. For self attention, x=y. Optional only with kv_cache enabled. mask (Optional[_MaskType]): Used to mask the scores after the query-key multiplication and before the softmax. Either: A boolean tensor with shape ``[b x s x s]``, ``[b x s x self.encoder_max_cache_seq_len]``, or ``[b x s x self.decoder_max_cache_seq_len]`` if using KV-cacheing with encoder/decoder layers. A value of True in row ``i`` and column ``j`` means token ``i`` attends to token ``j``. A value of False means token ``i`` does not attend to token ``j``. If no mask is specified, a causal mask is used by default. A :class:`~torch.nn.attention.flex_attention.BlockMask` for document masking in a packed sequence created via `create_block_mask <https://pytorch.org/blog/flexattention/#mask-mods>`_. We use :func:`~torch.nn.attention.flex_attention.flex_attention` when computing attention with block masks. Default is None. input_pos (Optional[torch.Tensor]): Optional tensor which contains the position ids of each token. During training, this is used to indicate the positions of each token relative to its sample when packed, shape [b x s]. During inference, this indicates the position of the current token. If none, assume the index of the token is its position id. Default is None. Raises: ValueError: If no ``y`` input and ``kv_cache`` is not enabled. Returns: torch.Tensor: output tensor with attention applied Notation used for tensor shapes: - b: batch size - s_x: sequence length for x - s_y: sequence length for y - n_h: num heads - n_kv: num kv heads - d: embed dim - h_d: head dim """ # x has shape [b, s_x, d] # y has shape [b, s_y, d] b, s_x, _ = x.shape s_y = y.shape[1] if y is not None else 0 # q has shape [b, s_x, num_heads * head_dim] q = self.q_proj(x) # number of queries per key/value q_per_kv = self.num_heads // self.num_kv_heads q = q.view(b, s_x, self.num_kv_heads * q_per_kv, self.head_dim) # Apply positional embeddings if self.pos_embeddings is not None: q = self.pos_embeddings(q, input_pos=input_pos) # [b, n_h, s_x, h_d] q = q.transpose(1, 2) # Normalize q if self.q_norm is not None: q = self.q_norm(q) if y is None: if self.kv_cache is None or not self.cache_enabled: raise ValueError( "Must provide y input or use kv_cache to enable streaming decoding" ) k = self.kv_cache.k_cache v = self.kv_cache.v_cache else: # Update k and v shape, positional embeddings, and normalization # k,v shape [b, s_y, num_kv_heads * head_dim] k = self.k_proj(y) v = self.v_proj(y) # Apply positional embeddings # k,v shape: [b, s_y, n_kv, h_d] k = k.view(b, s_y, -1, self.head_dim) v = v.view(b, s_y, -1, self.head_dim) if self.pos_embeddings is not None: k = self.pos_embeddings(k, input_pos=input_pos) # k,v shape: [b, n_kv, s_y, h_d] k = k.transpose(1, 2) v = v.transpose(1, 2) # Normalize k if self.k_norm is not None: k = self.k_norm(k) # Update key-value cache if self.kv_cache is not None and self.cache_enabled: k, v = self.kv_cache.update(k, v) # If needed, expand the key and value tensors to have the same shape # as the query tensor by copying values across the relevant dim # k,v shape: [b, n_kv, s, h_d] -> [b, n_h, s, h_d] if self.num_heads != self.num_kv_heads: expand_shape = (b, self.num_kv_heads, q_per_kv, -1, self.head_dim) k = k.unsqueeze(2).expand(expand_shape).flatten(1, 2) v = v.unsqueeze(2).expand(expand_shape).flatten(1, 2) output = self._attention_call( q, k, v, mask=mask, dropout_p=self.attn_dropout if self.training else 0.0, is_causal=self.kv_cache is None and mask is None and self.is_causal, ) # reshape the output to be the same shape as the input output = output.transpose(1, 2).contiguous().view(b, s_x, -1) return self.output_proj(output)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources