Shortcuts

LLM Interface

TorchRL provides a comprehensive framework for LLM post-training and fine-tuning. The LLM API is built around five core concepts that work together to create a complete reinforcement learning pipeline for language models.

Key Components

  1. Data Structures: History class for conversation management, structured output classes

  2. LLM Wrappers: Unified interfaces for Transformers, vLLM, and AsyncVLLM

  3. Environments: ChatEnv, task-specific environments, and transforms

  4. Collectors: LLMCollector and RayLLMCollector for data collection

  5. Objectives: GRPOLoss, SFTLoss for training

Quick Example

from torchrl.modules.llm import vLLMWrapper, AsyncVLLM
from torchrl.envs.llm import ChatEnv
from torchrl.collectors.llm import LLMCollector

# Create vLLM engine
engine = AsyncVLLM.from_pretrained("Qwen/Qwen2.5-7B", num_replicas=2)
policy = vLLMWrapper(engine, input_mode="history")

# Create environment
env = ChatEnv(tokenizer=tokenizer)

# Create collector
collector = LLMCollector(env, policy, dialog_turns_per_batch=256)

Warning

The LLM API is still under development and may change in the future. Feedback, issues and PRs are welcome!

Documentation Sections

<<<<<<< HEAD Policy Version Tracking ~~~~~~~~~~~~~~~~~~~~~~~

LLM Collectors also allow to track the version of the policy, which is useful for some use cases. This is done by adding a PolicyVersion transform to the environment, which is then incremented by the collector after each weight update. To do this, one either provides the stateful version of the transform, or a boolean to the collector constructor.

>>> from torchrl.envs.llm.transforms import PolicyVersion
>>> from torchrl.collectors.llm import LLMCollector
>>> from torchrl.weight_update.llm import VLLMWeightSyncScheme, get_model_metadata
>>> env = make_env() # place your code here
>>> policy = make_policy() # place your code here
>>> scheme = VLLMWeightSyncScheme(master_port=29500, gpus_per_replica=1, num_replicas=1)
>>> collector = LLMCollector(env, policy=policy, weight_sync_schemes={"policy": scheme}, track_policy_version=True)
>>> # Get the sender and register model
>>> sender = collector._weight_senders["policy"]
>>> sender.register_model(training_model)
>>> # Initialize the collective group
>>> metadata = get_model_metadata(training_model)
>>> sender.init_all_workers_group(metadata, vllm_engine=policy.model)
>>> # Update weights
>>> sender.update_weights()
>>> print(collector.policy_version_tracker.version)
>>> # the policy version is written in the data
>>> for data in collector:
...     print(data["policy_version"])

VLLMWeightSyncScheme([master_address, ...])

Weight synchronization scheme for vLLM engines.

VLLMWeightSender(scheme)

Sends weights to vLLM workers using collective communication.

VLLMWeightReceiver(scheme, vllm_engine)

Receives weights in a vLLM worker using collective communication.

VLLMCollectiveTransport(master_address, ...)

Transport for vLLM using collective communication (NCCL).

VLLMDoubleBufferSyncScheme(remote_addr[, ...])

Weight synchronization scheme for vLLM using double-buffered storage.

VLLMDoubleBufferWeightSender(scheme)

Sends weights to vLLM workers using double-buffered storage.

VLLMDoubleBufferWeightReceiver(scheme, ...)

Receives weights in a vLLM worker using double-buffered storage.

VLLMDoubleBufferTransport(remote_addr[, ...])

Transport for vLLM using double-buffered memory-mapped storage.

get_model_metadata(model)

Extract model metadata from a model.

Legacy Weight Updaters (Deprecated)

Deprecated since version 0.11: The vLLMUpdater and vLLMUpdaterV2 classes are deprecated in favor of the new weight synchronization schemes (VLLMWeightSyncScheme and VLLMDoubleBufferSyncScheme). These schemes provide better performance, more flexibility, and cleaner integration with collectors. The legacy updaters will be removed in a future release.

The legacy weight updaters (vLLMUpdater and vLLMUpdaterV2) are still available but are no longer recommended. Please migrate to the new weight synchronization schemes shown above.

vLLMUpdater(*args[, v2])

A class that sends weights to vLLM workers.

vLLMUpdaterV2(vllm_engine)

Simplified vLLM weight updater using the RLvLLMEngine interface.

LLMCollector(env, *[, policy, ...])

A simplified version of SyncDataCollector for LLM inference.

RayLLMCollector(env, *[, policy, ...])

A lightweight Ray implementation of the LLM Collector that can be extended and sampled remotely.

Environments

The environment layer orchestrates data loading, tool execution, reward computation, and formatting. When fine-tuning an LLM using TorchRL, the environment is a crucial component of the inference pipeline, alongside the policy and collector.

ChatEnv

ChatEnv serves as a blank canvas for LLM environments - it’s a basic tool designed to be extended with transforms that add specific functionality. The base ChatEnv provides the fundamental structure for managing conversation state using the History format, but it’s intentionally minimal to allow maximum flexibility.

Core Functionality

ChatEnv operates in three main modes: - History mode: Uses History objects for conversation management - Text mode: Uses simple text strings for input/output - Tokens mode: Uses tokenized data for input/output

The environment maintains conversation state by: - Reset: Initializes a new conversation with an optional system prompt - Step: Takes the LLM’s response and updates the conversation history, preparing the next prompt

Transform-Based Architecture

Transforms are the main way to extend ChatEnv with specific capabilities:

Integration with LLM Wrappers

ChatEnv is designed to work seamlessly with both TransformersWrapper and vLLMWrapper. The environment handles the conversation state management while the wrapper handles the actual LLM inference, creating a clean separation of concerns.

On each call to step, the environment:

  • Takes the LLM’s output, specifically the full field, which contains the entire conversation so far, including the new response (e.g., history.full, text.full, tokens.full).

  • Sets this full field as the new prompt for the next LLM step (e.g., td[“next”, “history”].prompt, td[“next”, “text”].prompt, td[“next”, “tokens”].prompt).

  • Optionally, applies transforms to insert new user messages, tool calls, or other modifications to the conversation before the next LLM step to refine the prompt.

This mechanism enables seamless multi-turn interactions and supports complex workflows such as tool use and reward shaping.

Task-Specific Environments

We provide a few task-specific environments, such as GSM8KEnv for the GSM8K dataset, IFEvalEnv for the IFEval dataset, and MLGymEnv for MLGym integration.

These environments wrap a ChatEnv and add a DataLoadingPrimer transform (plus an optional reward parsing transform) in a TransformedEnv class.

ChatEnv(*args, **kwargs)

A chat-based environment for LLMs, designed as a blank canvas for conversation and RL.

DatasetChatEnv(*args, **kwargs)

Base class for chat environment with queries pulled from a dataset.

GSM8KEnv(*args, **kwargs)

GSM8K dataset environment.

make_gsm8k_env([dataset, num_envs, repeats, ...])

A builder for an LLMEnv-based GSM8K environment.

GSM8KPrepareQuestion([in_keys, out_keys])

A transform to prepare the prompt when using GSM8k within an LLMEnv.

IFEvalEnv(*args, **kwargs)

A chat environment based on the IFEval dataset.

IfEvalScorer(*[, instruction_ids_key, ...])

Scorer for the IF-Eval task.

IFEvalScoreData(prompt_level_strict_acc, ...)

LLMEnv(*args, **kwargs)

A text generation environment for language models.

LLMHashingEnv(*args, **kwargs)

A text generation environment that uses a hashing module to identify unique observations.

make_mlgym(*[, task, tasks, tokenizer, ...])

Wraps an MLGymEnv in a TorchRL Environment.

MLGymWrapper(*args, **kwargs)

A thin wrapper for MLGym environments.

GSM8KRewardParser(tokenizer[, in_keys, ...])

Reward parser for GSM8KEnv or make_gsm8k_env.

Transforms

Transforms are used to modify the data before it is passed to the LLM. Tools are usually implemented as transforms, and appended to a base environment such as ChatEnv.

An example of a tool transform is the PythonInterpreter transform, which is used to execute Python code in the context of the LLM. The PythonInterpreter can optionally use a shared PythonExecutorService for efficient resource usage across multiple environments. See ref_services for more details on the service registry system.

>>> from torchrl.envs.llm.transforms import PythonInterpreter
>>> from torchrl.envs.llm import ChatEnv
>>> from tensordict import TensorDict, set_list_to_stack
>>> from transformers import AutoTokenizer
>>> from pprint import pprint
>>> set_list_to_stack(True).set()
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
>>> base_env = ChatEnv(
...     tokenizer=tokenizer,
...     system_prompt="You are an assistant that can execute Python code. Decorate your code with ```python``` tags.",
...     user_role="user",
...     system_role="system",
...     batch_size=[1],
... )
>>> env = base_env.append_transform(PythonInterpreter())
>>> env.set_seed(0)
>>> # Pass the reset data - the prompt - to the environment
>>> reset_data = env.reset(TensorDict(
...     text="Let's write a Python function that returns the square of a number.",
...     batch_size=[1])
... )
>>> # Simulate an action - i.e., a response from the LLM (as if we were an LLM)
>>> action = """Here is a block of code to be executed in python:
... ```python
... def square(x):
...     return x * x
... print('testing the square function with input 2:', square(2))
... ```
... <|im_end|>
... """
>>> step_data = reset_data.set("text_response", [action])
>>> s, s_ = env.step_and_maybe_reset(reset_data)
>>> # The history is a stack of chat messages.
>>> #  The python interpreter transform has executed the code in the last message.
>>> pprint(s_["history"].apply_chat_template(tokenizer=tokenizer))
['<|im_start|>system\n'
 'You are an assistant that can execute Python code. Decorate your code with '
 '```python``` tags.<|im_end|>\n'
 '<|im_start|>user\n'
 "Let's write a Python function that returns the square of a "
 'number.<|im_end|>\n'
 '<|im_start|>assistant\n'
 'Here is a block of code to be executed in python:\n'
 '```python\n'
 'def square(x):\n'
 '    return x * x\n'
 "print('testing the square function with input 2:', square(2))\n"
 '```<|im_end|>\n'
 '<|im_start|>user\n'
 '<tool_response>\n'
 'Code block 1 executed successfully:\n'
 'testing the square function with input 2: 4\n'
 '\n'
 '</tool_response><|im_end|>\n'
 '<|im_start|>assistant\n']

Similarly, environments that load data from a dataset are just special instances of the ChatEnv augmented with a DataLoadingPrimer transforms (and some dedicated reward parsing transforms).

Designing Reward Transforms

When designing reward transforms for LLM environments, several key considerations must be addressed to ensure proper integration with the training pipeline. The examples of GSM8KRewardParser and IfEvalScorer provide excellent templates for reward transform design.

Reward Shape Requirements

The reward tensor must have the same number of dimensions as the logits, which is typically two more dimensions than the environment batch size:

  • Sparse rewards: Shape (*bsz, 1, 1) - single reward per sequence

  • Dense rewards: Shape (*bsz, num_tokens, 1) - per-token rewards

This shape requirement ensures compatibility with the loss computation pipeline. For example, in the GSM8K reward parser:

# Rewards need to have shape broadcastable to [batch x tokens x 1]
tds = tds.apply(lambda t: t.unsqueeze(-1).unsqueeze(-1))

Done State Management

It is crucial to properly manage the done state to prevent endless generation. Common strategies include:

  1. Completion-based termination: Set done when the response is complete (e.g., History.complete=True)

  2. Content-based termination: Set done when specific content is detected (e.g., <answer> blocks)

  3. Step-based termination: Use StepCounter for predetermined step limits

Example from IFEvalScorer:

if self.set_done_if_answer and bool(answer_blocks):
    next_tensordict.set("done", torch.ones(...))
    next_tensordict.set("terminated", torch.ones(...))

Input Mode Handling

Reward transforms must handle different input modes correctly:

  • History mode: Extract text from ("history", "full") or ("history", "response")

  • Text mode: Use text directly from ("text", "full") or ("text", "response")

  • Tokens mode: Decode tokens from ("tokens", "full") or ("tokens", "response")

The GSM8K reward parser demonstrates this pattern:

if input_mode == "history":
    responses = lazy_stack([r[..., -1] for r in responses.unbind(0)])
    if hasattr(responses, "content"):
        text_completion = responses.content
elif input_mode == "text":
    text_completion = responses
elif input_mode == "tokens":
    text_completion = self.tokenizer.decode(responses.flatten(0, 1).tolist())

Specification Management

Accurate specification of reward and observation specs is essential for proper environment initialization. Both GSM8K and IFEval provide good examples:

def transform_reward_spec(self, reward_spec: Composite) -> Composite:
    shape = reward_spec.shape + (1, 1)
    reward_spec.update(
        Composite(
            reward_answer=Unbounded(shape),
            reward_think=Unbounded(shape),
            reward_right=Unbounded(shape),
            reward_contained=Unbounded(shape),
            reward=Unbounded(shape),
            success=Unbounded(shape, dtype=torch.bool),
        )
    )
    return reward_spec

Batch Processing Considerations

For efficient processing, handle batched data appropriately:

  1. Flatten batch dimensions: Use tensordict.view(-1) for processing

  2. Reshape results: Restore original batch structure after processing

  3. Handle variable-length sequences: Use proper padding and masking

Reward Aggregation Strategies

Consider different reward aggregation approaches:

  1. Simple aggregation: Sum or average multiple reward components

  2. Weighted aggregation: Apply different weights to different components

  3. Conditional rewards: Base rewards on specific conditions or thresholds

The IFEvalScorer demonstrates a sophisticated aggregation strategy:

def default_reward_aggregator(self, score: IFEvalScoreData, ...):
    # Format score (max 1.0)
    format_score = (format_components * weights).sum(dim=-1, keepdim=True)

    # Structure score (max 1.0)
    structure_score = think_score + answer_score

    # Completion bonus (max 0.2)
    completion_bonus = float(complete) * 0.2

    return format_score + structure_score + completion_bonus

Post-Processing in Replay Buffers

Rewards can also be computed after the fact by appending transforms to the replay buffer. However, done state capture must remain in the environment transform since it needs to occur on-the-fly during data collection.

Error Handling and Robustness

Implement robust error handling for parsing failures:

try:
    cot, potential_answer = self.extract_tags(compl)
except ET.ParseError:
    cot, potential_answer = ("", "")

Performance Considerations

  1. Avoid redundant computations: Cache parsed results when possible

  2. Use efficient text processing: Leverage regex or XML parsing as appropriate

  3. Minimize memory allocations: Reuse tensors and avoid unnecessary copies

By following these design principles, reward transforms can be effectively integrated into the LLM training pipeline while maintaining performance and reliability.

AddThinkingPrompt(cond[, prompt, ...])

A transform that adds thinking prompts to encourage the LLM to reconsider its response.

BrowserTransform([allowed_domains, ...])

A transform that enables web browsing capabilities.

DataLoadingPrimer(*args[, use_ray_service])

A primer that loads data from a dataloader and converts it into a tensordict using stack_method.

KLComputation([gen_log_probs_full_key, ...])

A transform to compute KL divergence between two log-prob tensors and optionally add it to the reward.

KLRewardTransform(*args[, use_ray_service])

A legacy transform for computing KL divergence-based rewards.

MCPToolTransform(servers[, ...])

A transform that executes tools via the Model Context Protocol (MCP).

PolicyVersion(version_type, ...)

A transform that keeps track of the version of the policy.

PythonExecutorService([pool_size, timeout])

Ray actor that manages a pool of persistent Python interpreters.

PythonInterpreter([tokenizer, tool_name, ...])

A transform that executes Python code in the LLM response.

RayDataLoadingPrimer(*[, dataloader, ...])

A DataLoadingPrimer that creates a single actor that can be shared by multiple environments.

RetrieveKL(*args[, use_ray_service])

A transform to retrieve the KL divergence between two models' log-probabilities.

RetrieveLogProb(model, *[, ...])

A transform to retrieve log-probabilities from a model for KL divergence computation.

TemplateTransform(tokenizer[, chat_template])

A transform that maps applies a chat template to an input string during the forward pass, and parses the strings to the template during backward.

Tokenizer([in_keys, out_keys, in_keys_inv, ...])

Applies a tokenization operation on the specified inputs.

as_nested_tensor(list_of_tensordicts)

Stacks a list of tensordicts into a single tensordict with nested tensors.

as_padded_tensor(list_of_tensordicts[, dim, ...])

Stacks a list of tensordicts into a single tensordict with padded tensors.

Objectives

LLM post-training requires specialized loss functions that are adapted to the unique characteristics of language models.

LLMLossOutput(loss_objective, clip_fraction, ...)

GRPOLoss(*args, **kwargs)

GRPO loss.

GRPOLossOutput(loss_objective, ...[, ...])

CISPOLossOutput(loss_objective, ...[, ...])

DAPO(*args, **kwargs)

DAPO (Clip-Higher over GRPO).

DAPOLossOutput(loss_objective, ...[, ...])

MCAdvantage(grpo_size[, prompt_key, ...])

Monte-Carlo advantage computation engine.

SFTLoss(*args, **kwargs)

Supervised fine-tuning loss.

SFTLossOutput(loss_sft[, loss_kl_to_ref, ...])

TopKRewardSelector(total_dialog_turns, topk_size)

A replay-buffer transform that selects the top-k rewards for each prompt.

llms_envs llms_transforms llms_collectors llms_objectives

>>>>>>> 571142f4e ([Doc] Huge doc refactoring)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources