torch.nn.attention.varlen#
Created On: Oct 14, 2025 | Last Updated On: Oct 15, 2025
Variable-length attention implementation using Flash Attention.
This module provides a high-level Python interface for variable-length attention that calls into the optimized Flash Attention kernels.
- torch.nn.attention.varlen.varlen_attn(query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal=False, return_aux=None, scale=None)[source]#
Compute variable-length attention using Flash Attention. This function is similar to scaled_dot_product_attention but optimized for variable-length sequences using cumulative sequence position tensors.
- Parameters:
query (Tensor) – Query tensor; shape
key (Tensor) – Key tensor; shape
value (Tensor) – Value tensor; shape
cu_seq_q (Tensor) – Cumulative sequence positions for queries; shape
cu_seq_k (Tensor) – Cumulative sequence positions for keys/values; shape
max_q (int) – Maximum query sequence length in the batch.
max_k (int) – Maximum key/value sequence length in the batch.
is_causal (bool, optional) – If set to True, applies causal masking (default: False).
return_aux (Optional[AuxRequest]) – If not None and
return_aux.lseis True, also returns the logsumexp tensor.scale (float, optional) – Scaling factor for attention scores
- Returns:
Output tensor from attention computation; shape .
- If
return_auxis not None andreturn_aux.lseis True: lse (Tensor): Log-sum-exp of attention scores; shape .
- If
- Return type:
output (Tensor)
- Shape legend:
: Batch size
: Total number of query tokens in the batch (sum of all query sequence lengths)
: Total number of key/value tokens in the batch (sum of all key/value sequence lengths)
: Number of attention heads
: Head dimension
Example:
>>> batch_size, max_seq_len, embed_dim, num_heads = 2, 512, 1024, 16 >>> head_dim = embed_dim // num_heads >>> seq_lengths = [] >>> for _ in range(batch_size): ... length = torch.randint(1, max_seq_len // 64 + 1, (1,)).item() * 64 ... seq_lengths.append(min(length, max_seq_len)) >>> seq_lengths = torch.tensor(seq_lengths, device="cuda") >>> total_tokens = seq_lengths.sum().item() >>> >>> # Create packed query, key, value tensors >>> query = torch.randn( ... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda" ... ) >>> key = torch.randn( ... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda" ... ) >>> value = torch.randn( ... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda" ... ) >>> >>> # Build cumulative sequence tensor >>> cu_seq = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) >>> cu_seq[1:] = seq_lengths.cumsum(0) >>> max_len = seq_lengths.max().item() >>> >>> # Call varlen_attn >>> output = varlen_attn( ... query, key, value, cu_seq, cu_seq, max_len, max_len, is_causal=False ... )