Rate this Page

torch.nn.attention.varlen#

Created On: Oct 14, 2025 | Last Updated On: Oct 15, 2025

Variable-length attention implementation using Flash Attention.

This module provides a high-level Python interface for variable-length attention that calls into the optimized Flash Attention kernels.

torch.nn.attention.varlen.varlen_attn(query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal=False, return_aux=None, scale=None)[source]#

Compute variable-length attention using Flash Attention. This function is similar to scaled_dot_product_attention but optimized for variable-length sequences using cumulative sequence position tensors.

Parameters:
  • query (Tensor) – Query tensor; shape (Tq,H,D)(T_q, H, D)

  • key (Tensor) – Key tensor; shape (Tk,H,D)(T_k, H, D)

  • value (Tensor) – Value tensor; shape (Tk,H,D)(T_k, H, D)

  • cu_seq_q (Tensor) – Cumulative sequence positions for queries; shape (N+1,)(N+1,)

  • cu_seq_k (Tensor) – Cumulative sequence positions for keys/values; shape (N+1,)(N+1,)

  • max_q (int) – Maximum query sequence length in the batch.

  • max_k (int) – Maximum key/value sequence length in the batch.

  • is_causal (bool, optional) – If set to True, applies causal masking (default: False).

  • return_aux (Optional[AuxRequest]) – If not None and return_aux.lse is True, also returns the logsumexp tensor.

  • scale (float, optional) – Scaling factor for attention scores

Returns:

Output tensor from attention computation; shape (Tq,H,D)(T_q, H, D).

If return_aux is not None and return_aux.lse is True:

lse (Tensor): Log-sum-exp of attention scores; shape (Tq,H)(T_q, H).

Return type:

output (Tensor)

Shape legend:
  • NN: Batch size

  • TqT_q: Total number of query tokens in the batch (sum of all query sequence lengths)

  • TkT_k: Total number of key/value tokens in the batch (sum of all key/value sequence lengths)

  • HH: Number of attention heads

  • DD: Head dimension

Example:

>>> batch_size, max_seq_len, embed_dim, num_heads = 2, 512, 1024, 16
>>> head_dim = embed_dim // num_heads
>>> seq_lengths = []
>>> for _ in range(batch_size):
...     length = torch.randint(1, max_seq_len // 64 + 1, (1,)).item() * 64
...     seq_lengths.append(min(length, max_seq_len))
>>> seq_lengths = torch.tensor(seq_lengths, device="cuda")
>>> total_tokens = seq_lengths.sum().item()
>>>
>>> # Create packed query, key, value tensors
>>> query = torch.randn(
...     total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda"
... )
>>> key = torch.randn(
...     total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda"
... )
>>> value = torch.randn(
...     total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda"
... )
>>>
>>> # Build cumulative sequence tensor
>>> cu_seq = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32)
>>> cu_seq[1:] = seq_lengths.cumsum(0)
>>> max_len = seq_lengths.max().item()
>>>
>>> # Call varlen_attn
>>> output = varlen_attn(
...     query, key, value, cu_seq, cu_seq, max_len, max_len, is_causal=False
... )
class torch.nn.attention.varlen.AuxRequest(lse=False)[source]#

Request which auxiliary outputs to compute from varlen_attn.

Each field is a boolean indicating whether that auxiliary output should be computed.