apply_low_precision_attention#
- torchao.prototype.attention.apply_low_precision_attention(model: Module, backend: AttentionBackend | None = None, hadamard: HadamardMode = HadamardMode.NONE) Module[source][source]#
Apply low-precision attention to a model.
Must be called before
torch.compile. KV caching should be disabled before calling (e.g.,config.use_cache = Falsefor HuggingFace models).This replaces
F.scaled_dot_product_attentionwith an FP8 SDPA for eager execution and sets a global pre-grad pass so thattorch.compilewill automatically fuse RoPE where detected.- Parameters:
model – The model to apply low-precision attention to.
backend – Backend to use. If None, auto-detected.
hadamard – Hadamard transform mode.
HadamardMode.QKVapplies the Hadamard transform to Q, K, and V before FP8 quantization, spreading outliers across the head dimension for better dynamic range utilization.HadamardMode.V_ONLYapplies the transform to V only, improving V quantization quality without the cost of transforming Q and K. Requires D to be a power of 2 and <= 256.
Example:
import torch import torch.nn as nn import torch.nn.functional as F from torchao.prototype.attention import apply_low_precision_attention # Simple model with attention class MyModel(nn.Module): def __init__(self, embed_dim=512, num_heads=8): super().__init__() self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False) def forward(self, x): B, S, _ = x.shape q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2) attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True) return self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, -1)) model = MyModel().to(device="cuda", dtype=torch.bfloat16).eval() # Auto-detect best backend model = apply_low_precision_attention(model) # Or specify a backend explicitly # model = apply_low_precision_attention(model, backend=AttentionBackend.FP8_FA3) # Optional: torch.compile for RoPE fusion model = torch.compile(model)