Shortcuts

Source code for torchrl.modules.llm.policies.common

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import weakref
from typing import Any, Literal, overload

import torch
from tensordict import NestedKey, TensorDictBase
from tensordict.nn import TensorDictModuleBase, TensorDictSequential
from tensordict.tensorclass import TensorClass
from tensordict.utils import _zip_strict
from torch import distributions as D
from torch.distributions import Categorical
from torch.nn.utils.rnn import pad_sequence
from torchrl.data.llm import History
from torchrl.data.tensor_specs import Unbounded
from torchrl.modules.distributions.discrete import LLMMaskedCategorical

# TODOs:
# - [ ] Remove the useless view(-1) calls when num_samples is not > 1
# - [ ] Remove as_list=True and use a context manager to handle that
# - [ ] Make sure tensordict can handle nested lazy tds that have a get(key, as_list=True) - I think it breaks atm
# - [ ] Handle packing


[docs]class Tokens(TensorClass["nocast"]): """A Tokens container. Args: prompt (torch.Tensor | None): The prompt tokens. response (torch.Tensor | None): The response tokens. assistant (torch.Tensor | None): The assistant tokens. full (torch.Tensor | None): The tokens across prompt and response. padded (bool | None): Whether the tokens are padded. Shapes: - prompt: (batch_size, prompt_length). If padded, padded on the left. - response: (batch_size, response_length). If padded, padded on the right. - full: (batch_size, prompt_length + response_length). If padded, padded on the left and/or right. - padded: bool. """ prompt: torch.Tensor | None = None response: torch.Tensor | None = None full: torch.Tensor | None = None padded: bool | None = None
[docs] @classmethod def default_spec( cls, shape=(-1,), keys: list[Literal["prompt", "response", "full"]] | None = None, ): """A default spec to use in transforms / envs that return Tokens objects.""" from torchrl.data import Composite, NonTensor if keys is None: keys = ["prompt", "response", "full"] defaults = {k: Unbounded(shape=shape + (-1,)) for k in keys} defaults["padded"] = NonTensor(shape=shape, example_data=False) return Composite(defaults, shape=shape[:-1], data_cls=cls, step_mdp_static=True)
[docs]class Masks(TensorClass["nocast"]): """A Masks container. Args: all_attention_mask (torch.Tensor | None): The attention mask across all tokens. The attention mask represents the tokens that are not masked. and that the model can attend to. all_assistant_mask (torch.Tensor | None): The assistant mask across all tokens, i.e. the tokens that are produced by the assistant. This is recovered from the the `assistant_masks` output of :meth:`~torchrl.data.llm.History.apply_chat_template`, if the chat template supports it. padded (bool | None): Whether the masks are padded. The masks always have the same shape as the `full` tensor in :class:`~torchrl.modules.llm.policies.common.Tokens`, and :class:`~torchrl.modules.llm.policies.common.LogProbs`. """ all_attention_mask: torch.Tensor | None = None all_assistant_mask: torch.Tensor | None = None padded: bool | None = None
[docs] @classmethod def default_spec( cls, shape=(-1,), keys: list[Literal["all_attention_mask", "all_assistant_mask"]] | None = None, ): """A default spec to use in transforms / envs that return Masks objects.""" from torchrl.data import Composite, NonTensor if keys is None: keys = ["all_attention_mask", "all_assistant_mask"] defaults = {k: Unbounded(shape=shape + (-1,)) for k in keys} defaults["padded"] = NonTensor(shape=shape, example_data=False) return Composite(defaults, shape=shape[:-1], data_cls=cls, step_mdp_static=True)
[docs]class ChatHistory(TensorClass["nocast"]): """A chat history container for managing conversation data in LLM environments. This class serves as a structured container for chat history data, similar to how :class:`~torchrl.modules.llm.policies.Text` and :class:`~torchrl.modules.llm.policies.Tokens` are used for text and token data respectively. **Recent Changes:** - **Modular Design**: ChatHistory is now used consistently across LLM wrappers and environments to represent conversation state in a structured way. - **Integration with Wrappers**: Both vLLMWrapper and TransformersWrapper now use ChatHistory objects when `input_mode="history"` is specified. - **Environment Support**: ChatEnv and related environments use ChatHistory for state management. Args: prompt (History | None): The prompt history stack containing the conversation up to the current point. response (History | None): The response history items (typically generated by the LLM). full (History | None): The complete history across prompt and response. Example: >>> from torchrl.data.llm import History >>> from torchrl.modules.llm.policies import ChatHistory >>> >>> # Create a conversation history >>> history = History.from_chats([[ ... {"role": "user", "content": "Hello"}, ... {"role": "assistant", "content": "Hi there!"} ... ]]) >>> >>> # Create ChatHistory object for LLM wrapper input >>> chat_history = ChatHistory(prompt=history) >>> >>> # Use with LLM wrapper >>> result = wrapper(TensorDict(history=chat_history, batch_size=(1,))) >>> print(result["history"].response) # New response from LLM >>> print(result["history"].full) # Complete conversation .. seealso:: :class:`~torchrl.modules.llm.policies.Text`: Container for text data. :class:`~torchrl.modules.llm.policies.Tokens`: Container for token data. :class:`~torchrl.data.llm.History`: The underlying History class for conversation data. """ prompt: History | None = None response: History | None = None full: History | None = None
[docs] @classmethod def default_spec( cls, shape=(-1,), keys: list[Literal["prompt", "response", "full"]] | None = None, ): """A default spec to use in transforms / envs that return ChatHistory objects.""" from torchrl.data import Composite if keys is None: keys = ["prompt", "response", "full"] return Composite( {k: History.default_spec(shape=shape + (-1,)) for k in keys}, shape=shape[:-1], data_cls=cls, step_mdp_static=True, )
[docs]class LogProbs(TensorClass["nocast"]): """A log-probability container. Args: prompt (torch.Tensor | None): The prompt log-probabilities. response (torch.Tensor | None): The response log-probabilities. assistant (torch.Tensor | None): The assistant log-probabilities. full (torch.Tensor | None): The log-probabilities across prompt and response. padded (bool | None): Whether the log-probabilities are padded. Shapes: - prompt: (batch_size, prompt_length). If padded, padded on the left. - response: (batch_size, response_length). If padded, padded on the right. - full: (batch_size, prompt_length + response_length). If padded, padded on the left and/or right. - padded: bool. """ prompt: torch.Tensor | None = None response: torch.Tensor | None = None full: torch.Tensor | None = None padded: bool | None = None
[docs] @classmethod def default_spec( cls, shape=(-1,), keys: list[Literal["prompt", "response", "full"]] | None = None, ): """A default spec to use in transforms / envs that return LogProbs objects.""" from torchrl.data import Composite, NonTensor if keys is None: keys = ["prompt", "response", "full"] defaults = {k: Unbounded(shape=shape + (-1,)) for k in keys} defaults["padded"] = NonTensor(shape=shape, example_data=False) return Composite(defaults, shape=shape[:-1], data_cls=cls, step_mdp_static=True)
[docs]class Text(TensorClass["nocast"]): """A text container. Args: prompt (str | None): The prompt text. response (str | None): The response text. full (str | None): The text across prompt and response. """ prompt: str | None = None response: str | None = None full: str | None = None
[docs] @classmethod def default_spec( cls, shape=(-1,), keys: list[Literal["prompt", "response", "full"]] | None = None, ): """A default spec to use in transforms / envs that return Text objects.""" from torchrl.data import Composite, NonTensor if keys is None: keys = ["prompt", "response", "full"] defaults = {k: NonTensor(shape=shape, example_data="a string") for k in keys} return Composite(defaults, shape=shape[:-1], data_cls=cls, step_mdp_static=True)
class LogProbDistribution(D.Distribution): """A distribution that works directly with log-probabilities. This is useful when we have pre-computed log-probabilities (e.g., from vLLM) and want to compute log_prob() without having access to the original logits. """ def __init__(self, log_probs: torch.Tensor, mask: torch.Tensor | None = None): """Initialize with log-probabilities. Args: log_probs: Tensor of shape [batch, seq_len] containing log-probabilities mask: Optional mask of shape [batch, seq_len] indicating valid positions """ self.log_probs = log_probs self.mask = mask batch_shape = log_probs.shape[:-1] if log_probs.dim() > 1 else log_probs.shape event_shape = log_probs.shape[-1:] if log_probs.dim() > 1 else torch.Size([]) super().__init__(batch_shape=batch_shape, event_shape=event_shape) def log_prob(self, value: torch.Tensor) -> torch.Tensor: """Compute log-probability for the given tokens. Args: value: Tensor of shape [batch, seq_len] containing token indices Returns: Tensor of shape [batch, seq_len] containing log-probabilities """ # For log-prob distributions, we just return the pre-computed log-probs # at the positions specified by the value tensor if value.shape != self.log_probs.shape: raise ValueError( f"Value shape {value.shape} must match log_probs shape {self.log_probs.shape}" ) result = self.log_probs.clone() # Apply mask if provided if self.mask is not None: result = torch.where( self.mask, result, torch.tensor(0.0, device=result.device, dtype=result.dtype), ) return result def sample(self, sample_shape: tuple | torch.Size | None = None) -> torch.Tensor: """Sample from the distribution. Note: This is not implemented for log-prob distributions since we don't have the full probability distribution, only the log-probs for specific tokens. """ raise NotImplementedError("Sampling not supported for LogProbDistribution") def entropy(self) -> torch.Tensor: """Compute entropy. Note: This is not implemented for log-prob distributions since we don't have the full probability distribution. """ raise NotImplementedError("Entropy not supported for LogProbDistribution")
[docs]class LLMWrapperBase(TensorDictModuleBase): r"""A LLM wrapper base class. This class provides a consistent interface for LLM wrappers with the following features: - Support for different input modalities (history, text, tokens) - Consistent output structure using TensorClass objects (Text, Tokens, Masks, LogProbs) - Configurable generation and log-probability computation Args: model: The underlying model to wrap. Keyword Args: tokenizer: The tokenizer to use for encoding and decoding text. input_mode: The input modality to use. Must be one of "history", "text", or "tokens". input_key: The key for the input data. If None, defaults to the input_mode name. attention_mask_key: The key for attention masks (used in "tokens" mode). generate: Whether to enable text generation. generate_kwargs: Additional arguments to pass to the model's generate method. tokenizer_kwargs: Additional arguments to pass to the tokenizer. pad_output: Whether to pad the output sequences to a uniform length. inplace: Determines how the module should handle in-place operations. device: The device to use for computation. layout: The layout to use for the output tensors when pad_output=False. num_samples: The number of samples to generate. log_probs_key (NestedKey | None, optional): The key for the log probabilities :class:`~torchrl.modules.llm.policies.LogProbs` object. Defaults to `"log_probs"`. text_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Text` object. Defaults to `"text"`. tokens_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Tokens` object. Defaults to `"tokens"`. masks_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Masks` object. Defaults to `"masks"`. Attributes: collector: The collector associated with the module, if it exists. .. seealso:: - :class:`~torchrl.modules.llm.policies.TransformersWrapper` (see :ref:`ref_transformers_wrapper`) - :class:`~torchrl.modules.llm.policies.vLLMWrapper` (see :ref:`ref_vllm_wrapper`) """ generate: bool pad_output: bool text_key: NestedKey tokens_key: NestedKey masks_key: NestedKey log_probs_key: NestedKey in_keys: list[NestedKey] out_keys: list[NestedKey] inplace: bool device: torch.device | None layout: torch.layout | None num_samples: int | None @overload def __init__( self, model: Any | str, *, tokenizer: callable | str | None = None, # type: ignore input_mode: str = "history", input_key: NestedKey | None = None, attention_mask_key: str = "attention_mask", generate: bool = True, generate_kwargs: dict | None = None, tokenizer_kwargs: dict | None = None, pad_output: bool = False, inplace: Literal[True, False, "empty"] | None = None, device: torch.device | None = None, layout: torch.layout | None = None, num_samples: int | None = None, chat_template_name: Literal["chatml_format", "qwen"] | None = None, chat_template: str | None = None, return_log_probs: bool | None = None, history_key: NestedKey | None = "history", text_key: NestedKey | None = "text", tokens_key: NestedKey | None = "tokens", masks_key: NestedKey | None = "masks", log_probs_key: NestedKey | None = "log_probs", ): ... def __init__(self, *args, **kwargs): super().__init__()
[docs] def get_new_version(self, **kwargs): """Returns a new version of the module with altered parameters. For instance, the generate parameter can be altered to enable text generation or log-probabilities computation. This is especially useful when one wants to avoid re-initializing the module with a new set of parameters, when the same parameters could be used to gather log-probs. Positional arguments are not supported. See the class constructor for more details about the parameters. """ raise NotImplementedError
_collector: weakref.ReferenceType[ LLMCollector # noqa: F821 # type: ignore ] | None = None
[docs] def register_collector(self, collector: LLMCollector): # noqa: F821 # type: ignore """Registers a weak reference to the container collector. This is automatically called by the :class:`~torchrl.collectors.llm.LLMCollector` class. """ self._collector = weakref.ref(collector)
@property def collector(self) -> LLMCollector | None: # noqa: F821 # type: ignore """Returns the collector associated with the module, if it exists.""" return self._collector() if self._collector is not None else None
[docs] def get_dist( self, tensordict: TensorDictBase, tensordict_out: TensorDictBase | None = None, logits_key: NestedKey = "logits", mask_key: NestedKey | None = None, as_padded_tensor: bool | None = None, as_nested_tensor: bool | None = None, padding_value: float | None = None, padding_side: str = "left", layout: torch.layout | None = None, **kwargs, ) -> D.Distribution: """Get distribution from logits/log-probs with optional masking. Args: tensordict: Input tensordict tensordict_out: Output tensordict (optional) logits_key: Key for logits/log-probs mask_key: Key for mask (optional). as_padded_tensor: Whether to return padded tensor. Default is False. as_nested_tensor: Whether to return nested tensor. Default is False. padding_value: Value for padding. Default is 0.0 for logits and False for masks. padding_side: Side for padding. Default is left by convention. layout: Tensor layout **kwargs: Additional arguments Returns: Distribution (Categorical or LLMMaskedCategorical) """ if self.generate: raise NotImplementedError( "get_dist is not implemented for generate=True. " "You can create a new version of this wrapper using the `get_new_version` method." ) td_out = self(tensordict.copy()) # Get logits/log-probs if as_padded_tensor is None: as_padded_tensor = as_nested_tensor is not True if padding_value is None: padding_value = 0.0 if as_nested_tensor is None: as_nested_tensor = False logits = td_out.get( logits_key, as_padded_tensor=as_padded_tensor, as_nested_tensor=as_nested_tensor, padding_value=padding_value, padding_side=padding_side, layout=layout, ) # Get mask if provided mask = None if mask_key is not None: mask = td_out.get( mask_key, as_padded_tensor=as_padded_tensor, as_nested_tensor=as_nested_tensor, padding_value=False, padding_side=padding_side, layout=layout, ) elif as_padded_tensor: # Default mask for padded tensors mask = logits != padding_value if mask is not None: dist = LLMMaskedCategorical( logits=logits, mask=mask, ) if not dist._position_level_masking: raise ValueError( "Mask is not a position-level mask. " "This is likely because the mask is not a position-level mask." ) return dist return Categorical(logits)
def _get_dist_with_prompt_mask( self, tensordict: TensorDictBase, tokens_key: NestedKey = ("tokens", "prompt"), logits_key: NestedKey = "logits", # TODO: add a prompt_mask and response_mask in Masks assistant_mask_key: NestedKey = ("masks", "all_assistant_mask"), attention_mask_key: NestedKey = ("masks", "all_attention_mask"), padding_side: str = "left", **kwargs, ) -> D.Distribution: """Get distribution masked to only include response tokens (exclude prompt). This is suitable for single-turn scenarios where we want to compute loss only on the generated response, not the input prompt. Note: If prompt tokens are not available (e.g., when using history input), this method falls back to using the assistant mask. Padding side is left by convention. This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy. """ if self.generate: raise NotImplementedError( "get_dist_with_prompt_mask is not implemented for generate=True. " "You can create a new version of this wrapper using the `get_new_version` method." ) td_out = self(tensordict.copy()) # Try to get prompt tokens first if self.pad_output: prompt_tokens = tensordict.get( tokens_key, as_padded_tensor=True, padding_value=-100, padding_side=padding_side, ) logits = td_out.get( logits_key, as_padded_tensor=True, padding_value=0.0, padding_side=padding_side, ) attention_mask = tensordict.get( attention_mask_key, as_padded_tensor=True, padding_value=False, padding_side=padding_side, ) assistant_mask = tensordict.get( assistant_mask_key, as_padded_tensor=True, padding_value=False, padding_side=padding_side, ) else: prompt_tokens = tensordict.get(tokens_key, as_list=True) logits = td_out.get(logits_key, as_list=True) attention_mask = td_out.get(attention_mask_key, as_list=True) assistant_mask = td_out.get(assistant_mask_key, as_list=True) if prompt_tokens is None: if assistant_mask is None: raise ValueError( f"Assistant mask not found in tensordict at key {assistant_mask_key} (keys: {td_out.keys()})" ) if self.pad_output: response_mask = assistant_mask.clone() else: response_mask = [am.clone() for am in assistant_mask] else: if self.pad_output: response_mask = attention_mask.clone() response_mask[..., : prompt_tokens.shape[-1]] = False else: response_mask = [] for am, p in _zip_strict(attention_mask, prompt_tokens): am = am.clone() am[..., : p.size(-1)] = False response_mask.append(am) if logits is None: raise ValueError( f"Logits not found in tensordict at key {logits_key} (keys: {td_out.keys()})" ) # Make the response mask using prompt tokens if not self.pad_output: # Check that the lengths of the mask is the same as the logits for m, lg in _zip_strict(response_mask, logits): if m.shape[-1] != lg.shape[-2]: raise ValueError( f"Mask and logits have different lengths: {m.shape[-1]} != {lg.shape[-2]}.\n" f"All the logits shapes: {[lg.shape for lg in logits]}, all the mask shapes: {[m.shape for m in response_mask]}" ) logits = pad_sequence( logits, batch_first=True, padding_value=0.0, padding_side=padding_side ) response_mask = pad_sequence( response_mask, batch_first=True, padding_value=False, padding_side=padding_side, ) dist = LLMMaskedCategorical( logits=logits, mask=response_mask.bool(), ) if not dist._position_level_masking: raise ValueError( "Mask is not a position-level mask. " "This is likely because the mask is not a position-level mask." ) return dist def _get_dist_with_assistant_mask( self, tensordict: TensorDictBase, assistant_mask_key: NestedKey = ("masks", "all_assistant_mask"), logits_key: NestedKey = "logits", padding_side: str = "left", **kwargs, ) -> D.Distribution: """Get distribution masked to only include assistant tokens. This is suitable for multi-turn scenarios where we want to compute loss only on assistant-generated tokens across the entire conversation. Padding side is left by convention. This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy. """ if self.generate: raise NotImplementedError( "get_dist_with_assistant_mask is not implemented for generate=True. " "You can create a new version of this wrapper using the `get_new_version` method." ) td_out = self(tensordict.copy()) # Update the tokens key to reflect the tokenized history when querying the log-probs tensordict.update( td_out, keys_to_update=[ ("tokens", "full"), ], ) if self.pad_output: logits = td_out.get(logits_key) assistant_mask = td_out.get(assistant_mask_key) else: logits = td_out.get( logits_key, as_padded_tensor=True, padding_value=0.0, padding_side=padding_side, ) assistant_mask = td_out.get( assistant_mask_key, as_padded_tensor=True, padding_value=False, padding_side=padding_side, ) if logits is None: raise ValueError(f"Logits not found in tensordict at key {logits_key}") if assistant_mask is None: if self.input_mode != "history": post_msg = "This is likely because the input_mode is not 'history'." else: post_msg = "" raise ValueError( f"Assistant mask not found in tensordict at key {assistant_mask_key}. {post_msg}" ) dist = LLMMaskedCategorical( logits=logits, mask=assistant_mask, ) if not dist._position_level_masking: raise ValueError( "Assistant mask is not a position-level mask. " "This is likely because the assistant mask is not a position-level mask." ) return dist def _get_dist_with_attention_mask( self, tensordict: TensorDictBase, attention_mask_key: NestedKey = ("masks", "all_attention_mask"), logits_key: NestedKey = "logits", padding_side: str = "left", **kwargs, ) -> D.Distribution: """Get distribution masked using attention mask. This is suitable for generic scenarios where we want to compute loss on all valid tokens (non-padding tokens). Padding side is left by convention. This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy. """ if self.generate: raise NotImplementedError( "get_dist_with_attention_mask is not implemented for generate=True. " "You can create a new version of this wrapper using the `get_new_version` method." ) td_out = self(tensordict.copy()) if self.pad_output: logits = td_out.get(logits_key) attention_mask = td_out.get(attention_mask_key) else: logits = td_out.get( logits_key, as_padded_tensor=True, padding_value=0.0, padding_side=padding_side, ) attention_mask = td_out.get( attention_mask_key, as_padded_tensor=True, padding_value=False, padding_side=padding_side, ) if logits is None: raise ValueError(f"Logits not found in tensordict at key {logits_key}") if attention_mask is None: raise ValueError( f"Attention mask not found in tensordict at key {attention_mask_key}" ) dist = LLMMaskedCategorical( logits=logits, mask=attention_mask, ) if not dist._position_level_masking: raise ValueError( "Attention mask is not a position-level mask. " "This is likely because the attention mask is not a position-level mask." ) return dist def _get_dist_with_custom_mask( self, tensordict: TensorDictBase, mask: torch.Tensor, logits_key: NestedKey = "logits", padding_side: str = "left", **kwargs, ) -> D.Distribution: """Get distribution with custom mask. This allows for completely custom masking logic. Padding side is left by convention. This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy. """ if self.generate: raise NotImplementedError( "get_dist_with_custom_mask is not implemented for generate=True. " "You can create a new version of this wrapper using the `get_new_version` method." ) td_out = self(tensordict.copy()) if self.pad_output: logits = td_out.get(logits_key) else: logits = td_out.get( logits_key, as_padded_tensor=True, padding_value=0.0, padding_side=padding_side, ) if logits is None: raise ValueError(f"Logits not found in tensordict at key {logits_key}") dist = LLMMaskedCategorical( logits=logits, mask=mask, ) if not dist._position_level_masking: raise ValueError( "Custom mask is not a position-level mask. " "This is likely because the custom mask is not a position-level mask." ) return dist # Convenience methods for common LLM training scenarios def _get_sft_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution: """Get distribution suitable for SFT loss (response tokens only). This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy. """ return self._get_dist_with_prompt_mask(tensordict, **kwargs) def _get_rlhf_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution: """Get distribution suitable for RLHF loss (assistant tokens only). This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy. """ return self._get_dist_with_assistant_mask(tensordict, **kwargs) def _get_generic_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution: """Get distribution suitable for generic losses (all tokens). This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy. """ return self._get_dist_with_attention_mask(tensordict, **kwargs) # Sampling is taken care of by the sub-modules forward = TensorDictSequential.forward def _check_padded(self, val: torch.Tensor) -> torch.Tensor: """Check that a value is a padded tensor.""" assert isinstance( val, torch.Tensor ), f"val must be torch.Tensor, got {type(val)}" if not isinstance(val, torch.Tensor): raise ValueError("Not a padded tensor") return val def _check_not_padded( self, val: list[torch.Tensor] | torch.Tensor ) -> list[torch.Tensor] | torch.Tensor: """Check that a value is not a padded tensor (i.e., a list of tensors).""" if isinstance(val, torch.Tensor): raise ValueError("Expected a list of tensors - not padded, got a tensor") return val @property def log_prob_keys(self) -> list[NestedKey]: return getattr(self, "_log_prob_keys", ["log_probs"]) @log_prob_keys.setter def log_prob_keys(self, value: list[NestedKey]): self._log_prob_keys = value @property def dist_params_keys(self) -> list[NestedKey]: raise NotImplementedError @property def dist_sample_keys(self) -> list[NestedKey]: return ["tokens_response"] def log_prob(self, data: TensorDictBase, **get_kwargs) -> TensorDictBase: if not self.generate: data = self(data) return data.get((self.log_prob_key, "response"), **get_kwargs) raise RuntimeError("log_prob not callable when generate=True.")
def _extract_responses_from_full_histories( text_full: list[str], prompt_histories, chat_template_name: str | None = None, tokenizer=None, ) -> History: """Extract response histories from full text histories. This function parses the full text back to history objects and extracts the response portions (everything after the prompt). Args: text_full: List of full text strings to parse prompt_histories: The original prompt histories chat_template_name: Optional chat template name for parsing tokenizer: Optional tokenizer for template detection Returns: Stacked History object with response portions Raises: RuntimeError: If full history is shorter than prompt history RuntimeError: If parsing produces inconsistent batch shapes """ import torch from tensordict.utils import _zip_strict from torchrl.data.llm import History # Extract response portions by processing each element individually # This avoids the stacking issue when different batch elements produce # different numbers of responses response_histories = [] full_histories = History.from_text( text_full, chat_template_name=chat_template_name, tokenizer=tokenizer ) for h_prompt, h_full in _zip_strict( prompt_histories.unbind(0), full_histories.unbind(0) ): if h_full.shape[0] <= h_prompt.shape[0]: raise RuntimeError( f"Full history is shorter than prompt history: {h_full.shape} <= {h_prompt.shape}" ) # Note: there can be more than one response, so the response has the same number of dims as prompt response_histories.append(h_full[h_prompt.shape[0] :]) # Check if all responses have the same shape shapes = [r.shape for r in response_histories] if len(set(shapes)) > 1: # Different shapes detected - pad to the same length max_length = max(r.shape[0] for r in response_histories) padded_responses = [] for response in response_histories: if response.shape[0] < max_length: # Pad with empty messages using "<none>" role padding_needed = max_length - response.shape[0] padding_history = History( role="<none>", content="", batch_size=(padding_needed,) ) padded_response = response.extend(padding_history, inplace=False) padded_responses.append(padded_response) else: padded_responses.append(response) return torch.stack(padded_responses) return torch.stack(response_histories)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources