torch.nn.attention.varlen#
Created On: Oct 14, 2025 | Last Updated On: Mar 10, 2026
Variable-length attention implementation using Flash Attention.
This module provides a high-level Python interface for variable-length attention that calls into the optimized Flash Attention kernels.
- torch.nn.attention.varlen.varlen_attn(query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, *, return_aux=None, scale=None, window_size=(-1, -1), seqused_k=None, block_table=None, num_splits=None)[source]#
Compute variable-length attention using Flash Attention.
This function is similar to scaled_dot_product_attention but optimized for variable-length sequences using cumulative sequence position tensors.
- Parameters:
query (Tensor) – Query tensor; shape
key (Tensor) – Key tensor; shape , or when
block_tableis provided.value (Tensor) – Value tensor; shape , or when
block_tableis provided.cu_seq_q (Tensor) – Cumulative sequence positions for queries; shape
cu_seq_k (Tensor) – Cumulative sequence positions for keys/values; shape
max_q (int) – Maximum query sequence length in the batch.
max_k (int) – Maximum key/value sequence length in the batch.
return_aux (Optional[AuxRequest]) – If not None and
return_aux.lseis True, also returns the logsumexp tensor.scale (float, optional) – Scaling factor for attention scores
window_size (tuple[int, int], optional) – Window size for sliding window attention as (left, right). Use (-1, -1) for full attention (default), (-1, 0) for causal attention, or (W, 0) for causal attention with sliding window of size W.
seqused_k (Tensor, optional) – Number of valid KV tokens per batch element; shape . When set, only the first
seqused_k[i]tokens in the key/value sequence for batch element i participate in attention. Useful for KV-cache decoding where the cache slot is larger than the actual sequence. Inference-only (not supported in backward).block_table (Tensor, optional) –
Block table for paged KV cache; shape , dtype
int32. Requiresseqused_k. Inference-only (not supported in backward).When
block_tableis provided,keyandvalueare a “pool” of pages of tokens of KV data and the pages belong to any sequence/order. Theblock_tableis what maps each sequence’s logical chunks back to physical pages in this pool.seqused_k[i]tells the kernel how many tokens in sequence i are actually valid, since the last page is typically only partially filled.num_splits (int, optional) – Number of splits for split-KV. Set to
1to disable split-KV which enables batch invariance. Split-KV parallelizes the key/value sequence dimension across multiple thread blocks and combines partial results. The split decision depends onmax_k(the longest sequence in the batch), so different batch compositions can change the reduction order and produce different floating-point results for the same sequence. When this is disabled, bitwise identical outputs are guaranteed for a given sequence regardless of what other sequences are in the batch, at the cost of lower GPU utilization when there are few queries. WhenNone(default), the kernel chooses automatically.
- Returns:
Output tensor from attention computation; shape .
- If
return_auxis not None andreturn_aux.lseis True: lse (Tensor): Log-sum-exp of attention scores; shape .
- If
- Return type:
output (Tensor)
- Shape legend:
: Batch size
: Total number of query tokens in the batch (sum of all query sequence lengths)
: Total number of key/value tokens in the batch (sum of all key/value sequence lengths)
: Number of attention heads
: Head dimension
Example:
>>> batch_size, max_seq_len, embed_dim, num_heads = 2, 512, 1024, 16 >>> head_dim = embed_dim // num_heads >>> seq_lengths = [] >>> for _ in range(batch_size): ... length = torch.randint(1, max_seq_len // 64 + 1, (1,)).item() * 64 ... seq_lengths.append(min(length, max_seq_len)) >>> seq_lengths = torch.tensor(seq_lengths, device="cuda") >>> total_tokens = seq_lengths.sum().item() >>> >>> # Create packed query, key, value tensors >>> query = torch.randn( ... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda" ... ) >>> key = torch.randn( ... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda" ... ) >>> value = torch.randn( ... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda" ... ) >>> >>> # Build cumulative sequence tensor >>> cu_seq = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) >>> cu_seq[1:] = seq_lengths.cumsum(0) >>> max_len = seq_lengths.max().item() >>> >>> # Call varlen_attn >>> output = varlen_attn( ... query, key, value, cu_seq, cu_seq, max_len, max_len ... )
- torch.nn.attention.varlen.varlen_attn_out(out, query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, *, return_aux=None, scale=None, window_size=(-1, -1), seqused_k=None, block_table=None, num_splits=None)[source]#
Compute variable-length attention using Flash Attention with a pre-allocated output tensor.
Same as
varlen_attn()but writes the attention output into the providedouttensor instead of allocating a new one.