Rate this Page

apply_low_precision_attention#

torchao.prototype.attention.apply_low_precision_attention(model: Module, backend: AttentionBackend | None = None, hadamard: HadamardMode = HadamardMode.NONE) Module[source][source]#

Apply low-precision attention to a model.

Must be called before torch.compile. KV caching should be disabled before calling (e.g., config.use_cache = False for HuggingFace models).

This replaces F.scaled_dot_product_attention with an FP8 SDPA for eager execution and sets a global pre-grad pass so that torch.compile will automatically fuse RoPE where detected.

Parameters:
  • model – The model to apply low-precision attention to.

  • backend – Backend to use. If None, auto-detected.

  • hadamard – Hadamard transform mode. HadamardMode.QKV applies the Hadamard transform to Q, K, and V before FP8 quantization, spreading outliers across the head dimension for better dynamic range utilization. HadamardMode.V_ONLY applies the transform to V only, improving V quantization quality without the cost of transforming Q and K. Requires D to be a power of 2 and <= 256.

Example:

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchao.prototype.attention import apply_low_precision_attention


# Simple model with attention
class MyModel(nn.Module):
    def __init__(self, embed_dim=512, num_heads=8):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)

    def forward(self, x):
        B, S, _ = x.shape
        q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
        attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        return self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, -1))


model = MyModel().to(device="cuda", dtype=torch.bfloat16).eval()

# Auto-detect best backend
model = apply_low_precision_attention(model)

# Or specify a backend explicitly
# model = apply_low_precision_attention(model, backend=AttentionBackend.FP8_FA3)

# Optional: torch.compile for RoPE fusion
model = torch.compile(model)