Shortcuts

Source code for torchrl.envs.llm.transforms.kl

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import warnings

from contextlib import nullcontext
from copy import copy
from typing import Any, Literal, TYPE_CHECKING

import torch
from tensordict import NestedKey, set_list_to_stack, TensorDictBase, unravel_key
from tensordict.utils import _zip_strict, is_seq_of_nested_key, logger as torchrl_logger
from torch.nn.utils.rnn import pad_sequence
from torchrl.data import Composite, Unbounded
from torchrl.envs import EnvBase, Transform
from torchrl.envs.transforms.transforms import Compose
from torchrl.envs.transforms.utils import _set_missing_tolerance
from torchrl.modules.llm.policies.common import LLMWrapperBase

if TYPE_CHECKING:
    import transformers


[docs]class KLRewardTransform(Transform): """A legacy transform for computing KL divergence-based rewards. **Deprecated**: This transform is maintained for backward compatibility but is no longer the recommended approach. Use :class:`~torchrl.envs.llm.transforms.kl.RetrieveKL` instead, which provides better modularity and integration with the new wrapper design. **Recent Changes:** - **Legacy Status**: This transform is now considered legacy and may not work optimally with the new modular wrapper design. - **ChatHistory Integration**: Limited support for the new :class:`~torchrl.modules.llm.policies.ChatHistory` objects. - **Input Mode Support**: May not handle all input modes (`"history"`, `"text"`, `"tokens"`) consistently. **Recommendation**: Use :class:`~torchrl.envs.llm.transforms.kl.RetrieveKL` for new code, which provides: - Better integration with the new wrapper design - Consistent support for all input modes - Proper handling of ChatHistory objects - More modular and composable architecture Args: gen_model (LLMWrapperBase): the generation model. ref_model (LLMWrapperBase): the reference model. Keyword Args: assistant_only (bool): whether to only compute KL on assistant tokens. Defaults to `True`. tokenizer (transformers.AutoTokenizer): the tokenizer to use. Defaults to `None`. detach (bool): whether to detach the KL from the computation graph. Defaults to `True`. device (torch.device): the device to use. Defaults to `None`. padding_side (str): the side of the padding when using pad_sequence. Defaults to `"left"`. Examples: >>> # Legacy usage (not recommended for new code) >>> transform = KLRewardTransform(gen_model, ref_model) >>> >>> # Recommended approach using RetrieveKL >>> from torchrl.envs.llm.transforms.kl import RetrieveKL >>> transform = RetrieveKL(gen_model, ref_model, assistant_only=True) .. seealso:: :class:`~torchrl.envs.llm.transforms.kl.RetrieveKL`: The recommended transform for KL divergence computation. :class:`~torchrl.envs.llm.transforms.kl.RetrieveLogProb`: Base transform for retrieving log-probabilities. :class:`~torchrl.envs.llm.transforms.kl.KLComputation`: Transform for computing KL divergence between log-prob tensors. """ DEFAULT_IN_KEYS = ["reward"] def __init__( self, ref_model: LLMWrapperBase, *, coef=1.0, in_keys=None, out_keys=None, log_prob_key: NestedKey = ("log_probs", "full"), device: torch.device | None = None, add_to_reward: bool = True, tokenizer: transformers.AutoTokenizer | None = None, assistant_only: bool = True, padding_side: str = "left", ): if in_keys is None: in_keys = self.DEFAULT_IN_KEYS if out_keys is None: out_keys = copy(in_keys) if len(out_keys) == len(in_keys): out_keys = out_keys + ["kl_penalty", "ref_log_prob"] elif len(out_keys) != len(in_keys) + 2: raise ValueError( "The out_keys must have the same length as the in_keys (plus two additional optional kl entries for logging)." ) super().__init__(in_keys=in_keys, out_keys=out_keys) if not is_seq_of_nested_key(self.in_keys) or not is_seq_of_nested_key( self.out_keys ): raise ValueError( f"invalid in_keys / out_keys:\nin_keys={self.in_keys} \nout_keys={self.out_keys}" ) if len(self.in_keys) != 1 or len(self.out_keys) != 3: raise ValueError( f"Only one in_key/out_key is allowed, got in_keys={self.in_keys}, out_keys={self.out_keys}." ) self._out_keys = [unravel_key(out_key) for out_key in self._out_keys] if getattr(ref_model, "generate", False): raise ValueError( "The actor is configured to generate text, not compute the log-probs." ) # update the in_keys for dispatch etc self.in_keys = self.in_keys + ref_model.in_keys self.in_keys = [unravel_key(in_key) for in_key in self.in_keys] self.add_to_reward = add_to_reward # check that the model has parameters self.__dict__["ref_model"] = ref_model # self._buffers["actor_params"] = params.clone().detach() self.device = device # find the sample log-prob key self.log_prob_full_key = log_prob_key self._tokenizer = tokenizer self.assistant_only = assistant_only self.padding_side = padding_side if not isinstance(coef, torch.Tensor): coef = torch.as_tensor(coef) self.register_buffer("coef", coef) # sanity check for the ref_model if not getattr(ref_model, "input_mode", "tokens") == "tokens": raise ValueError( "The ref_model must be configured to use tokens as input. Please set the `input_mode` argument to `tokens`." ) @property def pad_output(self): # We need pad_output to match the pad_output of the inference model return self.ref_model.pad_output @property def tokenizer(self): tokenizer = self._tokenizer if tokenizer is not None: return tokenizer try: return self.ref_model.tokenizer except AttributeError: raise AttributeError( "The ref_model does not have a tokenizer. Please pass the tokenizer to the constructor." ) def set_container(self, container: Transform | EnvBase) -> None: result = super().set_container(container) if self.action_key is None: parent = getattr(self, "parent", None) if parent is not None: action_keys = parent.action_keys if len(action_keys) != 1: raise ValueError( f"More than one action_key found. Please pass the `action_key` argument directly to {type(self).__name__}." ) action_key = action_keys[0] self.action_key = action_key return result def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase ) -> TensorDictBase: with _set_missing_tolerance(self, True): tensordict_reset = self._step(tensordict_reset, tensordict_reset) return tensordict_reset @property def action_key(self) -> NestedKey: # Get the action from the base env (a ChatEnv). if self.parent.base_env.input_mode == "history": return ("history", "full") if self.parent.base_env.input_mode == "text": return ("text", "full") if self.parent.base_env.input_mode == "tokens": return ("tokens", "full") raise ValueError(f"Invalid input mode: {self.parent.base_env.input_mode}") def _step( self, tensordict: TensorDictBase, next_tensordict: TensorDictBase ) -> TensorDictBase: if self.device is not None: tensordict = tensordict.to(self.device) next_tensordict = next_tensordict.to(self.device) # tensordict = self._get_text_response(tensordict, next_tensordict) response = tensordict.get(self.action_key, None) if response is None: if not self.missing_tolerance: raise RuntimeError( f"Action with key {self.action_key} not found data {tensordict}" ) # being called after reset or without action, skipping if self.out_keys[0] != "reward" and self.parent is not None: next_tensordict.set(self.out_keys[0], self.parent.reward_spec.zero()) return next_tensordict # We use the ("tokens", "full") key to get the log-probs of the reference model with torch.device(self.device) if self.device is not None else nullcontext(): td_input = tensordict.copy() ref_log_prob_td = self.ref_model(td_input) if self.pad_output: ref_log_prob_padded = ref_log_prob_td.get(self.log_prob_full_key) else: ref_log_prob_unpadded = ref_log_prob_td.get( self.log_prob_full_key, as_list=True ) if self.assistant_only: # Get the assistant mask mask = tensordict.get(("masks", "all_assistant_mask")) # mask will often be None - fall back on prompt / response separation if mask is None: if self.pad_output: # simple case: just take the prompt length prompt_length = tensordict.get(("tokens", "prompt")).shape[-1] mask = tensordict.get(("masks", "all_attention_mask")).clone() mask[..., :prompt_length] = False else: # simple case: just take the prompt length prompt_length = [ t.size(-1) for t in tensordict.get(("tokens", "prompt"), as_list=True) ] mask = tensordict.get(("masks", "all_attention_mask"), as_list=True) for i in range(len(prompt_length)): mask[i] = mask[i].clone() mask[i][..., : prompt_length[i]] = False # we want to keep the batch dimension ref_log_prob_list = [] if self.pad_output: for i in range(ref_log_prob_padded.size(0)): ref_log_prob_list.append( ref_log_prob_padded[i].masked_fill(~mask[i], 0) ) else: for i in range(len(ref_log_prob_unpadded)): ref_log_prob_list.append( ref_log_prob_unpadded[i].masked_fill(~mask[i], 0) ) if self.pad_output: ref_log_prob = pad_sequence( ref_log_prob_list, batch_first=True, padding_value=0, padding_side=self.padding_side, ) else: ref_log_prob = torch.nested.nested_tensor( ref_log_prob_list, layout=torch.strided ) # we obtain the current log-probs (already computed) from the current tensordict if self.pad_output: curr_log_prob_padded = tensordict.get(self.log_prob_full_key) else: curr_log_prob_unpadded = tensordict.get( self.log_prob_full_key, as_list=True ) if self.assistant_only: # we want to keep the batch dimension curr_log_prob_list = [] if self.pad_output: for i in range(curr_log_prob_padded.size(0)): curr_log_prob_list.append( curr_log_prob_padded[i].masked_fill(~mask[i], 0) ) else: for i in range(len(curr_log_prob_unpadded)): curr_log_prob_list.append( curr_log_prob_unpadded[i].masked_fill(~mask[i], 0) ) if self.pad_output: curr_log_prob = pad_sequence( curr_log_prob_list, batch_first=True, padding_value=0, padding_side=self.padding_side, ) else: curr_log_prob = torch.nested.nested_tensor( curr_log_prob_list, layout=torch.strided ) ref_log_prob = ref_log_prob.to(curr_log_prob.device) # We want the log-probs to have a similar dim to the reward curr_log_prob = curr_log_prob.unsqueeze(-1) ref_log_prob = ref_log_prob.unsqueeze(-1) for i in range(ref_log_prob.size(0)): if ref_log_prob[i].shape != curr_log_prob[i].shape: # Don't check shapes if nested raise ValueError( f"the log-probability tensor shapes must match, got cur_log_prob.shape={curr_log_prob[i].shape} and log_prob.shape={ref_log_prob[i].shape}. " f"One possible reason is that the padding token is identical to the eos token, which means that the eos_token log_prob is truncated from the " f"reference model output." ) kl = curr_log_prob - ref_log_prob if self.add_to_reward: reward_key = self.in_keys[0] reward = next_tensordict.get(reward_key) # we use the unbiased consistent estimator of the KL: log_p(x) - log_q(x) when x ~ p(x) if not reward.is_nested and ref_log_prob.is_nested: reward = torch.nested.nested_tensor( [rew.expand(lp.shape) for rew, lp in zip(reward, ref_log_prob)], layout=torch.strided, ) if reward is not None and reward.ndim != curr_log_prob.ndim: raise ValueError( "The number of dimensions of reward must be the same as the number of dimensions of the KL " f"term. Got ndim={reward.ndim} and {curr_log_prob.ndim} respectively." ) if reward is None: reward = 0 reward = reward - self.coef * kl next_tensordict.set(self.out_keys[0], reward) next_tensordict.set(self.out_keys[1], kl) next_tensordict.set(self.out_keys[2], ref_log_prob) return next_tensordict
[docs] def forward(self, tensordict: TensorDictBase) -> TensorDictBase: next_td = tensordict.pop("next") next_td = self._step(tensordict, next_td) return tensordict.set("next", next_td)
[docs] def transform_output_spec(self, output_spec: Composite) -> Composite: in_key = unravel_key(self.in_keys[0]) out_key = unravel_key(self.out_keys[0]) if "full_observation_spec" in output_spec.keys(): observation_spec = output_spec["full_observation_spec"] else: observation_spec = Composite( shape=output_spec.shape, device=output_spec.device ) output_spec["full_observation_spec"] = observation_spec if in_key == "reward" and out_key == "reward": parent = self.parent reward_keys = parent.reward_keys if len(reward_keys) == 1: reward_key = reward_keys[0] shape = output_spec["full_reward_spec"].shape elif "reward" in reward_keys: reward_key = "reward" shape = output_spec["full_reward_spec"].shape else: shape = output_spec.shape reward_key = "reward" # For LLMs, the shape of the reward is (batch, -1, 1) shape = (*shape, -1, 1) reward_spec = Unbounded( device=output_spec.device, shape=shape, ) output_spec["full_reward_spec"] = Composite( {reward_key: reward_spec}, shape=output_spec["full_reward_spec"].shape, ) elif in_key == "reward": # TODO: we should at least allow to make this a component of the reward specs, to avoid a call during reset parent = self.parent reward_spec = output_spec["full_reward_spec"][parent.reward_key] shape = output_spec["full_reward_spec"].shape # For LLMs, the shape of the reward is (batch, -1, 1) shape = (*shape, -1, 1) reward_spec = reward_spec.clone() reward_spec.shape = torch.Size(shape) # then we need to populate the output keys observation_spec[out_key] = reward_spec else: observation_spec = output_spec["full_observation_spec"] reward_spec = observation_spec[in_key] shape = observation_spec.shape shape = (*shape, -1, 1) reward_spec = reward_spec.clone() reward_spec.shape = torch.Size(shape) # then we need to populate the output keys observation_spec[out_key] = reward_spec observation_spec[self.out_keys[1]] = reward_spec.clone() return output_spec
[docs]class RetrieveLogProb(Transform): """A transform to retrieve log-probabilities from a model for KL divergence computation. This transform computes log-probabilities from a reference model, which can then be used to compute KL divergence with another model's log-probabilities. It's designed to work with the :class:`~torchrl.envs.llm.transforms.kl.RetrieveKL` and :class:`~torchrl.envs.llm.transforms.kl.KLComputation` transforms. Args: model (LLMWrapperBase): the model to use to compute the log-probs. Keyword Args: log_probs_full_key (NestedKey): the key where the log-probs are stored. If not provided, the key will be retrieved from the model's `log_probs_key` attribute (i.e., `(model.log_probs_key, "full")`). assistant_only (bool): whether to zero out the log-probs of the non-assistant tokens (i.e., steps of history where the role is not `"assistant"`). Defaults to `True`. .. note:: When `assistant_only=True`, the model must have `input_mode='history'` to properly identify assistant tokens. For other input modes (`"text"` or `"tokens"`), set `assistant_only=False`. This ensures users are conscious of the limitation that assistant token identification requires structured conversation history. tokenizer_kwargs (dict): the keyword arguments to pass to the tokenizer to be used to apply the chat template to the history when `assistant_only` is `True`. To control the tokenization in the ref_model, pass the tokenizer kwargs to the ref_model constructor. Defaults to `{"return_assistant_tokens_mask": True, "tokenize": True, "return_dict": True, "padding": False, "add_generation_prompt": False}`. tokenizer (transformers.AutoTokenizer): the tokenizer to be used to tokenize the input and compute the assitant mask. If not provided, the tokenizer will be inferred from the `ref_model`. detach (bool): whether to exclude the log-probs from the gradient computation. Defaults to `True`. device (torch.device): the device to use for tensor creation. Defaults to `None`. padding_side (str): the side of the padding when using pad_sequence. Defaults to `"left"`. Examples: >>> from torchrl.data.llm import History >>> from torchrl.modules.llm import TransformersWrapper >>> from torchrl.modules.llm.policies import ChatHistory >>> from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM >>> from tensordict import TensorDict, set_list_to_stack >>> import torch >>> >>> # Set up list to stack for History >>> set_list_to_stack(True).set() >>> >>> # Create chat data >>> chats = [ ... [ ... {"role": "system", "content": "You are a helpful assistant."}, ... {"role": "user", "content": "Hello, how are you?"}, ... {"role": "assistant", "content": "I'm doing well, thank you!"}, ... ], ... [ ... {"role": "system", "content": "You are a helpful assistant."}, ... {"role": "user", "content": "What's the weather like?"}, ... {"role": "assistant", "content": "I can't check the weather for you."}, ... ], ... ] >>> history = History.from_chats(chats) >>> print(f"Created history with shape: {history.shape}") Created history with shape: torch.Size([2, 3]) >>> >>> # Setup tokenizer and model >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") >>> tokenizer.pad_token = tokenizer.eos_token >>> model = OPTForCausalLM(OPTConfig()).eval() >>> >>> # Create reference model >>> ref_model = TransformersWrapper( ... model, ... tokenizer=tokenizer, ... input_mode="history", ... generate=False, ... return_log_probs=True, ... pad_output=True, ... ) >>> >>> # Create the RetrieveLogProb transform >>> transform = RetrieveLogProb( ... ref_model, ... assistant_only=True, ... tokenizer=tokenizer, ... ) >>> >>> # Prepare data using ChatHistory >>> chat_history = ChatHistory(full=history) >>> data = TensorDict(history=chat_history, batch_size=(2,)) >>> >>> # Apply the transform to get reference log probabilities >>> result = transform(data) >>> log_probs_key = (ref_model.log_probs_key, "full") >>> ref_log_probs = result.get(log_probs_key) >>> print(f"Log-probs shape: {ref_log_probs.shape}") Log-probs shape: torch.Size([2, 26]) .. note:: By default, the log-probabilities are stored as a list of tensors (one per sample, with variable length). Use `as_padded_tensor=True` in `.get()` to obtain a batchable tensor (with padding). The reference log probabilities are computed only for assistant tokens when `assistant_only=True`. **Input Mode Compatibility:** - When `assistant_only=True` (default), the model must have `input_mode='history'` to properly identify assistant tokens. - When `assistant_only=False`, the transform works with any input mode (`"history"`, `"text"`, or `"tokens"`). - This design ensures users are conscious of the limitation that assistant token identification requires structured conversation history. .. seealso:: :class:`~torchrl.envs.llm.transforms.kl.RetrieveKL`: A higher-level transform that combines two `RetrieveLogProb` instances with `KLComputation`. :class:`~torchrl.envs.llm.transforms.kl.KLComputation`: A transform that computes KL divergence between two log-prob tensors. :class:`~torchrl.envs.llm.transforms.kl.KLRewardTransform`: A legacy transform for KL reward computation (use `RetrieveKL` instead). """ def __init__( self, model: LLMWrapperBase, *, log_probs_full_key: NestedKey | None = None, assistant_only: bool = True, tokenizer_kwargs: dict | None = None, detach: bool = True, device: torch.device | None = None, tokenizer: transformers.AutoTokenizer | None = None, padding_side: str = "left", ): # Set up keys if log_probs_full_key is None: log_probs_full_key = (model.log_probs_key, "full") elif ( not isinstance(log_probs_full_key, tuple) or log_probs_full_key[-1] != "full" ): warnings.warn( f"The log_probs_full_key {log_probs_full_key} is not a tuple or does not end with 'full'. " "This may cause issues with the KL computation. " "Please use a tuple with the log_probs_key and 'full' as the last element." ) self.log_probs_full_key = log_probs_full_key # Set up input/output keys in_keys = list(model.in_keys) out_keys = [self.log_probs_full_key] super().__init__(in_keys=in_keys, out_keys=out_keys) # Store model and configuration self.model = model self.assistant_only = assistant_only self.detach = detach self.device = device self.tokenizer = tokenizer self.padding_side = padding_side # Set up tokenizer kwargs if tokenizer_kwargs is None: tokenizer_kwargs = {} tokenizer_kwargs.setdefault("return_assistant_tokens_mask", True) tokenizer_kwargs.setdefault("tokenize", True) tokenizer_kwargs.setdefault("return_dict", True) tokenizer_kwargs.setdefault("padding", False) tokenizer_kwargs.setdefault("add_generation_prompt", False) self.tokenizer_kwargs = tokenizer_kwargs # Validate model configuration (after setting assistant_only) self._validate_model_config(model) def _validate_model_config(self, model: LLMWrapperBase): """Validate model configuration.""" if not getattr(model, "return_log_probs", True): raise ValueError( "The model must have `return_log_probs=True` to use the `RetrieveLogProb` transform." ) if getattr(model, "generate", True): raise ValueError( "The model must have `generate=False` to use the `RetrieveLogProb` transform." ) # Check input mode compatibility with assistant_only input_mode = getattr(model, "input_mode", "history") if self.assistant_only and input_mode != "history": raise ValueError( f"The model must have `input_mode='history'` when `assistant_only=True`. " f"Current input_mode is '{input_mode}'. " f"To use input_mode '{input_mode}', set `assistant_only=False`." )
[docs] def forward(self, tensordict: TensorDictBase) -> TensorDictBase: next_td = tensordict.get("next") next_is_none = False if next_td is None: next_is_none = True next_td = tensordict output = self._step(tensordict, next_td) if next_is_none: return output return tensordict.set("next", output)
def _mask_assistant_tokens( self, td: TensorDictBase, lp_key: NestedKey ) -> torch.Tensor: """Mask log-probs to only include assistant tokens. Args: td: TensorDict containing the data lp_key: Key for log-probs in the TensorDict Returns: Masked log-probs tensor """ with torch.device(self.device) if self.device is not None else nullcontext(): # Get assistant mask assistant_masks = td.get(("masks", "all_assistant_mask"), as_list=True) log_probs = td.get(lp_key, as_list=True) log_probs = [ torch.masked_fill(lp, ~mask, 0.0) for lp, mask in _zip_strict(log_probs, assistant_masks) ] if self.model.pad_output: log_probs = pad_sequence( log_probs, batch_first=True, padding_value=0.0, padding_side=self.padding_side, ) else: log_probs = torch.nested.as_nested_tensor( log_probs, layout=self.model.layout ) return log_probs @set_list_to_stack(True) def _step( self, tensordict: TensorDictBase, next_tensordict: TensorDictBase ) -> TensorDictBase: # Compute log-probs using the model # Use tensordict since we want to process the "full" entry ref_td = self.model(tensordict.copy()) tmp_log_probs_key = (self.model.log_probs_key, "full") # Apply assistant masking if requested if self.assistant_only: log_probs = self._mask_assistant_tokens(ref_td, tmp_log_probs_key) ref_td.set(tmp_log_probs_key, log_probs) # Rename and store the log-probs if tmp_log_probs_key != self.log_probs_full_key: ref_td.rename_key_(tmp_log_probs_key, self.log_probs_full_key) next_tensordict.update(ref_td, keys_to_update=(self.log_probs_full_key,)) return next_tensordict
[docs] def transform_observation_spec(self, observation_spec: Composite) -> Composite: # Add kl to observation spec observation_spec["kl_penalty"] = Unbounded( device=observation_spec.device, shape=observation_spec.shape, ) return observation_spec
[docs]class RetrieveKL(Compose): """A transform to retrieve the KL divergence between two models' log-probabilities. This transform combines two :class:`~torchrl.envs.llm.transforms.kl.RetrieveLogProb` instances with a :class:`~torchrl.envs.llm.transforms.kl.KLComputation` to compute KL divergence between a generation model and a reference model. .. note:: Both gen_model and ref_model must use the same pad_output value (True or False), otherwise KL computation will fail. Args: gen_model (LLMWrapperBase): the generation model, wrapped in such a way that it does not generate but computes the log-probs. In cases where the transform is used within a :class:`~torchrl.collectors.llm.LLMCollector` run on a remote worker, the policy may not be available ahead of time. In this case, the `gen_model` can be set to `"from_collector"` (default) to retrieve the policy from the collector. See :meth:`~torchrl.modules.llm.policies.LLMWrapperBase.get_new_version` for more details about generating a new version of the policy to gather the log-probs. ref_model (LLMWrapperBase): the reference model, wrapped in such a way that it does not generate but computes the log-probs. Keyword Args: assistant_only (bool): whether to only retrieve the log-probs of the assistant tokens (i.e., steps of history where the role is `"assistant"`). Defaults to `True`. .. note:: When `assistant_only=True`, both models must have `input_mode='history'` to properly identify assistant tokens. For other input modes (`"text"` or `"tokens"`), set `assistant_only=False`. This ensures users are conscious of the limitation that assistant token identification requires structured conversation history. gen_log_probs_full_key (str): the key where the log-probs of the generation model are stored. Defaults to `("log_probs", "full")`. ref_log_probs_full_key (str): the key where the log-probs of the reference model are stored. Defaults to `("ref_log_probs", "full")`. history_key (str): the key where the history is stored. Defaults to `"history"`. tokenizer_kwargs (dict): the keyword arguments to pass to the tokenizer to be used to apply the chat template to the history when `assistant_only` is `True`. To control the tokenization in the actor, pass the tokenizer kwargs to the actor constructor. Defaults to `{"return_assistant_tokens_mask": True, "tokenize": True, "return_tensors": "pt", "padding": True, "add_generation_prompt": False}`. detach (bool): whether to exclude the log-probs from the gradient computation. Defaults to `True`. device (torch.device): the device to use for tensor creation. Defaults to `None`. tokenizer (transformers.AutoTokenizer): the tokenizer to be used to tokenize the input and compute the assitant mask. If not provided, the tokenizer will be inferred from the `actor`. padding_side (str): the side of the padding when using pad_sequence. Defaults to `"left"`. kl_key (NestedKey): the key where the KL divergence is stored. Defaults to `"kl_penalty"`. add_to_reward (bool): whether to add the KL divergence to the reward. Defaults to `True`. coeff (float): the coefficient for the KL term when adding to reward. Defaults to `1.0`. padding_side (str): the side of the padding when using pad_sequence. Defaults to `"left"`. **kwargs: additional arguments to pass to the `RetrieveLogProb` transform. Examples: >>> from torchrl.data.llm import History >>> from torchrl.modules.llm import TransformersWrapper >>> from torchrl.modules.llm.policies import ChatHistory >>> from transformers import AutoTokenizer, OPTConfig, OPTForCausalLM >>> from tensordict import TensorDict, set_list_to_stack >>> import torch >>> >>> # Set up list to stack for History >>> set_list_to_stack(True).set() >>> >>> # Create chat data >>> chats = [ ... [ ... {"role": "system", "content": "You are a helpful assistant."}, ... {"role": "user", "content": "Hello, how are you?"}, ... {"role": "assistant", "content": "I'm doing well, thank you!"}, ... ], ... [ ... {"role": "system", "content": "You are a helpful assistant."}, ... {"role": "user", "content": "What's the weather like?"}, ... {"role": "assistant", "content": "I can't check the weather for you."}, ... ], ... ] >>> history = History.from_chats(chats) >>> print(f"Created history with shape: {history.shape}") Created history with shape: torch.Size([2, 3]) >>> >>> # Setup tokenizer and model >>> tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") >>> tokenizer.pad_token = tokenizer.eos_token >>> model = OPTForCausalLM(OPTConfig()).eval() >>> >>> # Create generation and reference models >>> gen_model = TransformersWrapper( ... model, ... tokenizer=tokenizer, ... input_mode="history", ... generate=False, ... return_log_probs=True, ... pad_output=True, ... log_probs_key="gen_log_probs", ... ) >>> ref_model = TransformersWrapper( ... model, ... tokenizer=tokenizer, ... input_mode="history", ... generate=False, ... return_log_probs=True, ... pad_output=True, ... log_probs_key="ref_log_probs", ... ) >>> >>> # Create RetrieveKL transform >>> transform = RetrieveKL( ... gen_model=gen_model, ... ref_model=ref_model, ... assistant_only=True, ... tokenizer=tokenizer, ... ) >>> >>> # Prepare data with next tensordict using ChatHistory >>> chat_history = ChatHistory(full=history) >>> next_td = TensorDict(history=chat_history, batch_size=(2,)) >>> data = TensorDict(history=chat_history, next=next_td, batch_size=(2,)) >>> >>> # Apply transform >>> result = transform(data) >>> kl = result["next"].get("kl_penalty") >>> print(f"KL shape: {kl.shape}") KL shape: torch.Size([2, 26]) Note: **Input Mode Compatibility:** - When `assistant_only=True`, both models must have `input_mode='history'` to properly identify assistant tokens. - When `assistant_only=False`, the transform works with any input mode (`"history"`, `"text"`, or `"tokens"`). - This design ensures users are conscious of the limitation that assistant token identification requires structured conversation history. .. seealso:: :class:`~torchrl.envs.llm.transforms.kl.RetrieveLogProb`: The base transform for retrieving log-probabilities from a single model. :class:`~torchrl.envs.llm.transforms.kl.KLComputation`: The transform that computes KL divergence between two log-prob tensors. :class:`~torchrl.envs.llm.transforms.kl.KLRewardTransform`: A legacy transform for KL reward computation (use `RetrieveKL` instead). """ def __init__( self, gen_model: LLMWrapperBase | Literal["from_collector"] = "from_collector", ref_model: LLMWrapperBase | None = None, *, assistant_only: bool | None = True, history_key: str = "history", tokenizer_kwargs: dict[str, Any] | None = None, detach: bool = True, device: torch.device | None = None, tokenizer: transformers.AutoTokenizer | None = None, padding_side: str = "left", gen_log_probs_full_key: NestedKey = ("log_probs", "full"), ref_log_probs_full_key: NestedKey = ("ref_log_probs", "full"), kl_key: NestedKey = "kl_penalty", add_to_reward: bool = True, coeff: float = 1.0, **kwargs, ): if isinstance(gen_model, str) and gen_model == "from_collector": # Lazy init self._initialized = False self._init_params = { "ref_model": ref_model, "assistant_only": assistant_only, "history_key": history_key, "tokenizer_kwargs": tokenizer_kwargs, "detach": detach, "device": device, "tokenizer": tokenizer, "gen_log_probs_full_key": gen_log_probs_full_key, "ref_log_probs_full_key": ref_log_probs_full_key, "kl_key": kl_key, "add_to_reward": add_to_reward, "coeff": coeff, "padding_side": padding_side, **kwargs, } super().__init__() return self._initialized = True # Check pad_output consistency if both models are provided if hasattr(gen_model, "pad_output") and hasattr(ref_model, "pad_output"): if gen_model.pad_output != ref_model.pad_output: raise ValueError( f"pad_output mismatch: gen_model.pad_output={gen_model.pad_output}, " f"ref_model.pad_output={ref_model.pad_output}. " "Both models must use the same padding strategy for KL computation." ) if not getattr(gen_model, "return_log_probs", True): raise ValueError( "The generation model must have `return_log_probs=True` to use the `RetrieveKL` transform." ) elif getattr(gen_model, "generate", False): raise ValueError( "The generation model must have `generate=False` to use the `RetrieveKL` transform." ) if not getattr(ref_model, "return_log_probs", True): raise ValueError( "The reference model must have `return_log_probs=True` to use the `RetrieveKL` transform." ) elif getattr(ref_model, "generate", False): raise ValueError( "The reference model must have `generate=False` to use the `RetrieveKL` transform." ) if getattr(gen_model, "log_probs_key", "gen_log_probs") == getattr( ref_model, "log_probs_key", "log_probs" ): raise ValueError( "The generation and reference models must have different `log_prob_key` values to use the `RetrieveKL` transform." ) t1 = RetrieveLogProb( gen_model, log_probs_full_key=gen_log_probs_full_key, assistant_only=assistant_only, tokenizer_kwargs=tokenizer_kwargs, detach=detach, device=device, tokenizer=tokenizer, padding_side=padding_side, **kwargs, ) t2 = RetrieveLogProb( ref_model, log_probs_full_key=ref_log_probs_full_key, assistant_only=assistant_only, tokenizer_kwargs=tokenizer_kwargs, detach=detach, device=device, tokenizer=tokenizer, padding_side=padding_side, **kwargs, ) t3 = KLComputation( gen_log_probs_full_key=gen_log_probs_full_key, ref_log_probs_full_key=ref_log_probs_full_key, kl_key=kl_key, add_to_reward=add_to_reward, coeff=coeff, ) super().__init__(t1, t2, t3) def _init_deferred(self): torchrl_logger.info("Initializing RetrieveKL transform") container = self.container if container is None: # also logging, since this will be sometimes hidden within the AttributeError torchrl_logger.warning( "The container is not set. Please set the container before calling this method." ) raise ValueError( "The container is not set. Please set the container before calling this method." ) container.empty_cache() self.empty_cache() collector = self.collector if collector is None: # also logging, since this will be sometimes hidden within the AttributeError torchrl_logger.warning( "The collector is not set. Please set the collector before calling this method." ) raise ValueError( "The collector is not set. Please set the collector before calling this method." ) ref_model = self._init_params["ref_model"] pad_output = getattr(ref_model, "pad_output", None) gen_log_probs_full_key = self._init_params["gen_log_probs_full_key"] if ( not isinstance(gen_log_probs_full_key, tuple) or gen_log_probs_full_key[-1] != "full" ): raise ValueError( f"The gen_log_probs_full_key {gen_log_probs_full_key} is not a tuple or does not end with 'full'. " "This may cause issues with the KL computation. " "Please use a tuple with the log_probs_key and 'full' as the last element." ) log_probs_key = gen_log_probs_full_key[:-1] gen_model = collector.policy.get_new_version( generate=False, return_log_probs=True, log_probs_key=log_probs_key, input_mode=ref_model.input_mode, input_key=(ref_model.input_mode, "full"), pad_output=pad_output, # Pass pad_output from ref_model ) # Create the transforms manually instead of calling __init__ t1 = RetrieveLogProb( gen_model, log_probs_full_key=gen_log_probs_full_key, assistant_only=self._init_params["assistant_only"], tokenizer_kwargs=self._init_params["tokenizer_kwargs"], detach=self._init_params["detach"], device=self._init_params["device"], tokenizer=self._init_params["tokenizer"], padding_side=self._init_params["padding_side"], ) ref_log_probs_full_key = self._init_params["ref_log_probs_full_key"] if ( not isinstance(ref_log_probs_full_key, tuple) or ref_log_probs_full_key[-1] != "full" ): raise ValueError( f"The ref_log_probs_full_key {ref_log_probs_full_key} is not a tuple or does not end with 'full'. " "This may cause issues with the KL computation. " "Please use a tuple with the log_probs_key and 'full' as the last element." ) t2 = RetrieveLogProb( ref_model, log_probs_full_key=ref_log_probs_full_key, assistant_only=self._init_params["assistant_only"], tokenizer_kwargs=self._init_params["tokenizer_kwargs"], detach=self._init_params["detach"], device=self._init_params["device"], tokenizer=self._init_params["tokenizer"], padding_side=self._init_params["padding_side"], ) t3 = KLComputation( gen_log_probs_full_key=gen_log_probs_full_key, ref_log_probs_full_key=ref_log_probs_full_key, kl_key=self._init_params["kl_key"], add_to_reward=self._init_params["add_to_reward"], coeff=self._init_params["coeff"], ) # Replace the transforms in the Compose self.transforms.extend([t1, t2, t3]) del self._init_params self._initialized = True torchrl_logger.info("Successfully initialized") def _step( self, tensordict: TensorDictBase, next_tensordict: TensorDictBase ) -> TensorDictBase: if not self._initialized: self._init_deferred() return super()._step(tensordict, next_tensordict) def _reset( self, tensordict: TensorDictBase, tensordict_reset: TensorDictBase ) -> TensorDictBase: if not self._initialized: self._init_deferred() return super()._reset(tensordict, tensordict_reset)
[docs] def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if not self._initialized: self._init_deferred() return super().forward(tensordict)
[docs] def transform_observation_spec(self, observation_spec: Composite) -> Composite: if not self._initialized: self._init_deferred() return super().transform_observation_spec(observation_spec)
[docs] def transform_reward_spec(self, reward_spec: Composite) -> Composite: if not self._initialized: self._init_deferred() return super().transform_reward_spec(reward_spec)
def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase: if not self._initialized: self._init_deferred() return super()._inv_call(tensordict)
[docs] def transform_action_spec(self, action_spec: Composite) -> Composite: if not self._initialized: self._init_deferred() return super().transform_action_spec(action_spec)
[docs] def transform_input_spec(self, input_spec: Composite) -> Composite: if not self._initialized: self._init_deferred() return super().transform_input_spec(input_spec)
[docs] def transform_output_spec(self, output_spec: Composite) -> Composite: if not self._initialized: self._init_deferred() return super().transform_output_spec(output_spec)
[docs] def transform_state_spec(self, state_spec: Composite) -> Composite: if not self._initialized: self._init_deferred() return super().transform_state_spec(state_spec)
[docs]class KLComputation(Transform): """A transform to compute KL divergence between two log-prob tensors and optionally add it to the reward. This transform computes KL divergence between generation and reference log-probabilities and can optionally subtract it from the reward (for KL penalty). It's designed to work with the :class:`~torchrl.envs.llm.transforms.kl.RetrieveLogProb` and :class:`~torchrl.envs.llm.transforms.kl.RetrieveKL` transforms. .. note:: Both input log-prob tensors must use the same padding strategy (pad_output) for correct KL computation. Args: gen_log_probs_full_key (NestedKey): the key where the generation model log-probs are stored. Defaults to `("gen_log_probs", "full")`. ref_log_probs_full_key (NestedKey): the key where the reference model log-probs are stored. Defaults to `("ref_log_probs", "full")`. kl_key (NestedKey): the key where the KL divergence is stored. Defaults to `"kl_penalty"`. add_to_reward (bool): whether to add the KL divergence to the reward. Defaults to `True`. coeff (float): the coefficient for the KL term when adding to reward. Defaults to `1.0`. padding_side (str): the side of the padding when using pad_sequence. Defaults to `"left"`. Examples: >>> from tensordict import TensorDict >>> import torch >>> >>> # Create sample log-probs >>> gen_log_probs = torch.randn(2, 10) # 2 samples, 10 tokens each >>> ref_log_probs = torch.randn(2, 10) >>> >>> # Create data with next tensordict >>> next_td = TensorDict( ... { ... ("gen_log_probs", "full"): gen_log_probs, ... ("ref_log_probs", "full"): ref_log_probs, ... "reward": torch.randn(2, 10, 1), ... }, ... batch_size=(2,) ... ) >>> data = TensorDict(next=next_td, batch_size=(2,)) >>> >>> # Create KLComputation transform >>> kl_transform = KLComputation( ... gen_log_probs_key=("gen_log_probs", "full"), ... ref_log_probs_key=("ref_log_probs", "full"), ... kl_key="kl_penalty", ... add_to_reward=True, ... coef=1.0, ... ) >>> >>> # Apply transform >>> result = kl_transform(data) >>> kl = result["next"].get("kl_penalty") >>> print(f"KL shape: {kl.shape}") KL shape: torch.Size([2, 10]) .. seealso:: :class:`~torchrl.envs.llm.transforms.kl.RetrieveLogProb`: The base transform for retrieving log-probabilities from a single model. :class:`~torchrl.envs.llm.transforms.kl.RetrieveKL`: A higher-level transform that combines two `RetrieveLogProb` instances with `KLComputation`. :class:`~torchrl.envs.llm.transforms.kl.KLRewardTransform`: A legacy transform for KL reward computation (use `RetrieveKL` instead). """ def __init__( self, gen_log_probs_full_key: NestedKey = ("log_probs", "full"), ref_log_probs_full_key: NestedKey = ("ref_log_probs", "full"), *, kl_key: NestedKey = "kl_penalty", add_to_reward: bool = True, coeff: float = 1.0, padding_side: str = "left", ): in_keys = [gen_log_probs_full_key, ref_log_probs_full_key] if add_to_reward: in_keys.append("reward") out_keys = [kl_key] if add_to_reward: out_keys.append("reward") super().__init__(in_keys=in_keys, out_keys=out_keys) self.gen_log_probs_full_key = gen_log_probs_full_key self.ref_log_probs_full_key = ref_log_probs_full_key self.kl_key = kl_key self.add_to_reward = add_to_reward self.coeff = coeff self.padding_side = padding_side
[docs] def forward(self, tensordict: TensorDictBase) -> TensorDictBase: next_td = tensordict.get("next") has_next_td = True if next_td is None: next_td = tensordict has_next_td = False next_td = self._step(tensordict, next_td) if has_next_td: return tensordict.set("next", next_td) return next_td
def _step( self, tensordict: TensorDictBase, next_tensordict: TensorDictBase ) -> TensorDictBase: # Get log-probs gen_log_probs = next_tensordict.get(self.gen_log_probs_full_key, as_list=True) ref_log_probs = next_tensordict.get(self.ref_log_probs_full_key, as_list=True) if gen_log_probs is None or ref_log_probs is None: raise ValueError( f"Log-probs not found. Expected keys: {self.gen_log_probs_key}, {self.ref_log_probs_key}" ) # Debug: Check lengths and shapes if len(gen_log_probs) != len(ref_log_probs): raise ValueError( f"Batch size mismatch: gen_log_probs has {len(gen_log_probs)} samples, ref_log_probs has {len(ref_log_probs)} samples" ) # Check individual sequence lengths for i, (gen_lp, ref_lp) in enumerate(_zip_strict(gen_log_probs, ref_log_probs)): if gen_lp.shape != ref_lp.shape: raise ValueError( f"Sample {i} has different shapes: gen_log_probs[{i}].shape={gen_lp.shape}, ref_log_probs[{i}].shape={ref_lp.shape}" ) # Compute KL divergence: KL(p||q) = E_p[log p - log q] # Here gen_log_probs = log p, ref_log_probs = log q kl = [ gen_lp - ref_lp for gen_lp, ref_lp in _zip_strict(gen_log_probs, ref_log_probs) ] kl = torch.nested.as_nested_tensor(kl, layout=torch.strided) next_tensordict.set(self.kl_key, kl) # Add to reward if requested if self.add_to_reward: reward = next_tensordict.get("reward", as_list=True) if reward is not None: if isinstance(reward, list): if reward[0].ndim != kl[0].ndim + 1: raise ValueError( f"The rewards have shape {reward[0].shape} but the kl has shape {kl[0].shape}. " f"The rewards should have one more dimension than the KL." ) reward = [ r - self.coeff * k.unsqueeze(-1) for r, k in _zip_strict(reward, kl) ] next_tensordict.set( "reward", torch.nested.as_nested_tensor(reward, layout=torch.strided), ) else: if reward.ndim != kl.ndim + 1: raise ValueError( f"The rewards have shape {reward.shape} but the kl has shape {kl.shape}. " f"The rewards should have one more dimension than the KL." ) reward = reward - self.coeff * kl.unsqueeze(-1) next_tensordict.set("reward", reward) return next_tensordict
[docs] def transform_observation_spec(self, observation_spec: Composite) -> Composite: # Add kl to observation spec observation_spec[self.kl_key] = Unbounded( device=observation_spec.device, shape=observation_spec.shape, ) return observation_spec
[docs] def transform_reward_spec(self, reward_spec: Composite) -> Composite: # Optionally adjust reward spec if KL is added to reward if self.add_to_reward: shape = reward_spec["reward"].shape # For LLMs, the shape of the reward is (batch, -1, 1) shape = (*shape, -1, 1) reward_spec["reward"] = reward_spec["reward"].clone() reward_spec["reward"].shape = torch.Size(shape) return reward_spec

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources