Shortcuts

Source code for torchrl.modules.llm.policies.vllm_wrapper

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import collections
import warnings
from typing import Any, Literal

import torch
from tensordict import (
    lazy_stack,
    MetaData,
    NonTensorStack,
    set_list_to_stack,
    TensorDict,
    TensorDictBase,
)
from tensordict.tensorclass import from_dataclass, TensorClass
from tensordict.utils import _zip_strict, NestedKey
from torch import distributions as D
from torch.nn.utils.rnn import pad_sequence

from torchrl.envs.utils import _classproperty
from torchrl.modules.llm.policies.common import (
    _extract_responses_from_full_histories,
    ChatHistory,
    LLMWrapperBase,
    LogProbs,
    Masks,
    Text,
    Tokens,
)
from torchrl.modules.utils.utils import _unpad_tensors

# Type imports
try:
    import transformers
    import vllm
    from vllm.outputs import RequestOutput
    from vllm.sampling_params import SamplingParams
except ImportError:
    vllm = None
    transformers = None
    SamplingParams = Any  # type: ignore
    RequestOutput = Any  # type: ignore


[docs]class vLLMWrapper(LLMWrapperBase): """A wrapper class for vLLM models, providing a consistent interface for text generation and log probability computation. This class is a subclass of :class:`~torchrl.modules.llm.policies.LLMWrapperBase` and provides a unified API for handling different input modalities (history, text, tokens) with consistent output structure using :class:`~tensordict.TensorClass` objects. Args: model (vllm.LLM | str): The vLLM model to wrap. If a string, it will be passed to `vllm.LLM`. Keyword Args: tokenizer (transformers.tokenization_utils.PreTrainedTokenizer | str | None, optional): The tokenizer to use for encoding and decoding text. If `None`, the tokenizer associated with the model will be used. If a string, it will be passed to `transformers.AutoTokenizer.from_pretrained`. Defaults to `None`. input_mode (str, optional): The input modality to use. Must be one of `"history"`, `"text"`, or `"tokens"`. Defaults to `"history"`. input_key (str | None, optional): The key for the input data. If `None`, defaults to - `("history", "prompt")` for `"history"` when `generate=True`, `("history", "full")` for `"history"` when `generate=False` - `("text", "prompt")` for `"text"` when `generate=True`, `("text", "full")` for `"text"` when `generate=False` - `("tokens", "prompt")` for `"tokens"` when `generate=True`, `("tokens", "full")` for `"tokens"` when `generate=False` attention_mask_key (str, optional): The key for attention masks (used in `"tokens"` mode). Defaults to `"attention_mask"`. .. warning:: This argument is under development and may change in the future. generate (bool, optional): Whether to enable text generation. If `True`, the model will generate text based on the input. If `False`, only log probabilities will be computed. Defaults to `True`. return_log_probs (bool, optional): Whether to return log probabilities. Defaults to `True`. generate_kwargs (dict | None, optional): Additional arguments to pass to the model's generate method. Defaults to `None`. tokenizer_kwargs (dict | None, optional): Additional arguments to pass to the tokenizer. Defaults to `None`. pad_output (bool, optional): Whether to pad the output sequences to a uniform length. Defaults to `False`. inplace (Literal[True, False, "empty"] | None, optional): Determines how the module should handle in-place operations. Defaults to `True`. device (torch.device | None, optional): The device to use for computation. Defaults to `None`. layout (torch.layout | None, optional): The layout to use for the output tensors when `pad_output=False`. Defaults to `torch.strided`. chat_template_name (Literal["chatml_format", "qwen"] | None, optional): The name of the chat template to use when applying the chat template to the history. Defaults to `None`. For `input_mode="history"` only. chat_template (str | None, optional): The chat template to use when applying the chat template to the history. Defaults to `None`. For `input_mode="history"` only. num_samples (int | None, optional): The number of samples to generate. Defaults to `None` (one sample, and no batch-dimension for it). Can also be set via the `generate_kwargs["n"] = value` argument. log_probs_key (NestedKey | None, optional): The key for the log probabilities :class:`~torchrl.modules.llm.policies.LogProbs` object. Defaults to `"log_probs"`. text_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Text` object. Defaults to `"text"`. tokens_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Tokens` object. Defaults to `"tokens"`. masks_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.Masks` object. Defaults to `"masks"`. history_key (NestedKey | None, optional): The key for the action :class:`~torchrl.modules.llm.policies.ChatHistory` object. Defaults to `"history"`. Input Keys: The input key depends on both `input_mode` and `generate`: - If `input_mode="history"` and `generate=True`: `input_key` (defaults to `("history", "prompt")`) - If `input_mode="history"` and `generate=False`: `input_key` (defaults to `("history", "full")`) - If `input_mode="text"` and `generate=True`: `input_key` (defaults to `("text", "prompt")`) - If `input_mode="text"` and `generate=False`: `input_key` (defaults to `("text", "full")`) - If `input_mode="tokens"` and `generate=True`: `input_key` (defaults to `("tokens", "prompt")`) - If `input_mode="tokens"` and `generate=False`: `input_key` (defaults to `("tokens", "full")`) Output Keys: The output keys are automatically determined based on the input_mode: - **Tokens**: Always returned (`tokens_key`, defaults to `"tokens"`) - **Text**: Returned for `"text"` and `"history"` modes (`text_key`, defaults to `"text"`) - **History**: Returned only for `"history"` mode (`history_key`, defaults to `"history"`) - **Masks**: Always returned (`masks_key`, defaults to `"masks"`) - **Log Probs**: Returned when `return_log_probs=True` (`log_probs_key`, defaults to `"log_probs"`) Example output structure for `input_mode="history"`: ``` TensorDict( text=Text(prompt=..., response=..., full=...), masks=Masks(all_attention_mask=..., all_assistant_mask=...), tokens=Tokens(prompt=..., response=..., full=...), log_probs=LogProbs(prompt=..., response=..., full=...), history=ChatHistory(prompt=..., response=..., full=...) ) ``` Example: >>> from vllm import LLM >>> from transformers import AutoTokenizer >>> from torchrl.data.llm import History >>> from torchrl.modules.llm.policies import ChatHistory >>> >>> model = LLM("gpt2") >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") >>> >>> # History input (recommended for RL environments) >>> wrapper = vLLMWrapper( ... model, ... tokenizer=tokenizer, ... input_mode="history", ... generate=True, ... return_log_probs=True ... ) >>> >>> history = History.from_chats([[ ... {"role": "user", "content": "Hello"}, ... {"role": "assistant", "content": "Hi there!"} ... ]]) >>> chat_history = ChatHistory(prompt=history) >>> result = wrapper(TensorDict(history=chat_history, batch_size=(1,))) >>> print(result["text"].response) # Generated text >>> print(result["log_probs"].response) # Log probabilities >>> print(result["history"].response) # History with response Attributes: collector: The collector associated with the module, if it exists. .. seealso:: - :class:`~torchrl.modules.llm.policies.LLMWrapperBase` (see :ref:`ref_categorical_sequential`) - :class:`~torchrl.modules.llm.policies.TransformersWrapper` (see :ref:`ref_transformers_wrapper`) """ def __init__( self, model: vllm.LLM | str, *, tokenizer: callable | str | None = None, # type: ignore input_mode: str = "history", input_key: NestedKey | None = None, attention_mask_key: str = "attention_mask", generate: bool = True, generate_kwargs: dict | None = None, tokenizer_kwargs: dict | None = None, pad_output: bool = False, inplace: Literal[True, False, "empty"] | None = None, device: torch.device | None = None, layout: torch.layout | None = None, num_samples: int | None = None, chat_template_name: Literal["chatml_format", "qwen"] | None = None, chat_template: str | None = None, return_log_probs: bool | None = None, history_key: NestedKey | None = "history", text_key: NestedKey | None = "text", tokens_key: NestedKey | None = "tokens", masks_key: NestedKey | None = "masks", log_probs_key: NestedKey | None = "log_probs", ): super().__init__() if vllm is None: raise ImportError("vllm is required for vLLMWrapper") if transformers is None: raise ImportError("transformers is required for vLLMWrapper") if isinstance(model, str): model = vllm.LLM(model) if isinstance(tokenizer, str): from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(tokenizer) from vllm import SamplingParams # Validate input_mode if input_mode not in ["history", "text", "tokens"]: raise ValueError( f"input_mode must be one of 'history', 'text', 'tokens'. Got '{input_mode}'" ) self.model = model self._remote_calls = not isinstance(model, vllm.LLM) self.input_mode = input_mode self.attention_mask_key = attention_mask_key self.generate = generate # Auto-determine what to return based on input mode self.return_history = input_mode in ("history",) self.return_text = input_mode in ("text", "history") self.return_tokens = input_mode in ("tokens", "history", "text") self.return_masks = True if return_log_probs is False and not generate: raise ValueError("return_log_probs must be True when generate=False.") return_log_probs = ( True if (return_log_probs is None and generate) or (not generate) else bool(return_log_probs) ) self.return_log_probs = return_log_probs self.history_key = history_key self.log_probs_key = log_probs_key self.masks_key = masks_key self.text_key = text_key self.tokens_key = tokens_key if not isinstance(pad_output, bool): raise ValueError("pad_output must be a boolean") self.pad_output = pad_output self._device = device if not pad_output and layout is None: layout = torch.strided self.layout = layout padding_value = None # Set input keys based on mode and generate parameter if input_mode == "history": if generate: self.in_keys = [ ("history", "prompt") if input_key is None else input_key ] else: self.in_keys = [("history", "full") if input_key is None else input_key] elif input_mode == "text": if generate: self.in_keys = [("text", "prompt") if input_key is None else input_key] else: self.in_keys = [("text", "full") if input_key is None else input_key] elif input_mode == "tokens": if generate: self.in_keys = [ ("tokens", "prompt") if input_key is None else input_key ] else: self.in_keys = [("tokens", "full") if input_key is None else input_key] else: raise ValueError(f"Invalid input_mode: {input_mode}") self.input_key = self.in_keys[0] # Set output keys based on auto-determined return flags self.out_keys = [] if self.return_text: self.out_keys.append(self.text_key) if self.return_masks: self.out_keys.append(self.masks_key) if self.return_tokens: self.out_keys.append(self.tokens_key) if self.return_log_probs: self.out_keys.append(self.log_probs_key) if self.return_history: self.out_keys.append(self.history_key) # Tokenizer setup if not tokenizer_kwargs: tokenizer_kwargs = {} if not tokenizer_kwargs.setdefault("return_attention_mask", True): raise RuntimeError("return_attention_mask must be True") # If we don't pad, we use lists return_tensors = "pt" if self.pad_output else False if return_tensors: if ( tokenizer_kwargs.setdefault("return_tensors", return_tensors) != return_tensors ): raise RuntimeError if tokenizer_kwargs.setdefault("padding", self.pad_output) not in ( self.pad_output, ): raise RuntimeError if tokenizer_kwargs.setdefault("padding_side", "left") != "left": raise RuntimeError self.tokenizer_kwargs = tokenizer_kwargs # Get tokenizer if needed if tokenizer is None: try: tokenizer = model.get_tokenizer() except AttributeError: warnings.warn("No tokenizer provided and no tokenizer found in model.") self.tokenizer = tokenizer if self.tokenizer is not None and ( not hasattr(self.tokenizer, "pad_token") or self.tokenizer.pad_token is None ): self.tokenizer.pad_token = self.tokenizer.eos_token if self.tokenizer is not None: padding_value = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0] self.padding_value = padding_value # Generate kwargs setup if generate_kwargs is None: generate_kwargs = {} else: generate_kwargs = dict(generate_kwargs) self.num_samples = num_samples if generate_kwargs.get("n", 1) > 1 or num_samples is not None: if inplace in (True, "empty"): raise ValueError( "inplace must be False (or None) when generating more than one sample." ) if inplace is None: inplace = False if ( generate_kwargs.get("n", 1) > 1 and num_samples is not None and generate_kwargs.get("n", 1) != num_samples ): raise ValueError("num_samples differs from generate_kwargs['n'].") elif num_samples is None: self.num_samples = generate_kwargs.get("n", 1) generate_kwargs["n"] = self.num_samples elif inplace is None: inplace = True self.inplace = inplace prompt_logprobs = return_log_probs if not generate: # We want only the log-probs, we generate a single token (that we then discard) # and retrieve the prompt log-probs generate_kwargs["max_tokens"] = 1 if not return_log_probs: raise ValueError("return_log_probs must be True when generate=False.") generate_kwargs.setdefault("detokenize", not pad_output) generate_kwargs.setdefault("prompt_logprobs", prompt_logprobs) generate_kwargs.setdefault("logprobs", return_log_probs) generate_kwargs.setdefault("include_stop_str_in_output", True) generate_kwargs.setdefault("skip_special_tokens", False) sampling_params = SamplingParams(**generate_kwargs) self.sampling_params = sampling_params # Additional transformers-specific settings self.chat_template_name = chat_template_name self.chat_template = chat_template
[docs] def get_new_version(self, **kwargs): """Returns a new version of the module with altered parameters. For instance, the generate parameter can be altered to enable text generation or log-probabilities computation. This is especially useful when one wants to avoid re-initializing the module with a new set of parameters, when the same parameters could be used to gather log-probs. Positional arguments are not supported. See the class constructor for more details about the parameters. """ # Build the constructor arguments by using current values for missing parameters constructor_kwargs = {} # Model is always required constructor_kwargs["model"] = kwargs.get("model", self.model) # Check for each parameter and use current value if not provided if "tokenizer" in kwargs: constructor_kwargs["tokenizer"] = kwargs["tokenizer"] elif hasattr(self, "tokenizer"): constructor_kwargs["tokenizer"] = self.tokenizer if "input_mode" in kwargs: constructor_kwargs["input_mode"] = kwargs["input_mode"] elif hasattr(self, "input_mode"): constructor_kwargs["input_mode"] = self.input_mode if "input_key" in kwargs: constructor_kwargs["input_key"] = kwargs["input_key"] # Since the input_key is dynamically determined, we don't want to set it here # elif hasattr(self, "input_key"): # constructor_kwargs["input_key"] = self.input_key if "attention_mask_key" in kwargs: constructor_kwargs["attention_mask_key"] = kwargs["attention_mask_key"] elif hasattr(self, "attention_mask_key"): constructor_kwargs["attention_mask_key"] = self.attention_mask_key if "generate" in kwargs: constructor_kwargs["generate"] = kwargs["generate"] elif hasattr(self, "generate"): constructor_kwargs["generate"] = self.generate if "return_log_probs" in kwargs: constructor_kwargs["return_log_probs"] = kwargs["return_log_probs"] elif not constructor_kwargs.get("generate", True): # if we are not generating, we want to return log-probs constructor_kwargs["return_log_probs"] = True elif hasattr(self, "return_log_probs"): constructor_kwargs["return_log_probs"] = self.return_log_probs if "generate_kwargs" in kwargs: constructor_kwargs["generate_kwargs"] = kwargs["generate_kwargs"] elif hasattr(self, "generate_kwargs"): constructor_kwargs["generate_kwargs"] = self.generate_kwargs if "pad_output" in kwargs: constructor_kwargs["pad_output"] = kwargs["pad_output"] elif hasattr(self, "pad_output"): constructor_kwargs["pad_output"] = self.pad_output if "tokenizer_kwargs" in kwargs: constructor_kwargs["tokenizer_kwargs"] = kwargs["tokenizer_kwargs"] elif hasattr(self, "tokenizer_kwargs"): constructor_kwargs["tokenizer_kwargs"] = dict(self.tokenizer_kwargs) if ( "pad_output" in kwargs and kwargs.get("pad_output") != constructor_kwargs["tokenizer_kwargs"]["padding"] ): constructor_kwargs["tokenizer_kwargs"]["padding"] = kwargs.get( "pad_output" ) if "inplace" in kwargs: constructor_kwargs["inplace"] = kwargs["inplace"] elif hasattr(self, "inplace"): constructor_kwargs["inplace"] = self.inplace if "device" in kwargs: constructor_kwargs["device"] = kwargs["device"] elif hasattr(self, "_device"): constructor_kwargs["device"] = self._device if "layout" in kwargs: constructor_kwargs["layout"] = kwargs["layout"] elif hasattr(self, "layout"): constructor_kwargs["layout"] = self.layout if "num_samples" in kwargs: constructor_kwargs["num_samples"] = kwargs["num_samples"] elif hasattr(self, "num_samples"): constructor_kwargs["num_samples"] = self.num_samples if "chat_template_name" in kwargs: constructor_kwargs["chat_template_name"] = kwargs["chat_template_name"] elif hasattr(self, "chat_template_name"): constructor_kwargs["chat_template_name"] = self.chat_template_name if "chat_template" in kwargs: constructor_kwargs["chat_template"] = kwargs["chat_template"] elif hasattr(self, "chat_template"): constructor_kwargs["chat_template"] = self.chat_template if "history_key" in kwargs: constructor_kwargs["history_key"] = kwargs["history_key"] elif hasattr(self, "history_key"): constructor_kwargs["history_key"] = self.history_key if "text_key" in kwargs: constructor_kwargs["text_key"] = kwargs["text_key"] elif hasattr(self, "text_key"): constructor_kwargs["text_key"] = self.text_key if "tokens_key" in kwargs: constructor_kwargs["tokens_key"] = kwargs["tokens_key"] elif hasattr(self, "tokens_key"): constructor_kwargs["tokens_key"] = self.tokens_key if "masks_key" in kwargs: constructor_kwargs["masks_key"] = kwargs["masks_key"] elif hasattr(self, "masks_key"): constructor_kwargs["masks_key"] = self.masks_key if "log_probs_key" in kwargs: constructor_kwargs["log_probs_key"] = kwargs["log_probs_key"] elif hasattr(self, "log_probs_key"): constructor_kwargs["log_probs_key"] = self.log_probs_key # Create and return new instance return type(self)(**constructor_kwargs)
[docs] @set_list_to_stack(True) def forward( self, tensordict: TensorDictBase, tensordict_out: TensorDictBase | None = None, **kwargs, ) -> TensorDictBase: if not tensordict.ndim: # unsqueeze - squeeze the input try: return self(lazy_stack([tensordict])).squeeze(0) except Exception as e: raise RuntimeError( f"Unsqueeze/squeeze failed. Inputs to {type(self).__name__} should ideally be 1 dimensional." ) from e elif tensordict.ndim > 1: return self(tensordict.reshape(-1)).view(tensordict.shape) _source_device = None if self._device: _source_device = tensordict.device if tensordict.device: tensordict = tensordict.copy().clear_device_() if kwargs: from vllm import SamplingParams sampling_params = SamplingParams(**kwargs) else: sampling_params = self.sampling_params if self.num_samples is not None: out = ( TensorDict( device=tensordict.device, batch_size=( tensordict.batch_size[0], self.num_samples, *tensordict.batch_size[1:], ), ) .to_lazystack(1) .to_lazystack(0) ) else: out = TensorDict( device=tensordict.device, batch_size=tensordict.batch_size ).to_lazystack(0) if self.input_mode == "history": if self.generate: out = self._from_vllm_generate_history(tensordict, sampling_params, out) else: out = self._from_vllm_logprobs_history(tensordict, sampling_params, out) elif self.input_mode == "text": if self.generate: out = self._from_vllm_generate_text(tensordict, sampling_params, out) else: out = self._from_vllm_logprobs_text(tensordict, sampling_params, out) elif self.input_mode == "tokens": if self.generate: out = self._from_vllm_generate_tokens(tensordict, sampling_params, out) else: out = self._from_vllm_logprobs_tokens(tensordict, sampling_params, out) if _source_device: out = out.to(_source_device) if tensordict_out is None: if self.inplace is True: # The output is the input tensordict_out = tensordict elif self.inplace is False: # The output is the new structure tensordict_out = out elif self.inplace == "empty": # The output is empty tensordict_out = tensordict.empty() if tensordict_out is not None and tensordict_out is not out: result = tensordict_out.exclude(*self.out_keys, inplace=True) result.update(out, keys_to_update=self.out_keys) elif tensordict_out is out: result = out.select(*self.out_keys) elif self.inplace: result = out keys = list(set(self.out_keys + list(tensordict.keys(True, True)))) result = tensordict.exclude(*self.out_keys, inplace=True).update( result, keys_to_update=keys ) else: result = out return result
def _from_vllm_generate_history( self, tensordict_input: TensorDictBase, sampling_params: SamplingParams, out: TensorDictBase, ) -> TensorDictBase: """Generate text from history input.""" from torchrl.data.llm import History assert isinstance( tensordict_input, TensorDictBase ), f"tensordict_input must be TensorDictBase, got {type(tensordict_input)}" assert isinstance( sampling_params, SamplingParams ), f"sampling_params must be SamplingParams, got {type(sampling_params)}" assert isinstance( out, TensorDictBase ), f"out must be TensorDictBase, got {type(out)}" # Validate input if self.input_key not in tensordict_input: raise ValueError( f"Expected '{self.input_key}' key for history input mode, " f"but found keys: {list(tensordict_input.keys())}" ) history = tensordict_input.get(self.input_key) if not isinstance(history, History): raise TypeError( f"Expected History object for '{self.input_key}', got {type(history)}" ) # Apply chat template tokenizer_kwargs = {} if self.chat_template_name is not None: tokenizer_kwargs.setdefault("chat_template_name", self.chat_template_name) if self.chat_template is not None: tokenizer_kwargs.setdefault("chat_template", self.chat_template) tokenizer_kwargs.setdefault("add_generation_prompt", True) text_prompt = history.apply_chat_template( tokenizer=self.tokenizer, **tokenizer_kwargs ) tokenizer_kwargs.setdefault("return_assistant_tokens_mask", False) tokenizer_kwargs.setdefault("tokenize", True) tokenizer_kwargs.setdefault("padding", False) tokenizer_kwargs.setdefault("return_dict", True) response_struct = history.apply_chat_template( tokenizer=self.tokenizer, **tokenizer_kwargs ) tokens_prompt_padded = None tokens_prompt_unpadded = None if self.pad_output: tokens_prompt_padded = response_struct.get( "input_ids", as_padded_tensor=True, padding_value=self.padding_value, padding_side="left", ) else: tokens_prompt_unpadded = response_struct.get("input_ids", as_list=True) result = self._generate_from_tokens( tokens_prompt_padded=tokens_prompt_padded, tokens_prompt_unpadded=tokens_prompt_unpadded, sampling_params=sampling_params, out=out, ) # Generate using text path if self.pad_output: result[(self.tokens_key, "prompt")] = ( tokens_prompt_padded if not self.num_samples else tokens_prompt_padded.unsqueeze(1).repeat(1, self.num_samples, 1) ) else: tokens_prompt_nested = torch.nested.as_nested_tensor(tokens_prompt_unpadded) if not self.num_samples: result[(self.tokens_key, "prompt")] = tokens_prompt_nested else: for r in result.unbind(1): r[(self.tokens_key, "prompt")] = tokens_prompt_nested text_result = Text._from_tensordict(result.empty()) result.set(self.text_key, text_result) if not self.num_samples: text_result.prompt = text_prompt else: for r in result.unbind(1): r[self.text_key, "prompt"] = text_prompt with result.view(-1) as result_flat: if self.pad_output: tokens_full_padded = result_flat.get( (self.tokens_key, "full"), as_padded_tensor=True, padding_side="right", padding_value=self.padding_value, ) if tokens_full_padded is None: raise ValueError("tokens_full_padded is None") text_full = self.tokenizer.batch_decode( tokens_full_padded, skip_special_tokens=False ) else: tokens_full_unpadded = result_flat.get( (self.tokens_key, "full"), as_list=True ) # print("shapes of assistant masks", [t.shape for t in result_flat.get(("masks", "all_assistant_mask"), as_list=True)]) if tokens_full_unpadded is None: raise ValueError("tokens_full_unpadded is None") text_full = self.tokenizer.batch_decode( tokens_full_unpadded, skip_special_tokens=False ) text_prompt = result_flat[self.text_key, "prompt"] text_response = [ txt[len(prompt) :] for txt, prompt in _zip_strict(text_full, text_prompt) ] result_flat.set((self.text_key, "full"), text_full) result_flat.set((self.text_key, "response"), text_response) # Now parse the full text back to a history object, and use the extra history objects # as response history_chat = ChatHistory._from_tensordict(result.empty()) if self.num_samples is None: history_chat.prompt = history else: for h in history_chat.unbind(1): h.prompt = history with history_chat.view(-1) as history_chat_flat: prompt_histories = history_chat_flat.prompt # Extract response histories from full text h_responses = _extract_responses_from_full_histories( text_full, prompt_histories, self.chat_template_name, self.tokenizer ) history_chat_flat.response = h_responses result.set(self.history_key, history_chat) return result def _from_vllm_logprobs_history( self, tensordict_input: TensorDictBase, sampling_params: SamplingParams, out: TensorDictBase, ) -> TensorDictBase: """Compute log-probs from history input.""" assert isinstance( tensordict_input, TensorDictBase ), f"tensordict_input must be TensorDictBase, got {type(tensordict_input)}" assert isinstance( sampling_params, SamplingParams ), f"sampling_params must be SamplingParams, got {type(sampling_params)}" assert isinstance( out, TensorDictBase ), f"out must be TensorDictBase, got {type(out)}" from torchrl.data.llm import History # Validate input if self.input_key not in tensordict_input: raise ValueError( f"Expected '{self.input_key}' key for history input mode, " f"but found keys: {list(tensordict_input.keys())}" ) history = tensordict_input.get(self.input_key) if not isinstance(history, History): raise TypeError( f"Expected History object for '{self.input_key}', got {type(history)}" ) # Apply chat template tokenizer_kwargs = {} if self.chat_template_name is not None: tokenizer_kwargs.setdefault("chat_template_name", self.chat_template_name) if self.chat_template is not None: tokenizer_kwargs.setdefault("chat_template", self.chat_template) tokenizer_kwargs.setdefault("add_generation_prompt", False) text_full = history.apply_chat_template( tokenizer=self.tokenizer, **tokenizer_kwargs ) tokenizer_kwargs.setdefault("return_assistant_tokens_mask", True) tokenizer_kwargs.setdefault("tokenize", True) tokenizer_kwargs.setdefault("padding", False) tokenizer_kwargs.setdefault("return_dict", True) response_struct = history.apply_chat_template( tokenizer=self.tokenizer, **tokenizer_kwargs ) result = self._logprobs_from_tokens( response_struct=response_struct, sampling_params=sampling_params, out=out ) text_result = Text._from_tensordict(result.empty()) result.set(self.text_key, text_result) result[self.text_key, "full"] = text_full result.set(self.history_key, ChatHistory(full=history)) return result def _from_vllm_generate_text( self, td: TensorDictBase, sampling_params: SamplingParams, out: TensorDictBase ) -> TensorDictBase: """Generate text from text input.""" # Type assertions assert isinstance( td, TensorDictBase ), f"td must be TensorDictBase, got {type(td)}" assert isinstance( sampling_params, SamplingParams ), f"sampling_params must be SamplingParams, got {type(sampling_params)}" assert isinstance( out, TensorDictBase ), f"out must be TensorDictBase, got {type(out)}" # Validate input if self.input_key not in td: raise ValueError( f"Expected '{self.input_key}' key for text input mode, " f"but found keys: {list(td.keys())}" ) text = td.get(self.input_key) if text is None: raise ValueError(f"Expected '{self.input_key}' key for text input mode") return self._generate_from_text(text, sampling_params, out) def _from_vllm_logprobs_text( self, td: TensorDictBase, sampling_params: SamplingParams, out: TensorDictBase ) -> TensorDictBase: """Compute log-probs from text input.""" # Type assertions assert isinstance( td, TensorDictBase ), f"td must be TensorDictBase, got {type(td)}" assert isinstance( sampling_params, SamplingParams ), f"sampling_params must be SamplingParams, got {type(sampling_params)}" assert isinstance( out, TensorDictBase ), f"out must be TensorDictBase, got {type(out)}" # Validate input if self.input_key not in td: raise ValueError( f"Expected '{self.input_key}' key for text input mode, " f"but found keys: {list(td.keys())}" ) text = td.get(self.input_key) if text is None: raise ValueError(f"Expected '{self.input_key}' key for text input mode") return self._logprobs_from_text(text, sampling_params, out) def _from_vllm_generate_tokens( self, td: TensorDictBase, sampling_params: SamplingParams, out: TensorDictBase ) -> TensorDictBase: """Generate text from tokens input.""" # Type assertions assert isinstance( td, TensorDictBase ), f"td must be TensorDictBase, got {type(td)}" assert isinstance( sampling_params, SamplingParams ), f"sampling_params must be SamplingParams, got {type(sampling_params)}" assert isinstance( out, TensorDictBase ), f"out must be TensorDictBase, got {type(out)}" # Validate input if self.input_key not in td: raise ValueError( f"Expected '{self.input_key}' key for tokens input mode, " f"but found keys: {list(td.keys())}" ) tokens_prompt_padded = None tokens_prompt_unpadded = None if self.pad_output: tokens_prompt_padded = td.get(self.input_key) else: tokens_prompt_unpadded = list(td.get(self.input_key, as_list=True)) # make sure we remove the padding tokens tokens_prompt_unpadded = [ tokens[tokens != self.padding_value] for tokens in tokens_prompt_unpadded ] return self._generate_from_tokens( tokens_prompt_unpadded=tokens_prompt_unpadded, tokens_prompt_padded=tokens_prompt_padded, sampling_params=sampling_params, out=out, ) def _from_vllm_logprobs_tokens( self, td: TensorDictBase, sampling_params: SamplingParams, out: TensorDictBase ) -> TensorDictBase: """Compute log-probs from tokens input.""" # Type assertions assert isinstance( td, TensorDictBase ), f"td must be TensorDictBase, got {type(td)}" assert isinstance( sampling_params, SamplingParams ), f"sampling_params must be SamplingParams, got {type(sampling_params)}" assert isinstance( out, TensorDictBase ), f"out must be TensorDictBase, got {type(out)}" # Validate input if self.input_key not in td: raise ValueError( f"Expected '{self.input_key}' key for tokens input mode, " f"but found keys: {list(td.keys())}" ) tokens_full_padded = None tokens_full_unpadded = None if self.pad_output: tokens_full_padded = td.get(self.input_key) else: tokens_full_unpadded = list(td.get(self.input_key, as_list=True)) # make sure we remove the padding tokens tokens_full_unpadded = [ tokens[tokens != self.padding_value] for tokens in tokens_full_unpadded ] return self._logprobs_from_tokens( response_struct=None, tokens_full_unpadded=tokens_full_unpadded, tokens_full_padded=tokens_full_padded, sampling_params=sampling_params, out=out, ) def _cat_text( self, text: str | list[str], response_text: str | list[str] ) -> str | list[str]: """Concatenate text and response text.""" assert isinstance( text, (str, list) ), f"text must be str or list, got {type(text)}" assert isinstance( response_text, (str, list) ), f"response_text must be str or list, got {type(response_text)}" if isinstance(text, list): return [self._cat_text(t, t_) for t, t_ in _zip_strict(text, response_text)] else: return text + response_text def _generate_from_text( self, text: str | list[str] | NonTensorStack, sampling_params: SamplingParams, out: TensorDictBase, ) -> TensorDictBase: """Generate text from text input.""" # Convert text to list format if isinstance(text, str): text = [text] elif not isinstance(text, list): text = text.tolist() assert isinstance( text, (str, list) ), f"text must be str or list, got {type(text)}" assert isinstance( sampling_params, SamplingParams ), f"sampling_params must be SamplingParams, got {type(sampling_params)}" assert isinstance( out, TensorDictBase ), f"out must be TensorDictBase, got {type(out)}" generate_kwargs = {"sampling_params": sampling_params} args = () # Convert text to list format if isinstance(text, str): text = [text] elif not isinstance(text, list): text = text.tolist() if not self._remote_calls: request_output = self.model.generate(text, *args, **generate_kwargs) else: import ray request_output = ray.get( self.model.generate.remote(text, *args, **generate_kwargs) ) request_output_tc = _RequestOutput_tc.from_request_output(request_output) # Extract response tokens and text outputs = ( request_output_tc.outputs.view(-1) if self.num_samples is not None else request_output_tc.outputs ) if self.pad_output: response_tokens_padded = outputs.view(-1).get( "token_ids", as_padded_tensor=self.pad_output, padding_value=self.padding_value, padding_side="right", ) response_tokens_list = outputs.view(-1).get( "token_ids", as_list=True, ) self._check_not_padded(response_tokens_list) if self.tokenizer is not None: response_text = self.tokenizer.batch_decode( response_tokens_list, skip_special_tokens=False ) else: response_text = None # Build output TensorClass objects masks_obj = Masks._from_tensordict(out.empty()) masks_obj.all_attention_mask = None masks_obj.all_assistant_mask = None masks_obj.padded = MetaData(self.pad_output) out.set(self.masks_key, masks_obj) if self.num_samples is not None: text = [txt for txt in text for _ in range(self.num_samples)] text_obj = Text._from_tensordict(out.empty()) with text_obj.view(-1) as text_obj_flat: text_obj_flat.prompt = text text_obj_flat.response = response_text text_obj_flat.full = self._cat_text(text, response_text) out.set(self.text_key, text_obj) tokens_obj = Tokens._from_tensordict(out.empty()) with tokens_obj.view(-1) as tokens_obj_flat: tokens_obj_flat.prompt = None # We don't have prompt tokens in this path if self.pad_output: tokens_obj_flat.response = response_tokens_padded self._check_padded(response_tokens_padded) else: tokens_obj_flat.response = response_tokens_list self._check_not_padded(response_tokens_list) tokens_obj_flat.full = ( None # we don't have prompt tokens in this path so no all_tokens either ) tokens_obj.padded = MetaData(self.pad_output) out.set(self.tokens_key, tokens_obj) if self.return_log_probs: log_probs_obj = LogProbs._from_tensordict(out.empty()) with log_probs_obj.view(-1) as log_probs_obj_flat: if self.pad_output: log_probs_padded = outputs.get( "logprobs", as_padded_tensor=self.pad_output, padding_value=self.padding_value, padding_side="right", ) self._check_padded(log_probs_padded) log_probs_obj_flat.response = log_probs_padded log_probs_obj_flat.full = log_probs_padded else: log_probs_list = outputs.get( "logprobs", as_list=True, ) self._check_not_padded(log_probs_list) log_probs_obj_flat.response = log_probs_list log_probs_obj_flat.full = log_probs_list log_probs_obj_flat.prompt = None log_probs_obj.padded = MetaData(self.pad_output) out.set(self.log_probs_key, log_probs_obj) return out def _logprobs_from_text( self, text: str | list[str] | NonTensorStack, sampling_params: SamplingParams, out: TensorDictBase, ) -> TensorDictBase: """Compute log-probs from text input.""" # Convert text to list format if isinstance(text, str): text = [text] elif not isinstance(text, list): text = text.tolist() assert isinstance( text, (str, list) ), f"text must be str or list, got {type(text)}" assert isinstance( sampling_params, SamplingParams ), f"sampling_params must be SamplingParams, got {type(sampling_params)}" assert isinstance( out, TensorDictBase ), f"out must be TensorDictBase, got {type(out)}" # Tokenize the text if self.tokenizer is None: raise ValueError( "Tokenizer is required for log-probs computation with text input" ) # Tokenize the text tokenized_output = self.tokenizer(text, **self.tokenizer_kwargs) if self.pad_output: tokens_full_padded = tokenized_output["input_ids"] attention_mask_full_padded = tokenized_output["attention_mask"] tokens_full_list = self._to_list( tokens_full_padded, attention_mask_full_padded ) else: tokens_full_unpadded = tokenized_output["input_ids"] tokens_full_list = self._to_list(tokens_full_unpadded, None) attention_mask_full_unpadded = tokenized_output["attention_mask"] attention_mask_full_unpadded = [ am.bool() if isinstance(am, torch.Tensor) else torch.tensor(am, dtype=torch.bool) for am in attention_mask_full_unpadded ] # Convert to list format for vLLM generate_kwargs = { "sampling_params": sampling_params, "prompt_token_ids": tokens_full_list, } # Generate with vLLM to get prompt_logprobs if not self._remote_calls: request_output = self.model.generate(**generate_kwargs) else: import ray request_output = ray.get(self.model.generate.remote(**generate_kwargs)) request_output_tc = _RequestOutput_tc.from_request_output(request_output) # Extract log-probs from prompt_logprobs if self.pad_output: # For padded case, use all prompt_logprobs log_probs_full_padded = request_output_tc.get( "prompt_logprobs", as_padded_tensor=True, padding_value=0, padding_side="left", ) # Mask out padding attention_mask_full_padded = tokens_full_padded != self.padding_value log_probs_full_padded = torch.where( attention_mask_full_padded, log_probs_full_padded, 0.0 ) else: # For unpadded case, extract from each sequence log_probs_full_unpadded = request_output_tc.get( "prompt_logprobs", as_list=True ) self._check_not_padded(log_probs_full_unpadded) masks_obj = Masks._from_tensordict( TensorDict(batch_size=out.batch_size).to_lazystack(0) ) if self.pad_output: self._check_padded(attention_mask_full_padded) masks_obj.all_attention_mask = attention_mask_full_padded.bool() else: self._check_not_padded(attention_mask_full_unpadded) masks_obj.all_attention_mask = attention_mask_full_unpadded masks_obj.padded = MetaData(self.pad_output) out.set(self.masks_key, masks_obj) # Build output TensorClass objects text_obj = Text._from_tensordict( TensorDict(batch_size=out.batch_size).to_lazystack(0) ) text_obj.prompt = None text_obj.response = None text_obj.full = text out.set(self.text_key, text_obj) tokens_obj = Tokens._from_tensordict( TensorDict(batch_size=out.batch_size).to_lazystack(0) ) if self.pad_output: self._check_padded(tokens_full_padded) tokens_obj.full = tokens_full_padded else: tokens_obj.full = tokens_full_unpadded tokens_obj.response = None tokens_obj.padded = MetaData(self.pad_output) out.set(self.tokens_key, tokens_obj) if self.return_log_probs: log_probs_obj = LogProbs._from_tensordict( TensorDict(batch_size=out.batch_size).to_lazystack(0) ) if self.pad_output: self._check_padded(log_probs_full_padded) log_probs_obj.full = log_probs_full_padded else: self._check_not_padded(log_probs_full_unpadded) log_probs_obj.full = log_probs_full_unpadded log_probs_obj.response = None log_probs_obj.padded = MetaData(self.pad_output) out.set(self.log_probs_key, log_probs_obj) return out def _cat_tensors( self, tokens: list[torch.Tensor] | torch.Tensor, response_tokens: list[torch.Tensor] | torch.Tensor, ) -> list[torch.Tensor] | torch.Tensor: """Concatenate tokens and response tokens.""" if isinstance(tokens, list) or isinstance(response_tokens, list): return [ self._cat_tensors(t, t_) for t, t_ in _zip_strict(tokens, response_tokens) ] else: return torch.cat([tokens, response_tokens], dim=-1) def _generate_from_tokens( self, tokens_prompt_unpadded: list[torch.Tensor] | None, tokens_prompt_padded: torch.Tensor | None, sampling_params: SamplingParams, out: TensorDictBase, ) -> TensorDictBase: """Generate text from tokens input.""" assert isinstance( tokens_prompt_padded, (torch.Tensor, type(None)) ), f"tokens_prompt_padded must be torch.Tensor or None, got {type(tokens_prompt_padded)}" assert isinstance( tokens_prompt_unpadded, (list, type(None)) ), f"tokens_prompt_unpadded must be list or None, got {type(tokens_prompt_unpadded)}" assert isinstance( sampling_params, SamplingParams ), f"sampling_params must be SamplingParams, got {type(sampling_params)}" assert isinstance( out, TensorDictBase ), f"out must be TensorDictBase, got {type(out)}" generate_kwargs = {"sampling_params": sampling_params} args = () if tokens_prompt_unpadded is None: # TODO: To be on the safe side, we may do this even in the unpadded case since we're not sure # the user passed an unpadded tensor in the first place. tokens_prompt_list = self._to_list( tokens_prompt_padded, tokens_prompt_padded != self.padding_value ) else: tokens_prompt_list = self._to_list(tokens_prompt_unpadded, None) generate_kwargs.update({"prompt_token_ids": tokens_prompt_list}) if not self._remote_calls: request_output = self.model.generate(*args, **generate_kwargs) else: import ray request_output = ray.get( self.model.generate.remote(*args, **generate_kwargs) ) request_output_tc = _RequestOutput_tc.from_request_output(request_output) # Extract response tokens and text outputs = ( request_output_tc.outputs.view(-1) if self.num_samples is not None else request_output_tc.outputs ) if self.pad_output: tokens_response_padded = outputs.get( "token_ids", as_padded_tensor=self.pad_output, padding_value=self.padding_value, padding_side="right", ) self._check_padded(tokens_response_padded) tokens_response_unpadded = outputs.get( "token_ids", as_list=True, ) self._check_not_padded(tokens_response_unpadded) tokens_obj = Tokens._from_tensordict(out.empty()) if self.pad_output: self._check_padded(tokens_response_padded) self._check_padded(tokens_prompt_padded) else: self._check_not_padded(tokens_response_unpadded) self._check_not_padded(tokens_prompt_unpadded) if self.num_samples is not None: # replicate tokens for i in range(self.num_samples): tokens_obj[:, i].prompt = ( tokens_prompt_unpadded if not self.pad_output else tokens_prompt_padded ) else: tokens_obj.prompt = ( tokens_prompt_unpadded if not self.pad_output else tokens_prompt_padded ) with tokens_obj.view(-1) as tokens_obj_flat: if self.pad_output: tokens_obj_flat.response = tokens_response_padded tokens_full_padded = self._cat_tensors( tokens_obj_flat.prompt, tokens_response_padded ) tokens_obj_flat.full = tokens_full_padded else: tokens_obj_flat.response = tokens_response_unpadded tokens_full_unpadded = self._cat_tensors( tokens_obj_flat.get("prompt", as_list=True), tokens_response_unpadded, ) tokens_obj_flat.full = tokens_full_unpadded tokens_obj.padded = MetaData(self.pad_output) out.set(self.tokens_key, tokens_obj) masks_obj = Masks._from_tensordict(out.empty()) # self.return_tokens must be True if self.pad_output: # Get "real" attention masks full_attention_mask_padded = tokens_obj.get("full") != self.padding_value masks_obj.all_attention_mask = full_attention_mask_padded.bool() else: # Get "real" attention masks # We can use select to avoid batch-size problems _td = torch.ones_like( out.select(("tokens", "full")) .copy() .rename_key_(("tokens", "full"), "all_attention_mask") ).bool() del _td["tokens"] masks_obj.update(_td) masks_obj.all_assistant_mask = None masks_obj.padded = MetaData(self.pad_output) out.set(self.masks_key, masks_obj) if self.return_log_probs: if self.pad_output: log_probs_padded = outputs.get( "logprobs", as_padded_tensor=self.pad_output, padding_value=self.padding_value, padding_side="right", ) else: log_probs_list = outputs.get( "logprobs", as_list=True, ) self._check_not_padded(log_probs_list) if self.num_samples is None: # TODO: this is not correct, we should use the prompt_logprobs # but they're not returned by vLLM if self.pad_output: prompt_logprobs_padded = request_output_tc.get( "prompt_logprobs", as_padded_tensor=self.pad_output, padding_value=self.padding_value, padding_side="right", ) else: prompt_logprobs_list = request_output_tc.get( "prompt_logprobs", as_list=True, ) self._check_not_padded(prompt_logprobs_list) log_probs_obj = LogProbs._from_tensordict(out.empty()) if self.pad_output: self._check_padded(log_probs_padded) if self.num_samples is None: self._check_padded(prompt_logprobs_padded) log_probs_obj.prompt = prompt_logprobs_padded else: self._check_not_padded(log_probs_list) if self.num_samples is None: self._check_not_padded(prompt_logprobs_list) log_probs_obj.prompt = prompt_logprobs_list with log_probs_obj.view(-1) as log_probs_obj_flat: log_probs_obj_flat.response = ( log_probs_padded if self.pad_output else log_probs_list ) if self.num_samples is None: if self.pad_output: log_probs_obj_flat.full = self._cat_tensors( log_probs_obj_flat.prompt, log_probs_padded ) else: log_probs_obj_flat.full = self._cat_tensors( log_probs_obj_flat.get("prompt", as_list=True), log_probs_list, ) else: log_probs_obj_flat.full = None log_probs_obj.padded = MetaData(self.pad_output) out.set(self.log_probs_key, log_probs_obj) return out def _logprobs_from_tokens( self, *, response_struct: TensorDictBase | None = None, tokens_full_unpadded: list[torch.Tensor] | None = None, tokens_full_padded: torch.Tensor | None = None, sampling_params: SamplingParams | None = None, out: TensorDictBase | None = None, ) -> TensorDictBase: """Compute log-probs from tokens input.""" assert isinstance( response_struct, (TensorDictBase, type(None)) ), f"response_struct must be TensorDictBase or None, got {type(response_struct)}" assert isinstance( tokens_full_unpadded, (list, type(None)) ), f"tokens_full_unpadded must be list or None, got {type(tokens_full_unpadded)}" assert isinstance( tokens_full_padded, (torch.Tensor, type(None)) ), f"tokens_full_padded must be torch.Tensor or None, got {type(tokens_full_padded)}" assert isinstance( sampling_params, (SamplingParams, type(None)) ), f"sampling_params must be SamplingParams or None, got {type(sampling_params)}" assert isinstance( out, (TensorDictBase, type(None)) ), f"out must be TensorDictBase or None, got {type(out)}" # Convert to list format for vLLM if response_struct is not None: tokens_full_padded = response_struct.get( "input_ids", as_padded_tensor=True, padding_value=self.padding_value, padding_side="left", ) attention_mask_full_padded = response_struct.get( "attention_mask", as_padded_tensor=True, padding_value=False, padding_side="left", ).bool() attention_mask_full_unpadded = _unpad_tensors( attention_mask_full_padded, attention_mask_full_padded, as_nested=False ) elif tokens_full_unpadded is not None: tokens_full_padded = pad_sequence( tokens_full_unpadded, padding_value=self.padding_value, batch_first=True, padding_side="left", ) attention_mask_full_unpadded = [ t != self.padding_value for t in tokens_full_unpadded ] attention_mask_full_padded = pad_sequence( attention_mask_full_unpadded, padding_value=False, batch_first=True, padding_side="left", ) elif tokens_full_padded is not None: attention_mask_full_padded = tokens_full_padded != self.padding_value else: raise ValueError("Either response_struct or tokens must be provided") assert isinstance(tokens_full_padded, torch.Tensor) assert isinstance(attention_mask_full_padded, torch.Tensor) if tokens_full_unpadded is None: tokens_full_list = self._to_list( tokens_full_padded, attention_mask_full_padded ) else: tokens_full_list = self._to_list(tokens_full_unpadded, None) generate_kwargs = { "sampling_params": sampling_params, "prompt_token_ids": tokens_full_list, } # Generate with vLLM to get prompt_logprobs if not self._remote_calls: tokens_out_stuct = self.model.generate(**generate_kwargs) else: import ray tokens_out_stuct = ray.get(self.model.generate.remote(**generate_kwargs)) request_output_tc = _RequestOutput_tc.from_request_output(tokens_out_stuct) # Extract log-probs from prompt_logprobs if self.pad_output: # For padded case, use all prompt_logprobs log_probs_full_padded = request_output_tc.get( "prompt_logprobs", as_padded_tensor=True, padding_value=0, padding_side="left", ) # Mask out padding attention_mask_full_padded = tokens_full_padded != self.padding_value log_probs_full_padded = torch.where( attention_mask_full_padded, log_probs_full_padded, 0.0 ) else: # For unpadded case, extract from each sequence log_probs_full_unpadded = request_output_tc.get( "prompt_logprobs", as_list=True ) self._check_not_padded(log_probs_full_unpadded) assistant_mask_full_padded = None if response_struct is not None: assistant_mask_full_padded = response_struct.get( "assistant_masks", as_padded_tensor=True, padding_side="left", padding_value=0, ) if assistant_mask_full_padded is not None: assistant_mask_full_padded = assistant_mask_full_padded.bool() if not self.pad_output: assistant_mask_full_unpadded = _unpad_tensors( assistant_mask_full_padded, attention_mask_full_padded, as_nested=False, ) else: assistant_mask_full_unpadded = None else: assistant_mask_full_unpadded = None masks_obj = Masks._from_tensordict( TensorDict(batch_size=out.batch_size).to_lazystack(0) ) if self.pad_output: self._check_padded(attention_mask_full_padded) masks_obj.all_attention_mask = attention_mask_full_padded.bool() if assistant_mask_full_padded is not None: masks_obj.all_assistant_mask = assistant_mask_full_padded else: self._check_not_padded(attention_mask_full_unpadded) masks_obj.all_attention_mask = attention_mask_full_unpadded if assistant_mask_full_unpadded is not None: masks_obj.all_assistant_mask = assistant_mask_full_unpadded masks_obj.padded = MetaData(self.pad_output) out.set(self.masks_key, masks_obj) tokens_obj = Tokens._from_tensordict( TensorDict(batch_size=out.batch_size).to_lazystack(0) ) if self.pad_output: self._check_padded(tokens_full_padded) tokens_obj.full = tokens_full_padded else: tokens_obj.full = tokens_full_unpadded tokens_obj.response = None tokens_obj.padded = MetaData(self.pad_output) out.set(self.tokens_key, tokens_obj) log_probs_obj = LogProbs._from_tensordict( TensorDict(batch_size=out.batch_size).to_lazystack(0) ) if self.pad_output: self._check_padded(log_probs_full_padded) log_probs_obj.full = log_probs_full_padded else: self._check_not_padded(log_probs_full_unpadded) log_probs_obj.full = log_probs_full_unpadded log_probs_obj.response = None log_probs_obj.padded = MetaData(self.pad_output) out.set(self.log_probs_key, log_probs_obj) return out def _to_list( self, tokens_padded: torch.Tensor | list[torch.Tensor], attention_mask_padded: torch.Tensor | None, ) -> list[list[int]]: """Converts a tensor of integers into a masked list (of lists) of integers.""" if isinstance(tokens_padded, torch.Tensor): parent = [] queue = collections.deque() if attention_mask_padded is None: attention_mask_padded = torch.ones_like(tokens_padded) queue.append((tokens_padded, attention_mask_padded.bool(), parent)) while queue: token_tensor, attention_mask_bool, _parent = queue.popleft() if token_tensor.ndim == 1: _parent.extend(token_tensor[attention_mask_bool].tolist()) else: _parent.extend([[] for _ in range(token_tensor.shape[0])]) queue.extend( [ (t, m, local_parent) for t, m, local_parent in zip( token_tensor, attention_mask_bool, _parent ) ] ) tokens_list = parent elif isinstance(tokens_padded, list): parent = [] queue = collections.deque() queue.append((tokens_padded, parent)) while queue: tokens_list, _parent = queue.popleft() if isinstance(tokens_list, list) and isinstance( tokens_list[0], (list, torch.Tensor) ): _parent.extend([[] for _ in tokens_list]) queue.extend( [ (t, local_parent) for t, local_parent in zip(tokens_list, _parent) ] ) continue elif isinstance(tokens_list, torch.Tensor): tokens_list = tokens_list.tolist() _parent.extend(tokens_list) tokens_list = parent return tokens_list @_classproperty def CompletionOutput_tc(cls): if vllm is None: raise ImportError("vllm is required for CompletionOutput_tc") if hasattr(cls, "_CompletionOutput_tc"): return cls._CompletionOutput_tc CompletionOutput_tc = from_dataclass(vllm.outputs.CompletionOutput) # type: ignore cls._CompletionOutput_tc = CompletionOutput_tc return CompletionOutput_tc
[docs] def get_dist( self, tensordict: TensorDictBase, tensordict_out: TensorDictBase | None = None, logits_key: NestedKey = "logits", mask_key: NestedKey | None = None, as_padded_tensor: bool | None = None, as_nested_tensor: bool | None = None, padding_value: float | None = None, padding_side: str = "right", layout: torch.layout | None = None, **kwargs, ) -> D.Distribution: """Get distribution from logits/log-probs with optional masking. vLLM does not return logits, so this method is not supported. """ raise NotImplementedError( "vLLM does not return logits, so get_dist is not supported" )
[docs] def get_dist_with_prompt_mask( self, tensordict: TensorDictBase, tokens_key: NestedKey = ("tokens", "full"), logits_key: NestedKey = "logits", assistant_mask_key: NestedKey = ("masks", "all_assistant_mask"), attention_mask_key: NestedKey = ("masks", "all_attention_mask"), **kwargs, ) -> D.Distribution: """Get distribution masked to only include response tokens (exclude prompt). vLLM does not return logits, so this method is not supported. This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy. """ raise NotImplementedError( "vLLM does not return logits, so get_dist_with_prompt_mask is not supported" )
def _get_dist_with_assistant_mask( self, tensordict: TensorDictBase, assistant_mask_key: NestedKey = ("masks", "all_assistant_mask"), logits_key: NestedKey = "logits", **kwargs, ) -> D.Distribution: """Get distribution masked to only include assistant tokens. vLLM does not return logits, so this method is not supported. This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy. """ raise NotImplementedError( "vLLM does not return logits, so get_dist_with_assistant_mask is not supported" ) def _get_dist_with_attention_mask( self, tensordict: TensorDictBase, attention_mask_key: NestedKey = ("masks", "all_attention_mask"), logits_key: NestedKey = "logits", **kwargs, ) -> D.Distribution: """Get distribution masked using attention mask. vLLM does not return logits, so this method is not supported. This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy. """ raise NotImplementedError( "vLLM does not return logits, so get_dist_with_attention_mask is not supported" ) def _get_dist_with_custom_mask( self, tensordict: TensorDictBase, mask: torch.Tensor, logits_key: NestedKey = "logits", **kwargs, ) -> D.Distribution: """Get distribution with custom mask. vLLM does not return logits, so this method is not supported. This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy. """ raise NotImplementedError( "vLLM does not return logits, so get_dist_with_custom_mask is not supported" ) # Convenience methods for common LLM training scenarios def _get_sft_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution: """Get distribution suitable for SFT loss (response tokens only). vLLM does not return logits, so this method is not supported. This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy. """ raise NotImplementedError( "vLLM does not return logits, so get_sft_dist is not supported" ) def _get_rlhf_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution: """Get distribution suitable for RLHF loss (assistant tokens only). vLLM does not return logits, so this method is not supported. This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy. """ raise NotImplementedError( "vLLM does not return logits, so get_rlhf_dist is not supported" ) def _get_generic_dist(self, tensordict: TensorDictBase, **kwargs) -> D.Distribution: """Get distribution suitable for generic losses (all tokens). vLLM does not return logits, so this method is not supported. This is a provisional method that will be replaced by the `get_dist` method once we have a better masking strategy. """ raise NotImplementedError( "vLLM does not return logits, so get_generic_dist is not supported" )
class _RequestOutput_tc(TensorClass["nocast"]): """TensorClass wrapper for vLLM RequestOutput.""" request_id: str prompt: str prompt_token_ids: torch.Tensor prompt_logprobs: torch.Tensor outputs: list # type: ignore finished: str metrics: str lora_request: str encoder_prompt: str encoder_prompt_token_ids: str num_cached_tokens: torch.Tensor def __post_init__(self): CompletionOutput_tc = vLLMWrapper.CompletionOutput_tc def postproc(output): def get_logprob(output): t = [] for v, tid in zip(output.logprobs, output.token_ids): t.append( v[int(tid)]["logprob"] if v[tid].get("logprob") is not None else 0.0 ) return torch.tensor(t) if output.logprobs: output.logprobs = get_logprob(output) output.token_ids = torch.as_tensor(output.token_ids) return output if isinstance(self.outputs, list): outputs = self.outputs outputs = [ postproc(from_dataclass(output, dest_cls=CompletionOutput_tc)) for output in outputs ] if len(outputs) == 1: self.outputs = outputs[0] else: # Check if we can stack the outputs (they should have the same shape) try: self.outputs = lazy_stack(outputs) except RuntimeError: # If stacking fails (different sizes), keep as list self.outputs = outputs @classmethod def from_request_output( cls, requests: RequestOutput | list[RequestOutput] ) -> _RequestOutput_tc | list[_RequestOutput_tc]: """Create _RequestOutput_tc from vLLM RequestOutput.""" # Type assertions assert isinstance( requests, (RequestOutput, list) ), f"requests must be RequestOutput or list, got {type(requests)}" # Check if we can stack the outputs try: out = lazy_stack( [ cls( request_id=request.request_id, prompt=request.prompt, prompt_token_ids=torch.as_tensor(request.prompt_token_ids), prompt_logprobs=torch.tensor( [ v[int(tid)].logprob if v is not None else 0.0 for v, tid in _zip_strict( request.prompt_logprobs, request.prompt_token_ids ) ] ) if request.prompt_logprobs is not None else torch.tensor([]), outputs=request.outputs, finished=request.finished, metrics=request.metrics, lora_request=request.lora_request, encoder_prompt=request.encoder_prompt, encoder_prompt_token_ids=request.encoder_prompt_token_ids, num_cached_tokens=torch.as_tensor(request.num_cached_tokens), ) for request in requests ] ) return out except RuntimeError: # If stacking fails, return a list of individual _RequestOutput_tc objects return [ cls( request_id=request.request_id, prompt=request.prompt, prompt_token_ids=torch.as_tensor(request.prompt_token_ids), prompt_logprobs=torch.tensor( [ v[int(tid)].logprob if v is not None else 0.0 for v, tid in _zip_strict( request.prompt_logprobs, request.prompt_token_ids ) ] ) if request.prompt_logprobs is not None else torch.tensor([]), outputs=request.outputs, finished=request.finished, metrics=request.metrics, lora_request=request.lora_request, encoder_prompt=request.encoder_prompt, encoder_prompt_token_ids=request.encoder_prompt_token_ids, num_cached_tokens=torch.as_tensor(request.num_cached_tokens), ) for request in requests ]

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources