Shortcuts

LLM Interface

TorchRL provides a comprehensive framework for LLM post-training and fine-tuning. The LLM API is built around five core concepts that work together to create a complete reinforcement learning pipeline for language models:

  1. Data Representation (Data Structures): The foundation for handling conversations, text parsing, and LLM output classes. This includes the History class for managing conversation context and structured output classes for tokens, log-probabilities, and text.

  2. LLM Wrapper API (Modules): Unified interfaces for different LLM backends, including TransformersWrapper for Hugging Face models and vLLMWrapper for vLLM inference. These wrappers provide consistent input/output formats across different backends and an integrated interface for loss computation, data storage, grading, weight synchronization, etc.

  3. Environments (Environments): The orchestration layer that manages data loading, tool execution, reward computation, and formatting. This includes ChatEnv for conversation management, dataset environments, and various transforms for tool integration.

  4. Objectives (Objectives): Specialized loss functions for LLM training, including GRPOLoss for Group Relative Policy Optimization and SFTLoss for supervised fine-tuning.

  5. Collectors (Collectors): Collectors are used to collect data from the environment and store it in a format that can be used for training. This includes LLMCollector for collecting data from the environment and RayLLMCollector for collecting data in distributed settings using Ray.

These components work together to create a complete pipeline: environments load and format data, LLM wrappers handle inference, data structures maintain conversation context, and objectives compute training losses. The modular design allows you to mix and match components based on your specific use case.

A complete example of how to use the LLM API can be found in the sota-implementations/grpo/ directory. The training orchestration involves three main components:

  • The Data Collector: holds a reference to the environment and the inference model or engine. It collects data, puts it in the buffer, and handles weight updates.

  • The Replay Buffer: stores the collected data and executes any pre or post-processing steps. These may include: - Advantage estimation with Monte-Carlo based method (using the MCAdvantage transform); - Grading of the outputs; - Logging etc.

  • The trainer: handles the training loop, including the optimization step, serialization, logging and weight updates initialization.

Warning

The LLM API is still under development and may change in the future. Feedback, issues and PRs are welcome!

Data Structures

The data representation layer provides the foundation for handling conversations and LLM outputs in a structured way.

History Class

The History class is a TensorClass version of the chat format usually found in transformers (see Hugging Face chat documentation). It provides a comprehensive API for managing conversation data with features including:

  • Text parsing and formatting: Convert between text and structured conversation format using from_text() and apply_chat_template()

  • Dynamic conversation building: Append and extend conversations with append() and extend() methods

  • Multi-model support: Automatic template detection for various model families (Qwen, DialoGPT, Falcon, DeepSeek, etc.)

  • Assistant token masking: Identify which tokens were generated by the assistant for reinforcement learning applications

  • Tool calling support: Handle function calls and tool responses in conversations

  • Batch operations: Efficient tensor operations for processing multiple conversations simultaneously.

History(role, content[, is_complete, ...])

ContentBase(type, text, url, data, ...[, ...])

Supported Model Families

We currently support the following model families for string to History parsing or assistant token masking:

  • Qwen family (e.g., Qwen/Qwen2.5-0.5B): Custom template with full tool calling support

  • DialoGPT family (e.g., microsoft/DialoGPT-medium): Custom template for conversation format

  • Falcon family (e.g., tiiuae/falcon-7b-instruct): Custom template for instruction format

  • DeepSeek family (e.g., deepseek-ai/deepseek-coder-6.7b-base): Custom template with native format

Other models are supported, but you will need to provide a custom template for them. LLAMA, Mistral, OPT, GPT, MPT, BLOOM, Pythia, Phi, etc. will use the default chatml_format template.

Usage

>>> from torchrl.data.llm.chat import History
>>> from transformers import AutoTokenizer
>>>
>>> # Create a conversation history
>>> history = History.from_chats([[
...     {"role": "user", "content": "Hello"},
...     {"role": "assistant", "content": "Hi there!"},
...     {"role": "user", "content": "How are you?"},
...     {"role": "assistant", "content": "I'm doing well, thanks!"}
... ]])
>>>
>>> # Load any supported tokenizer
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
>>>
>>> # Apply chat template with assistant token masking
>>> result = history.apply_chat_template(
...     chat_template_name="qwen",
...     add_generation_prompt=False,
...     return_dict=True,
...     return_assistant_tokens_mask=True,
... )
>>>
>>> # The result contains an assistant_masks tensor
>>> assistant_masks = result["assistant_masks"]
>>> print(f"Assistant tokens: {assistant_masks.sum().item()}")

Adding Custom Templates

You can add custom chat templates for new model families using the torchrl.data.llm.chat.add_chat_template() function.

Usage Examples

Adding a Llama Template
>>> from torchrl.data.llm.chat import add_chat_template, History
>>> from transformers import AutoTokenizer
>>>
>>> # Define the Llama chat template
>>> llama_template = '''
... {% for message in messages %}
... {%- if message['role'] == 'user' %}
... {{ '<s>[INST] ' + message['content'] + ' [/INST]' }}
... {%- elif message['role'] == 'assistant' %}
... {% generation %}{{ message['content'] + '</s>' }}{% endgeneration %}
... {%- endif %}
... {% endfor %}
... {%- if add_generation_prompt %}
... {% generation %}{{ ' ' }}{% endgeneration %}
... {%- endif %}
... '''
>>>
>>> # Define the inverse parser for Llama format
>>> def parse_llama_text(text: str) -> History:
...     import re
...     pattern = r'<s>\[INST\]\s*(.*?)\s*\[/INST\]\s*(.*?)</s>'
...     matches = re.findall(pattern, text, re.DOTALL)
...     messages = []
...     for user_content, assistant_content in matches:
...         messages.append(History(role="user", content=user_content.strip()))
...         messages.append(History(role="assistant", content=assistant_content.strip()))
...     return lazy_stack(messages)
>>>
>>> # Add the template with auto-detection
>>> add_chat_template(
...     template_name="llama",
...     template=llama_template,
...     inverse_parser=parse_llama_text,
...     model_family_keywords=["llama", "meta-llama"]
... )
>>>
>>> # Now you can use it with auto-detection
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
>>> history = History.from_chats([[
...     {"role": "user", "content": "Hello"},
...     {"role": "assistant", "content": "Hi there!"}
... ]])
>>>
>>> # Auto-detection will use the llama template
>>> result = history.apply_chat_template(
...     tokenizer=tokenizer,
...     add_generation_prompt=False,
...     return_dict=True,
...     return_assistant_tokens_mask=True,
... )

Testing Your Custom Templates

When adding custom templates, you should test them to ensure they work correctly. Here are the recommended tests:

Assistant Token Masking Test

Test that your template supports assistant token masking:

import pytest
from torchrl.data.llm.chat import History, add_chat_template
from transformers import AutoTokenizer

def test_my_model_assistant_masking():
    """Test that your model supports assistant token masking."""
    # Add your template first
    add_chat_template(
        template_name="my_model",
        template="your_template_here",
        model_family_keywords=["my_model"]
    )

    tokenizer = AutoTokenizer.from_pretrained("your/model/name")
    history = History.from_chats([[
        {'role': 'user', 'content': 'Hello'},
        {'role': 'assistant', 'content': 'Hi there!'}
    ]])

    result = history.apply_chat_template(
        tokenizer=tokenizer,
        chat_template_name="my_model",
        add_generation_prompt=False,
        return_dict=True,
        return_assistant_tokens_mask=True,
    )

    # Verify assistant mask is present
    assert 'assistant_masks' in result
    assert result['assistant_masks'].shape[0] == 1, "Should have batch dimension of 1"
    assert result['assistant_masks'].shape[1] > 0, "Should have sequence length > 0"

    # Verify some assistant tokens are masked
    assistant_token_count = result['assistant_masks'].sum().item()
    assert assistant_token_count > 0, "Should have assistant tokens masked"
    print(f"✓ {assistant_token_count} assistant tokens masked")
Template Equivalence Test

Test that your custom template produces the same output as the model’s default template (except for masking):

def test_my_model_template_equivalence():
    """Test that your template matches the model's default template."""
    tokenizer = AutoTokenizer.from_pretrained("your/model/name")
    history = History.from_chats([[
        {'role': 'user', 'content': 'Hello'},
        {'role': 'assistant', 'content': 'Hi there!'},
        {'role': 'user', 'content': 'How are you?'},
        {'role': 'assistant', 'content': 'I\'m good, thanks!'},
    ]])

    # Get output with model's default template
    try:
        default_out = history.apply_chat_template(
            tokenizer=tokenizer,
            add_generation_prompt=False,
            chat_template=tokenizer.chat_template,
            tokenize=False,
        )
    except Exception as e:
        default_out = None
        print(f"[WARN] Could not get default template: {e}")

    # Get output with your custom template
    custom_out = history.apply_chat_template(
        tokenizer=tokenizer,
        add_generation_prompt=False,
        chat_template_name="my_model",
        tokenize=False,
    )

    if default_out is not None:
        # Normalize whitespace for comparison
        import re
        def norm(s):
            return re.sub(r"\s+", " ", s.strip())

        assert norm(default_out) == norm(custom_out), (
            f"Custom template does not match default!\n"
            f"Default: {default_out}\nCustom: {custom_out}"
        )
        print("✓ Template equivalence verified")
    else:
        print("[INFO] Skipped equivalence check (no default template available)")
Inverse Parsing Test

If you provided an inverse parser, test that it works correctly:

def test_my_model_inverse_parsing():
    """Test that your inverse parser works correctly."""
    history = History.from_chats([[
        {'role': 'user', 'content': 'Hello'},
        {'role': 'assistant', 'content': 'Hi there!'}
    ]])

    # Format using your template
    formatted = history.apply_chat_template(
        tokenizer=tokenizer,
        chat_template_name="my_model",
        add_generation_prompt=False,
        tokenize=False,
    )

    # Parse back using your inverse parser
    parsed = History.from_text(formatted, chat_template_name="my_model")

    # Verify the parsing worked
    assert parsed.role == history.role
    assert parsed.content == history.content
    print("✓ Inverse parsing verified")

LLM Wrapper API

The LLM wrapper API provides unified interfaces for different LLM backends, ensuring consistent input/output formats across training and inference pipelines. The main wrappers are TransformersWrapper for Hugging Face models and vLLMWrapper for vLLM inference.

Data Structure Classes

The wrappers use structured TensorClass objects to represent different aspects of LLM data:

  • :class:`~torchrl.modules.llm.policies.Text`: Contains text data with prompt, response, and full fields

  • :class:`~torchrl.modules.llm.policies.ChatHistory`: Contains History objects with prompt, response, and full fields

  • :class:`~torchrl.modules.llm.policies.Tokens`: Contains tokenized data with prompt, response, and full fields

  • :class:`~torchrl.modules.llm.policies.LogProbs`: Contains log probabilities with prompt, response, and full fields

  • :class:`~torchrl.modules.llm.policies.Masks`: Contains attention and assistant masks

API Flow

The wrappers operate in two distinct modes:

Generation Mode (`generate=True`): - Input: Reads from prompt fields (e.g., history.prompt, text.prompt, tokens.prompt) - Output: Writes to both response and full fields

  • response: Contains only the newly generated content

  • full: Contains the complete sequence (prompt + response)

Log-Probability Mode (`generate=False`): - Input: Reads from full fields (e.g., history.full, text.full, tokens.full) - Output: Writes log probabilities to the corresponding full fields

LLM-Environment Interaction Loop

LLM-Environment interaction loop

LLM-Environment interaction: the LLM generates a response, the environment updates the conversation, and transforms can inject new messages or tools.

In a typical RL or tool-augmented setting, the LLM and environment interact in a loop:

  1. LLM Generation: The LLM wrapper receives a prompt (the current conversation history), generates a response, and outputs a full field

containing the concatenation of the prompt and response.

  1. Environment Step: The environment takes the full field and makes it the next prompt for the LLM. This ensures that the conversation

context grows with each turn. See ref_env_llm_step for more details.

  1. Transforms: Before the next LLM step, transforms can modify the conversation—for example, by inserting a new user message, a tool call,

or a reward annotation.

  1. Repeat: This process repeats for as many turns as needed, enabling multi-turn dialogue, tool use, and RL training.

This design allows for flexible augmentation of the conversation at each step, supporting advanced RL and tool-use scenarios.

A typical pseudocode loop:

# Get the first prompt out of an initial query
obs = env.reset(TensorDict({"query": "Hello!"}, batch_size=env.batch_size, device=env.device))
while not done:
    # LLM generates a response given the current prompt
    llm_output = llm(obs)
    # Environment steps: creates a ("next", "history") field with the new prompt (from the previous `"full"` field)
    obs = env.step(llm_output)

Integration with History

When using input_mode=”history”, the wrapper integrates seamlessly with the History class:

  • Input: Takes a ChatHistory object containing a History in the prompt field

  • Generation: Applies chat templates to convert History to tokens, generates response, then parses the full text back into a History object

  • Output: Returns a ChatHistory with: - prompt: Original conversation history - response: New History object containing only the assistant’s response - full: Complete conversation history with the new response appended

This design allows for natural conversation flow where each generation step extends the conversation history, making it ideal for multi-turn dialogue systems.

Prompt vs. Response and padding

LLM output data format (Tokens, Masks, Padded vs. Sparse)

Structure of LLM outputs: padded vs. sparse representations for Tokens, LogProbs, and Masks.

The diagram above illustrates the structure of the main output classes used in TorchRL’s LLM API:

  • Tokens (and by extension, LogProbs): - Padded format: All sequences in a batch are padded to the same length (with a special pad token), making them suitable for tensor operations. The prompt and response are concatenated to form tokens.full, and masks indicate valid vs. padded positions. - Sparse format: Each sequence retains its original length (no padding), represented as lists of tensors. This is more memory-efficient for variable-length data.

  • Masks: Two main masks are shown: - mask.attention_mask_all marks valid (non-pad) tokens. - mask.assistant_mask_all marks which tokens were generated by the assistant (useful for RLHF and SFT training).

  • Text: Not shown in detail, as it is simply the decoded string representation of the prompt, response, or full sequence.

This format ensures that all LLM outputs (Tokens, LogProbs, Masks, Text) are consistent and easy to manipulate, regardless of whether you use padded or sparse batching.

In general, we recommend working with unpadded data, as it is more memory-efficient and easier to manipulate. For instance, when collecting multiple padded elements from the buffer, it may be hard to clearly understand how to re-pad them to combine them in a cohesive batch. Working with unpadded data is more straightforward.

Modules

The LLM wrapper API provides unified interfaces for different LLM backends, ensuring consistent input/output formats across training and inference pipelines.

Wrappers

The main goal of these primitives is to:

  • Unify the input/output data format across training and inference pipelines

  • Unify the input/output data format across backends (to be able to use different backends across losses and collectors)

  • Provide appropriate tooling to construct these objects in typical RL settings (resource allocation, async execution, weight update, etc.)

LLMWrapperBase(*args, **kwargs)

A LLM wrapper base class.

TransformersWrapper(*args, **kwargs)

A wrapper class for Hugging Face Transformers models, providing a consistent interface for text generation and log probability computation.

vLLMWrapper(*args, **kwargs)

A wrapper class for vLLM models, providing a consistent interface for text generation and log probability computation.

ChatHistory([prompt, response, full, ...])

Text([prompt, response, full, device, names])

LogProbs([prompt, response, full, padded, ...])

Masks([all_attention_mask, ...])

Tokens([prompt, response, full, padded, ...])

Utils

LLMOnDevice(*args[, bundle_indices])

A thin wrapper around vllm.LLM to control its placement devices.

make_vllm_worker(*, model_name[, devices, ...])

Creates a vLLM inference engine with tensor parallelism support.

stateless_init_process_group(master_address, ...)

Initializes a stateless process group for distributed communication.

vLLMWorker(*args, **kwargs)

vLLM worker for Ray.

Collectors

TorchRL offers specialized collector classes (LLMCollector and RayLLMCollector) that are tailored for LLM use cases. We also provide dedicated updaters for some inference engines.

See ref_collectors for more details on the collector API. In brief, the idea of a collector is to isolate the inference part of the pipeline in a dedicated class. A collector usually takes as input a policy and an environment, and alternate between running one and the other. In “classical” settings, the policy is similar to the policy being trained (with some optional extra-exploration). In the context of LLM fine-tuning, the policy will usually be a specialized inference engine, such as a vLLM server. Collectors are defined by the following parameters and features:

  • Sync/Async: Whether the collector should run in sync or async mode. In sync mode, the collector will run the inference step in alternate with the optimization/training step. In async mode, the collector will run the inference step in parallel with the optimization/training step. A replay buffer can be passed to the collector, in such a way that the collector can directly write to it. In other cases, the collector can be iterated over to collect data.

  • Steps: A collector is built with a certain number of steps budget, as well as a number of steps to be included in each batch yield during collection.

  • Weight Updater: Weight updaters are the classes that update the policy weights. Isolating the weight update in a dedicated class allows to easily implement different weight update strategies depending on the policy specification.

Policy Version Tracking

LLM Collectors also allow to track the version of the policy, which is useful for some use cases. This is done by adding a PolicyVersion transform to the environment, which is then incremented by the collector after each weight update. To do this, one either provides the stateful version of the transform, or a boolean to the collector constructor.

>>> from torchrl.envs.llm.transforms import PolicyVersion
>>> from torchrl.collectors.llm import LLMCollector
>>> from torchrl.collectors.llm.weight_update import vLLMUpdater
>>> env = make_env() # place your code here
>>> policy = make_policy() # place your code here
>>> collector = LLMCollector(env, policy=policy, weight_updater=vLLMUpdater(), track_policy_version=True)
>>> # init the updater
>>> collector.weight_updater.init(...)
>>> # the version is incremented after each weight update
>>> collector.update_policy_weights_(state_dict=...)
>>> print(collector.policy_version_tracker.version)
>>> # the policy version is written in the data
>>> for data in collector:
...     print(data["policy_version"])

vLLMUpdater([master_address, master_port, ...])

A class that sends weights to vLLM workers.

LLMCollector(env, *[, policy, ...])

A simplified version of SyncDataCollector for LLM inference.

RayLLMCollector(env, *[, policy, ...])

A lightweight Ray implementation of the LLM Collector that can be extended and sampled remotely.

Environments

The environment layer orchestrates data loading, tool execution, reward computation, and formatting. When fine-tuning an LLM using TorchRL, the environment is a crucial component of the inference pipeline, alongside the policy and collector.

ChatEnv

ChatEnv serves as a blank canvas for LLM environments - it’s a basic tool designed to be extended with transforms that add specific functionality. The base ChatEnv provides the fundamental structure for managing conversation state using the History format, but it’s intentionally minimal to allow maximum flexibility.

Core Functionality

ChatEnv operates in three main modes: - History mode: Uses History objects for conversation management - Text mode: Uses simple text strings for input/output - Tokens mode: Uses tokenized data for input/output

The environment maintains conversation state by: - Reset: Initializes a new conversation with an optional system prompt - Step: Takes the LLM’s response and updates the conversation history, preparing the next prompt

Transform-Based Architecture

Transforms are the main way to extend ChatEnv with specific capabilities:

Integration with LLM Wrappers

ChatEnv is designed to work seamlessly with both TransformersWrapper and vLLMWrapper. The environment handles the conversation state management while the wrapper handles the actual LLM inference, creating a clean separation of concerns.

On each call to step, the environment:

  • Takes the LLM’s output, specifically the full field, which contains the entire conversation so far, including the new response (e.g., history.full, text.full, tokens.full).

  • Sets this full field as the new prompt for the next LLM step (e.g., td[“next”, “history”].prompt, td[“next”, “text”].prompt, td[“next”, “tokens”].prompt).

  • Optionally, applies transforms to insert new user messages, tool calls, or other modifications to the conversation before the next LLM step to refine the prompt.

This mechanism enables seamless multi-turn interactions and supports complex workflows such as tool use and reward shaping.

Task-Specific Environments

We provide a few task-specific environments, such as GSM8KEnv for the GSM8K dataset, IFEvalEnv for the IFEval dataset, and MLGymEnv for MLGym integration.

These environments wrap a ChatEnv and add a DataLoadingPrimer transform (plus an optional reward parsing transform) in a TransformedEnv class.

ChatEnv(*args, **kwargs)

A chat-based environment for LLMs, designed as a blank canvas for conversation and RL.

DatasetChatEnv(*args, **kwargs)

Base class for chat environment with queries pulled from a dataset.

GSM8KEnv(*args, **kwargs)

GSM8K dataset environment.

make_gsm8k_env([dataset, num_envs, repeats, ...])

A builder for an LLMEnv-based GSM8K environment.

GSM8KPrepareQuestion([in_keys, out_keys])

A transform to prepare the prompt when using GSM8k within an LLMEnv.

IFEvalEnv(*args, **kwargs)

A chat environment based on the IFEval dataset.

IfEvalScorer(*[, instruction_ids_key, ...])

Scorer for the IF-Eval task.

IFEvalScoreData(prompt_level_strict_acc, ...)

LLMEnv(*args, **kwargs)

A text generation environment for language models.

LLMHashingEnv(*args, **kwargs)

A text generation environment that uses a hashing module to identify unique observations.

make_mlgym(*[, task, tasks, tokenizer, ...])

Wraps an MLGymEnv in a TorchRL Environment.

MLGymWrapper(*args, **kwargs)

A thin wrapper for MLGym environments.

GSM8KRewardParser(tokenizer[, in_keys, ...])

Reward parser for GSM8KEnv or make_gsm8k_env.

Transforms

Transforms are used to modify the data before it is passed to the LLM. Tools are usually implemented as transforms, and appended to a base environment such as ChatEnv.

An example of a tool transform is the PythonInterpreter transform, which is used to execute Python code in the context of the LLM.

>>> from torchrl.envs.llm.transforms import PythonInterpreter
>>> from torchrl.envs.llm import ChatEnv
>>> from tensordict import TensorDict, set_list_to_stack
>>> from transformers import AutoTokenizer
>>> from pprint import pprint
>>> set_list_to_stack(True).set()
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
>>> base_env = ChatEnv(
...     tokenizer=tokenizer,
...     system_prompt="You are an assistant that can execute Python code. Decorate your code with ```python``` tags.",
...     user_role="user",
...     system_role="system",
...     batch_size=[1],
... )
>>> env = base_env.append_transform(PythonInterpreter())
>>> env.set_seed(0)
>>> # Pass the reset data - the prompt - to the environment
>>> reset_data = env.reset(TensorDict(
...     text="Let's write a Python function that returns the square of a number.",
...     batch_size=[1])
... )
>>> # Simulate an action - i.e., a response from the LLM (as if we were an LLM)
>>> action = """Here is a block of code to be executed in python:
... ```python
... def square(x):
...     return x * x
... print('testing the square function with input 2:', square(2))
... ```
... <|im_end|>
... """
>>> step_data = reset_data.set("text_response", [action])
>>> s, s_ = env.step_and_maybe_reset(reset_data)
>>> # The history is a stack of chat messages.
>>> #  The python interpreter transform has executed the code in the last message.
>>> pprint(s_["history"].apply_chat_template(tokenizer=tokenizer))
['<|im_start|>system\n'
 'You are an assistant that can execute Python code. Decorate your code with '
 '```python``` tags.<|im_end|>\n'
 '<|im_start|>user\n'
 "Let's write a Python function that returns the square of a "
 'number.<|im_end|>\n'
 '<|im_start|>assistant\n'
 'Here is a block of code to be executed in python:\n'
 '```python\n'
 'def square(x):\n'
 '    return x * x\n'
 "print('testing the square function with input 2:', square(2))\n"
 '```<|im_end|>\n'
 '<|im_start|>user\n'
 '<tool_response>\n'
 'Code block 1 executed successfully:\n'
 'testing the square function with input 2: 4\n'
 '\n'
 '</tool_response><|im_end|>\n'
 '<|im_start|>assistant\n']

Similarly, environments that load data from a dataset are just special instances of the ChatEnv augmented with a DataLoadingPrimer transforms (and some dedicated reward parsing transforms).

Designing Reward Transforms

When designing reward transforms for LLM environments, several key considerations must be addressed to ensure proper integration with the training pipeline. The examples of GSM8KRewardParser and IfEvalScorer provide excellent templates for reward transform design.

Reward Shape Requirements

The reward tensor must have the same number of dimensions as the logits, which is typically two more dimensions than the environment batch size:

  • Sparse rewards: Shape (*bsz, 1, 1) - single reward per sequence

  • Dense rewards: Shape (*bsz, num_tokens, 1) - per-token rewards

This shape requirement ensures compatibility with the loss computation pipeline. For example, in the GSM8K reward parser:

# Rewards need to have shape broadcastable to [batch x tokens x 1]
tds = tds.apply(lambda t: t.unsqueeze(-1).unsqueeze(-1))

Done State Management

It is crucial to properly manage the done state to prevent endless generation. Common strategies include:

  1. Completion-based termination: Set done when the response is complete (e.g., History.complete=True)

  2. Content-based termination: Set done when specific content is detected (e.g., <answer> blocks)

  3. Step-based termination: Use StepCounter for predetermined step limits

Example from IFEvalScorer:

if self.set_done_if_answer and bool(answer_blocks):
    next_tensordict.set("done", torch.ones(...))
    next_tensordict.set("terminated", torch.ones(...))

Input Mode Handling

Reward transforms must handle different input modes correctly:

  • History mode: Extract text from ("history", "full") or ("history", "response")

  • Text mode: Use text directly from ("text", "full") or ("text", "response")

  • Tokens mode: Decode tokens from ("tokens", "full") or ("tokens", "response")

The GSM8K reward parser demonstrates this pattern:

if input_mode == "history":
    responses = lazy_stack([r[..., -1] for r in responses.unbind(0)])
    if hasattr(responses, "content"):
        text_completion = responses.content
elif input_mode == "text":
    text_completion = responses
elif input_mode == "tokens":
    text_completion = self.tokenizer.decode(responses.flatten(0, 1).tolist())

Specification Management

Accurate specification of reward and observation specs is essential for proper environment initialization. Both GSM8K and IFEval provide good examples:

def transform_reward_spec(self, reward_spec: Composite) -> Composite:
    shape = reward_spec.shape + (1, 1)
    reward_spec.update(
        Composite(
            reward_answer=Unbounded(shape),
            reward_think=Unbounded(shape),
            reward_right=Unbounded(shape),
            reward_contained=Unbounded(shape),
            reward=Unbounded(shape),
            success=Unbounded(shape, dtype=torch.bool),
        )
    )
    return reward_spec

Batch Processing Considerations

For efficient processing, handle batched data appropriately:

  1. Flatten batch dimensions: Use tensordict.view(-1) for processing

  2. Reshape results: Restore original batch structure after processing

  3. Handle variable-length sequences: Use proper padding and masking

Reward Aggregation Strategies

Consider different reward aggregation approaches:

  1. Simple aggregation: Sum or average multiple reward components

  2. Weighted aggregation: Apply different weights to different components

  3. Conditional rewards: Base rewards on specific conditions or thresholds

The IFEvalScorer demonstrates a sophisticated aggregation strategy:

def default_reward_aggregator(self, score: IFEvalScoreData, ...):
    # Format score (max 1.0)
    format_score = (format_components * weights).sum(dim=-1, keepdim=True)

    # Structure score (max 1.0)
    structure_score = think_score + answer_score

    # Completion bonus (max 0.2)
    completion_bonus = float(complete) * 0.2

    return format_score + structure_score + completion_bonus

Post-Processing in Replay Buffers

Rewards can also be computed after the fact by appending transforms to the replay buffer. However, done state capture must remain in the environment transform since it needs to occur on-the-fly during data collection.

Error Handling and Robustness

Implement robust error handling for parsing failures:

try:
    cot, potential_answer = self.extract_tags(compl)
except ET.ParseError:
    cot, potential_answer = ("", "")

Performance Considerations

  1. Avoid redundant computations: Cache parsed results when possible

  2. Use efficient text processing: Leverage regex or XML parsing as appropriate

  3. Minimize memory allocations: Reuse tensors and avoid unnecessary copies

By following these design principles, reward transforms can be effectively integrated into the LLM training pipeline while maintaining performance and reliability.

AddThinkingPrompt(cond[, prompt, ...])

A transform that adds thinking prompts to encourage the LLM to reconsider its response.

BrowserTransform([allowed_domains, ...])

A transform that enables web browsing capabilities.

DataLoadingPrimer(dataloader, *[, primers, ...])

A primer that loads data from a dataloader and converts it into a tensordict using stack_method.

KLComputation([gen_log_probs_full_key, ...])

A transform to compute KL divergence between two log-prob tensors and optionally add it to the reward.

KLRewardTransform(ref_model, *[, coef, ...])

A legacy transform for computing KL divergence-based rewards.

MCPToolTransform(tools, tool_schemas[, ...])

A transform that executes MCP-style tools in response to LLM actions.

PolicyVersion(version_type, ] =)

A transform that keeps track of the version of the policy.

PythonInterpreter([tokenizer, tool_name, ...])

A transform that executes Python code in the LLM response.

RetrieveKL([gen_model, ref_model, ...])

A transform to retrieve the KL divergence between two models' log-probabilities.

RetrieveLogProb(model, *[, ...])

A transform to retrieve log-probabilities from a model for KL divergence computation.

TemplateTransform(tokenizer[, chat_template])

A transform that maps applies a chat template to an input string during the forward pass, and parses the strings to the template during backward.

Tokenizer([in_keys, out_keys, in_keys_inv, ...])

Applies a tokenization operation on the specified inputs.

as_nested_tensor(list_of_tensordicts)

Stacks a list of tensordicts into a single tensordict with nested tensors.

as_padded_tensor(list_of_tensordicts[, dim, ...])

Stacks a list of tensordicts into a single tensordict with padded tensors.

Objectives

LLM post-training requires specialized loss functions that are adapted to the unique characteristics of language models.

GRPO

The GRPOLoss class is a thin wrapper around the PPOLoss class that codes the LLM-specific functionalities.

GRPOLoss(*args, **kwargs)

GRPO loss.

GRPOLossOutput(loss_objective, ...[, ...])

MCAdvantage(grpo_size[, prompt_key, ...])

Monte-Carlo advantage computation engine.

SFT

SFTLoss(*args, **kwargs)

Supervised fine-tuning loss.

SFTLossOutput(loss_sft[, loss_kl_to_ref, ...])

TopKRewardSelector(total_dialog_turns, topk_size)

A replay-buffer transform that selects the top-k rewards for each prompt.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources