LLMMaskedCategorical¶
- class torchrl.modules.LLMMaskedCategorical(logits: Tensor, mask: Tensor, ignore_index: int = - 100)[source]¶
LLM-optimized masked categorical distribution.
This class provides a more memory-efficient approach for LLM training by: 1. Using ignore_index=-100 for log_prob computation (no masking overhead) 2. Using traditional masking for sampling operations
This is particularly beneficial for large vocabulary sizes where masking all logits can be memory-intensive.
- Parameters:
logits (torch.Tensor) – Event log probabilities (unnormalized), shape [B, T, C]. - B: batch size (optional) - T: sequence length - C: vocabulary size (number of classes)
mask (torch.Tensor) –
Boolean mask indicating valid positions/tokens. - If shape [*B, T]: position-level masking. True means the position is valid (all tokens allowed). - If shape [*B, T, C]: token-level masking. True means the token is valid at that position.
Warning
Token-level masking is considerably more memory-intensive than position-level masking. Only use this if you need to mask tokens.
ignore_index (int, optional) – Index to ignore in log_prob computation. Defaults to -100.
- Input shapes:
- Use cases:
- Position-level masking
>>> logits = torch.randn(2, 10, 50000) # [B=2, T=10, C=50000] >>> mask = torch.ones(2, 10, dtype=torch.bool) # [B, T] >>> mask[0, :5] = False # mask first 5 positions of first sequence >>> dist = LLMMaskedCategorical(logits=logits, mask=mask) >>> tokens = torch.randint(0, 50000, (2, 10)) # [B, T] >>> tokens[0, :5] = -100 # set masked positions to ignore_index >>> log_probs = dist.log_prob(tokens) >>> samples = dist.sample() # [B, T]
- Token-level masking
>>> logits = torch.randn(2, 10, 50000) >>> mask = torch.ones(2, 10, 50000, dtype=torch.bool) # [B, T, C] >>> mask[0, :5, :1000] = False # mask first 1000 tokens for first 5 positions >>> dist = LLMMaskedCategorical(logits=logits, mask=mask) >>> tokens = torch.randint(0, 50000, (2, 10)) >>> # Optionally, set tokens at fully-masked positions to ignore_index >>> log_probs = dist.log_prob(tokens) >>> samples = dist.sample() # [B, T]
Notes
For log_prob, tokens must be of shape [B, T] and contain valid token indices (0 <= token < C), or ignore_index for masked/ignored positions.
For token-level masking, if a token is masked at a given position, log_prob will return -inf for that entry.
For position-level masking, if a position is masked (ignore_index), log_prob will return 0.0 for that entry (correct for cross-entropy loss).
Sampling always respects the mask (masked tokens/positions are never sampled).
All documented use cases are covered by tests in test_distributions.py.
- log_prob(value: Tensor) Tensor [source]¶
Compute log probabilities using ignore_index approach.
This is memory-efficient as it doesn’t require masking the logits. The value tensor should use ignore_index for masked positions.
- property masked_dist: Categorical¶
Get the masked distribution for sampling operations.
- property position_level_masking: bool¶
Whether the mask is position-level (True) or token-level (False).
- rsample(sample_shape: torch.Size | Sequence[int] | None = None) torch.Tensor [source]¶
Reparameterized sampling using masked logits.
- sample(sample_shape: torch.Size | Sequence[int] | None = None) torch.Tensor [source]¶
Sample from the distribution using masked logits.