Shortcuts

Source code for torchrl.modules.llm.policies.common

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import torch
from tensordict import NestedKey, TensorDictBase
from tensordict.nn import TensorDictModuleBase, TensorDictSequential
from torch import distributions as D
from torch.distributions import Categorical
from torchrl.modules import MaskedCategorical


[docs]class CategoricalSequential(TensorDictModuleBase): """A ProbabilisticTensorDictSequential subclass meant to work with LLMs. .. seealso:: :class:`~tensordict.nn.ProbabilisticTensorDictSequential` class. """ generate: bool def get_dist( self, tensordict: TensorDictBase, tensordict_out: TensorDictBase | None = None, as_padded_tensor: bool | None = None, as_nested_tensor: bool | None = None, padding_value: float | None = None, padding_side: str = "right", layout: torch.layout | None = None, **kwargs, ) -> D.Distribution: td_out = self(tensordict.copy()) # By default, pad and use masked categorical if as_padded_tensor is None: as_padded_tensor = as_nested_tensor is not True if padding_value is None: padding_value = 0.0 if as_nested_tensor is None: as_nested_tensor = False logits = td_out.get( "logits", as_padded_tensor=as_padded_tensor, as_nested_tensor=as_nested_tensor, padding_value=padding_value, padding_side=padding_side, layout=layout, ) if as_padded_tensor: # We can use MaskedCategorical dist = MaskedCategorical( logits=logits, mask=logits != padding_value, use_cross_entropy=True, ) return dist return Categorical(logits) # Sampling is taken care of by the sub-modules forward = TensorDictSequential.forward @property def log_prob_keys(self) -> list[NestedKey]: return getattr(self, "_log_prob_keys", ["log_probs"]) @log_prob_keys.setter def log_prob_keys(self, value: list[NestedKey]): self._log_prob_keys = value @property def log_prob_key(self) -> NestedKey: return self.log_prob_keys[0] @log_prob_key.setter def log_prob_key(self, value: NestedKey) -> None: self.log_prob_keys[0] = value @property def dist_params_keys(self) -> list[NestedKey]: raise NotImplementedError @property def dist_sample_keys(self) -> list[NestedKey]: return ["tokens_response"] def log_prob(self, data: TensorDictBase, **get_kwargs) -> TensorDictBase: if not self.generate: data = self(data) return data.get(self.log_prob_key, **get_kwargs) raise RuntimeError("log_prob not callable when generate=True.")

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources