Shortcuts

LLMMaskedCategorical

class torchrl.modules.LLMMaskedCategorical(logits: Tensor, mask: Tensor, ignore_index: int = - 100)[source]

LLM-optimized masked categorical distribution.

This class provides a more memory-efficient approach for LLM training by: 1. Using ignore_index=-100 for log_prob computation (no masking overhead) 2. Using traditional masking for sampling operations

This is particularly beneficial for large vocabulary sizes where masking all logits can be memory-intensive.

Parameters:
  • logits (torch.Tensor) – Event log probabilities (unnormalized), shape [B, T, C]. - B: batch size (optional) - T: sequence length - C: vocabulary size (number of classes)

  • mask (torch.Tensor) –

    Boolean mask indicating valid positions/tokens. - If shape [*B, T]: position-level masking. True means the position is valid (all tokens allowed). - If shape [*B, T, C]: token-level masking. True means the token is valid at that position.

    Warning

    Token-level masking is considerably more memory-intensive than position-level masking. Only use this if you need to mask tokens.

  • ignore_index (int, optional) – Index to ignore in log_prob computation. Defaults to -100.

Input shapes:
  • logits: [*B, T, C] (required)

  • mask: [*B, T] (position-level) or [*B, T, C] (token-level)

  • tokens (for log_prob): [*B, T] (token indices, with ignore_index for masked positions)

Use cases:
  1. Position-level masking
    >>> logits = torch.randn(2, 10, 50000)  # [B=2, T=10, C=50000]
    >>> mask = torch.ones(2, 10, dtype=torch.bool)  # [B, T]
    >>> mask[0, :5] = False  # mask first 5 positions of first sequence
    >>> dist = LLMMaskedCategorical(logits=logits, mask=mask)
    >>> tokens = torch.randint(0, 50000, (2, 10))  # [B, T]
    >>> tokens[0, :5] = -100  # set masked positions to ignore_index
    >>> log_probs = dist.log_prob(tokens)
    >>> samples = dist.sample()  # [B, T]
    
  2. Token-level masking
    >>> logits = torch.randn(2, 10, 50000)
    >>> mask = torch.ones(2, 10, 50000, dtype=torch.bool)  # [B, T, C]
    >>> mask[0, :5, :1000] = False  # mask first 1000 tokens for first 5 positions
    >>> dist = LLMMaskedCategorical(logits=logits, mask=mask)
    >>> tokens = torch.randint(0, 50000, (2, 10))
    >>> # Optionally, set tokens at fully-masked positions to ignore_index
    >>> log_probs = dist.log_prob(tokens)
    >>> samples = dist.sample()  # [B, T]
    

Notes

  • For log_prob, tokens must be of shape [B, T] and contain valid token indices (0 <= token < C), or ignore_index for masked/ignored positions.

  • For token-level masking, if a token is masked at a given position, log_prob will return -inf for that entry.

  • For position-level masking, if a position is masked (ignore_index), log_prob will return 0.0 for that entry (correct for cross-entropy loss).

  • Sampling always respects the mask (masked tokens/positions are never sampled).

All documented use cases are covered by tests in test_distributions.py.

clear_cache()[source]

Clear cached masked tensors to free memory.

entropy() Tensor[source]

Compute entropy using masked logits.

log_prob(value: Tensor) Tensor[source]

Compute log probabilities using ignore_index approach.

This is memory-efficient as it doesn’t require masking the logits. The value tensor should use ignore_index for masked positions.

property logits: Tensor

Get the original logits.

property mask: Tensor

Get the mask.

property masked_dist: Categorical

Get the masked distribution for sampling operations.

property masked_logits: Tensor

Get the masked logits for sampling operations.

property mode: Tensor

Get the mode using masked logits.

property position_level_masking: bool

Whether the mask is position-level (True) or token-level (False).

property probs: Tensor

Get probabilities from original logits.

rsample(sample_shape: torch.Size | Sequence[int] | None = None) torch.Tensor[source]

Reparameterized sampling using masked logits.

sample(sample_shape: torch.Size | Sequence[int] | None = None) torch.Tensor[source]

Sample from the distribution using masked logits.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources