torch.nn.attention.varlen#
Created On: Oct 14, 2025 | Last Updated On: Oct 15, 2025
Variable-length attention implementation using Flash Attention.
This module provides a high-level Python interface for variable-length attention that calls into the optimized Flash Attention kernels.
- torch.nn.attention.varlen.varlen_attn(query, key, value, cu_seq_q, cu_seq_k, max_q, max_k, is_causal=False, return_aux=None)[source]#
Compute variable-length attention using Flash Attention. This function is similar to scaled_dot_product_attention but optimized for variable-length sequences using cumulative sequence position tensors. Args: - query (Tensor): Query tensor; shape - key (Tensor): Key tensor; shape - value (Tensor): Value tensor; shape - cu_seq_q (Tensor): Cumulative sequence positions for queries; shape - cu_seq_k (Tensor): Cumulative sequence positions for keys/values; shape - max_q (int): Maximum query sequence length in the batch. - max_k (int): Maximum key/value sequence length in the batch. - is_causal (bool, optional): If set to True, applies causal masking (default: False). - return_aux (Optional[AuxRequest]): If not None and
return_aux.lse
is True, also returns the logsumexp tensor.Shape legend: - : Batch size - : Total number of query tokens in the batch (sum of all query sequence lengths) - : Total number of key/value tokens in the batch (sum of all key/value sequence lengths) - : Number of attention heads - : Head dimension
Returns: - Tensor: Output tensor from attention computation - If
return_aux
is not None andreturn_aux.lse
is True, returns a tuple of Tensors: (output, lse), where lse is the logsumexpExample:
>>> batch_size, max_seq_len, embed_dim, num_heads = 2, 512, 1024, 16 >>> head_dim = embed_dim // num_heads >>> seq_lengths = [] >>> for _ in range(batch_size): ... length = torch.randint(1, max_seq_len // 64 + 1, (1,)).item() * 64 ... seq_lengths.append(min(length, max_seq_len)) >>> seq_lengths = torch.tensor(seq_lengths, device="cuda") >>> total_tokens = seq_lengths.sum().item() >>> >>> # Create packed query, key, value tensors >>> query = torch.randn( ... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda" ... ) >>> key = torch.randn( ... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda" ... ) >>> value = torch.randn( ... total_tokens, num_heads, head_dim, dtype=torch.float16, device="cuda" ... ) >>> >>> # Build cumulative sequence tensor >>> cu_seq = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) >>> cu_seq[1:] = seq_lengths.cumsum(0) >>> max_len = seq_lengths.max().item() >>> >>> # Call varlen_attn >>> output = varlen_attn( ... query, key, value, cu_seq, cu_seq, max_len, max_len, is_causal=False ... )
- Return type