Shortcuts

LLM interface

TorchRL offers a set of tools for LLM post-training, as well as some examples for training or setup.

Collectors

TorchRL offers specialized collector classes (LLMCollector and RayLLMCollector) that are tailored for LLM use cases. We also provide dedicated updaters for some inference engines.

LLM Collectors allow to track the version of the policy, which is useful for some use cases. This is done by adding a PolicyVersion transform to the environment, which is then incremented by the collector after each weight update. To do this, one either provides the stateful version of the transform, or a boolean to the collector constructor.

>>> from torchrl.envs.llm.transforms import PolicyVersion
>>> from torchrl.collectors.llm import LLMCollector
>>> from torchrl.collectors.llm.weight_update import vLLMUpdater
>>> env = make_env() # place your code here
>>> policy = make_policy() # place your code here
>>> collector = LLMCollector(env, policy=policy, weight_updater=vLLMUpdater(), track_policy_version=True)
>>> # init the updater
>>> collector.weight_updater.init(...)
>>> # the version is incremented after each weight update
>>> collector.update_policy_weights_(state_dict=...)
>>> print(collector.policy_version_tracker.version)
>>> # the policy version is written in the data
>>> for data in collector:
...     print(data["policy_version"])

vLLMUpdater([master_address, master_port, ...])

A class that sends weights to vLLM workers.

LLMCollector(env, *[, policy, ...])

A simplified version of SyncDataCollector for LLM inference.

RayLLMCollector(env, *[, policy, ...])

A lightweight Ray implementation of the LLM Collector that can be extended and sampled remotely.

Data structures

To handle text-based data structures (such as conversations etc.), we offer a few data structures dedicated to carrying data for LLM post-training.

History(role, content[, is_complete, ...])

ContentBase(type, text, url, data, ...[, ...])

LLMData([tokens, tokens_response, ...])

Environments

When fine-tuning an LLM using TorchRL, the environment is a crucial component of the inference pipeline, alongside the policy and collector. Environments manage operations that are not handled by the LLM itself, such as interacting with tools, loading prompts from datasets, computing rewards (when necessary), and formatting data.

Therefore, the fundamental structure of an LLM post-training pipeline is:

  • A policy that wraps the LLM and the LLM only

  • An environment that handles the world around the LLM:
  • A data collector that takes the policy (the LLM) and the environment, and handles the inference part of the pipeline:
    • Running reset, step and gathering actions;

    • Yielding the data in a consistent format - or populating a buffer;

    • Updating the policy weights (through WeightUpdaterBase classes)

  • A replay buffer that stores the data collected using the collector

  • A loss that takes the LLM’s output and returns a loss (through GRPOLoss for example)

These elements are presented in the GRPO scripts in the sota-implementations/llm directory.

The design of environments in TorchRL allows for flexibility and modularity. By framing tasks as environments, users can easily extend or modify existing environments using transforms. This approach enables the isolation of individual components within specific EnvBase or Transform subclasses, making it simpler to augment or alter the environment logic.

Available Environment Classes and Utilities

TorchRL provides various environment classes and utilities for working with LLMs, including:

  • Various environment classes (ChatEnv, DatasetChatEnv, GSM8KEnv, etc.)

  • Utility functions (make_gsm8k_env, make_mlgym, etc.)

  • Transforms and other supporting classes (KLRewardTransform, TemplateTransform, Tokenizer, etc.)

These components can be used to create customized environments tailored to specific use cases and requirements.

ChatEnv(*args, **kwargs)

A chat-based environment.

DatasetChatEnv(*args, **kwargs)

Base class for chat environment with queries pulled from a dataset.

GSM8KEnv(*args, **kwargs)

GSM8K dataset environment.

make_gsm8k_env([dataset, num_envs, repeats, ...])

A builder for an LLMEnv-based GSM8K environment.

GSM8KPrepareQuestion([in_keys, out_keys])

A transform to prepare the prompt when using GSM8k within an LLMEnv.

IFEvalEnv(*args, **kwargs)

A chat environment based on the IFEval dataset.

IfEvalScorer(*[, instruction_ids_key, ...])

Scorer for the IF-Eval task.

IFEvalScoreData(prompt_level_strict_acc, ...)

LLMEnv(*args, **kwargs)

A text generation environment for language models.

LLMHashingEnv(*args, **kwargs)

A text generation environment that uses a hashing module to identify unique observations.

make_mlgym(*[, task, tasks, tokenizer, ...])

Wraps an MLGymEnv in a TorchRL Environment.

MLGymWrapper(*args, **kwargs)

A thin wrapper for MLGym environments.

GSM8KRewardParser(tokenizer[, in_keys, ...])

Reward parser for GSM8KEnv or make_gsm8k_env.

Transforms

Transforms are used to modify the data before it is passed to the LLM. Tools are usually implemented as transforms, and appended to a base environment such as ChatEnv.

An example of a tool transform is the PythonInterpreter transform, which is used to execute Python code in the context of the LLM.

>>> from torchrl.envs.llm.transforms import PythonInterpreter
>>> from torchrl.envs.llm import ChatEnv
>>> from tensordict import TensorDict, set_list_to_stack
>>> from transformers import AutoTokenizer
>>> from pprint import pprint
>>> set_list_to_stack(True).set()
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
>>> base_env = ChatEnv(
...     tokenizer=tokenizer,
...     system_prompt="You are an assistant that can execute Python code. Decorate your code with ```python``` tags.",
...     user_role="user",
...     system_role="system",
...     batch_size=[1],
... )
>>> env = base_env.append_transform(PythonInterpreter())
>>> env.set_seed(0)
>>> # Pass the reset data - the prompt - to the environment
>>> reset_data = env.reset(TensorDict(
...     text="Let's write a Python function that returns the square of a number.",
...     batch_size=[1])
... )
>>> # Simulate an action - i.e., a response from the LLM (as if we were an LLM)
>>> action = """Here is a block of code to be executed in python:
... ```python
... def square(x):
...     return x * x
... print('testing the square function with input 2:', square(2))
... ```
... <|im_end|>
... """
>>> step_data = reset_data.set("text_response", [action])
>>> s, s_ = env.step_and_maybe_reset(reset_data)
>>> # The history is a stack of chat messages.
>>> #  The python interpreter transform has executed the code in the last message.
>>> pprint(s_["history"].apply_chat_template(tokenizer=tokenizer))
['<|im_start|>system\n'
 'You are an assistant that can execute Python code. Decorate your code with '
 '```python``` tags.<|im_end|>\n'
 '<|im_start|>user\n'
 "Let's write a Python function that returns the square of a "
 'number.<|im_end|>\n'
 '<|im_start|>assistant\n'
 'Here is a block of code to be executed in python:\n'
 '```python\n'
 'def square(x):\n'
 '    return x * x\n'
 "print('testing the square function with input 2:', square(2))\n"
 '```<|im_end|>\n'
 '<|im_start|>user\n'
 '<tool_response>\n'
 'Code block 1 executed successfully:\n'
 'testing the square function with input 2: 4\n'
 '\n'
 '</tool_response><|im_end|>\n'
 '<|im_start|>assistant\n']

Similarly, environments that load data from a dataset are just special instances of the ChatEnv augmented with a DataLoadingPrimer transforms (and some dedicated reward parsing transforms).

DataLoadingPrimer(dataloader, *[, primers, ...])

A primer that loads data from a dataloader and converts it into a tensordict using stack_method.

KLRewardTransform(actor[, coef, in_keys, ...])

A transform to add a KL[pi_current||pi_0] correction term to the reward.

RetrieveLogProb(actor, *[, history_key, ...])

A transform to retrieve the log-probs of a text given a reference model.

MCPToolTransform(tools, tool_schemas[, ...])

A transform that executes MCP-style tools in response to LLM actions.

BrowserTransform([allowed_domains, ...])

A transform that enables web browsing capabilities.

PythonInterpreter([tokenizer, tool_name, ...])

A transform that executes Python code in the LLM response.

PolicyVersion(version_type, ] =)

A transform that keeps track of the version of the policy.

TemplateTransform(tokenizer[, chat_template])

A transform that maps applies a chat template to an input string during the forward pass, and parses the strings to the template during backward.

Tokenizer([in_keys, out_keys, in_keys_inv, ...])

Applies a tokenization operation on the specified inputs.

as_nested_tensor(list_of_tensordicts)

Stacks a list of tensordicts into a single tensordict with nested tensors.

as_padded_tensor(list_of_tensordicts[, dim, ...])

Stacks a list of tensordicts into a single tensordict with padded tensors.

Modules

The ~torchrl.modules.llm section provides a set of wrappers and utility functions for popular training and inference backends. The main goal of these primitives is to:

  • Unify the input / output data format across training and inference pipelines;

  • Unify the input / output data format across backends (to be able to use different backends across losses and collectors, for instance)

  • Give appropriate tooling to construct these objects in typical RL settings (resource allocation, async execution, weight update, etc.)

Wrappers

TransformersWrapper(*args, **kwargs)

A wrapper class for Hugging Face Transformers models, providing a consistent interface for text generation and log probability computation.

vLLMWrapper(*args, **kwargs)

A wrapper class for vLLM models, providing a consistent interface for text generation and log probability computation, similar to the Hugging Face Transformers interface.

Utils

CategoricalSequential(*args, **kwargs)

A ProbabilisticTensorDictSequential subclass meant to work with LLMs.

LLMOnDevice(*args[, bundle_indices])

A thin wrapper around vllm.LLM to control its placement devices.

make_vllm_worker(*, model_name[, devices, ...])

Creates a vLLM inference engine with tensor parallelism support.

stateless_init_process_group(master_address, ...)

Initializes a stateless process group for distributed communication.

vLLMWorker(*args, **kwargs)

vLLM worker for Ray.

Objectives

LLM post training require some appropriate versions of the losses implemented in TorchRL.

GRPO

The GRPOLoss class is a thin wrapper around the PPOLoss class that codes the LLM-specific functionnalities.

GRPOLoss(*args, **kwargs)

GRPO loss.

GRPOLossOutput(loss_objective, ...[, ...])

MCAdvantage(grpo_size[, prompt_key, ...])

Monte-Carlo advantage computation engine.

SFT

SFTLoss(*args, **kwargs)

Supervised fine-tuning loss.

SFTLossOutput(loss_sft[, loss_kl_to_ref, ...])

TopKRewardSelector(total_dialog_turns, topk_size)

A replay-buffer transform that selects the top-k rewards for each prompt.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources