LLM interface¶
TorchRL offers a set of tools for LLM post-training, as well as some examples for training or setup.
Collectors¶
TorchRL offers specialized collector classes (LLMCollector
and RayLLMCollector
) that are tailored for LLM
use cases. We also provide dedicated updaters for some inference engines.
LLM Collectors allow to track the version of the policy, which is useful for some use cases.
This is done by adding a PolicyVersion
transform to the environment, which is
then incremented by the collector after each weight update. To do this, one either provides the stateful version of the
transform, or a boolean to the collector constructor.
>>> from torchrl.envs.llm.transforms import PolicyVersion
>>> from torchrl.collectors.llm import LLMCollector
>>> from torchrl.collectors.llm.weight_update import vLLMUpdater
>>> env = make_env() # place your code here
>>> policy = make_policy() # place your code here
>>> collector = LLMCollector(env, policy=policy, weight_updater=vLLMUpdater(), track_policy_version=True)
>>> # init the updater
>>> collector.weight_updater.init(...)
>>> # the version is incremented after each weight update
>>> collector.update_policy_weights_(state_dict=...)
>>> print(collector.policy_version_tracker.version)
>>> # the policy version is written in the data
>>> for data in collector:
... print(data["policy_version"])
|
A class that sends weights to vLLM workers. |
|
A simplified version of SyncDataCollector for LLM inference. |
|
A lightweight Ray implementation of the LLM Collector that can be extended and sampled remotely. |
Data structures¶
To handle text-based data structures (such as conversations etc.), we offer a few data structures dedicated to carrying data for LLM post-training.
|
|
|
|
|
Environments¶
When fine-tuning an LLM using TorchRL, the environment is a crucial component of the inference pipeline, alongside the policy and collector. Environments manage operations that are not handled by the LLM itself, such as interacting with tools, loading prompts from datasets, computing rewards (when necessary), and formatting data.
Therefore, the fundamental structure of an LLM post-training pipeline is:
A policy that wraps the LLM and the LLM only
- An environment that handles the world around the LLM:
Loading data (through
DataLoadingPrimer
)Formatting data (through
TemplateTransform
)Executing tools (through
PythonInterpreter
orMCPToolTransform
)Computing rewards online, if needed (through
KLRewardTransform
)
- A data collector that takes the policy (the LLM) and the environment, and handles the inference part of the pipeline:
Running reset, step and gathering actions;
Yielding the data in a consistent format - or populating a buffer;
Updating the policy weights (through
WeightUpdaterBase
classes)
A replay buffer that stores the data collected using the collector
A loss that takes the LLM’s output and returns a loss (through
GRPOLoss
for example)
These elements are presented in the GRPO scripts in the sota-implementations/llm directory.
The design of environments in TorchRL allows for flexibility and modularity. By framing tasks as environments, users can
easily extend or modify existing environments using transforms. This approach enables the isolation of individual
components within specific EnvBase
or Transform
subclasses, making it
simpler to augment or alter the environment logic.
Available Environment Classes and Utilities¶
TorchRL provides various environment classes and utilities for working with LLMs, including:
Various environment classes (
ChatEnv
,DatasetChatEnv
,GSM8KEnv
, etc.)Utility functions (
make_gsm8k_env
,make_mlgym
, etc.)Transforms and other supporting classes (
KLRewardTransform
,TemplateTransform
,Tokenizer
, etc.)
These components can be used to create customized environments tailored to specific use cases and requirements.
|
A chat-based environment. |
|
Base class for chat environment with queries pulled from a dataset. |
|
GSM8K dataset environment. |
|
A builder for an LLMEnv-based GSM8K environment. |
|
A transform to prepare the prompt when using GSM8k within an LLMEnv. |
|
A chat environment based on the IFEval dataset. |
|
Scorer for the IF-Eval task. |
|
|
|
A text generation environment for language models. |
|
A text generation environment that uses a hashing module to identify unique observations. |
|
Wraps an MLGymEnv in a TorchRL Environment. |
|
A thin wrapper for MLGym environments. |
|
Reward parser for GSM8KEnv or make_gsm8k_env. |
Transforms¶
Transforms are used to modify the data before it is passed to the LLM.
Tools are usually implemented as transforms, and appended to a base environment
such as ChatEnv
.
An example of a tool transform is the PythonInterpreter
transform, which is used
to execute Python code in the context of the LLM.
>>> from torchrl.envs.llm.transforms import PythonInterpreter
>>> from torchrl.envs.llm import ChatEnv
>>> from tensordict import TensorDict, set_list_to_stack
>>> from transformers import AutoTokenizer
>>> from pprint import pprint
>>> set_list_to_stack(True).set()
>>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
>>> base_env = ChatEnv(
... tokenizer=tokenizer,
... system_prompt="You are an assistant that can execute Python code. Decorate your code with ```python``` tags.",
... user_role="user",
... system_role="system",
... batch_size=[1],
... )
>>> env = base_env.append_transform(PythonInterpreter())
>>> env.set_seed(0)
>>> # Pass the reset data - the prompt - to the environment
>>> reset_data = env.reset(TensorDict(
... text="Let's write a Python function that returns the square of a number.",
... batch_size=[1])
... )
>>> # Simulate an action - i.e., a response from the LLM (as if we were an LLM)
>>> action = """Here is a block of code to be executed in python:
... ```python
... def square(x):
... return x * x
... print('testing the square function with input 2:', square(2))
... ```
... <|im_end|>
... """
>>> step_data = reset_data.set("text_response", [action])
>>> s, s_ = env.step_and_maybe_reset(reset_data)
>>> # The history is a stack of chat messages.
>>> # The python interpreter transform has executed the code in the last message.
>>> pprint(s_["history"].apply_chat_template(tokenizer=tokenizer))
['<|im_start|>system\n'
'You are an assistant that can execute Python code. Decorate your code with '
'```python``` tags.<|im_end|>\n'
'<|im_start|>user\n'
"Let's write a Python function that returns the square of a "
'number.<|im_end|>\n'
'<|im_start|>assistant\n'
'Here is a block of code to be executed in python:\n'
'```python\n'
'def square(x):\n'
' return x * x\n'
"print('testing the square function with input 2:', square(2))\n"
'```<|im_end|>\n'
'<|im_start|>user\n'
'<tool_response>\n'
'Code block 1 executed successfully:\n'
'testing the square function with input 2: 4\n'
'\n'
'</tool_response><|im_end|>\n'
'<|im_start|>assistant\n']
Similarly, environments that load data from a dataset are just special instances of the ChatEnv
augmented with a DataLoadingPrimer
transforms (and some dedicated reward parsing
transforms).
|
A primer that loads data from a dataloader and converts it into a tensordict using |
|
A transform to add a KL[pi_current||pi_0] correction term to the reward. |
|
A transform to retrieve the log-probs of a text given a reference model. |
|
A transform that executes MCP-style tools in response to LLM actions. |
|
A transform that enables web browsing capabilities. |
|
A transform that executes Python code in the LLM response. |
|
A transform that keeps track of the version of the policy. |
|
A transform that maps applies a chat template to an input string during the forward pass, and parses the strings to the template during backward. |
|
Applies a tokenization operation on the specified inputs. |
|
Stacks a list of tensordicts into a single tensordict with nested tensors. |
|
Stacks a list of tensordicts into a single tensordict with padded tensors. |
Modules¶
The ~torchrl.modules.llm section provides a set of wrappers and utility functions for popular training and inference backends. The main goal of these primitives is to:
Unify the input / output data format across training and inference pipelines;
Unify the input / output data format across backends (to be able to use different backends across losses and collectors, for instance)
Give appropriate tooling to construct these objects in typical RL settings (resource allocation, async execution, weight update, etc.)
Wrappers¶
|
A wrapper class for Hugging Face Transformers models, providing a consistent interface for text generation and log probability computation. |
|
A wrapper class for vLLM models, providing a consistent interface for text generation and log probability computation, similar to the Hugging Face Transformers interface. |
Utils¶
|
A ProbabilisticTensorDictSequential subclass meant to work with LLMs. |
|
A thin wrapper around vllm.LLM to control its placement devices. |
|
Creates a vLLM inference engine with tensor parallelism support. |
|
Initializes a stateless process group for distributed communication. |
|
vLLM worker for Ray. |
Objectives¶
LLM post training require some appropriate versions of the losses implemented in TorchRL.
GRPO¶
The GRPOLoss
class is a thin wrapper around the PPOLoss
class
that codes the LLM-specific functionnalities.
|
GRPO loss. |
|
|
|
Monte-Carlo advantage computation engine. |
SFT¶
|
Supervised fine-tuning loss. |
|
|
A replay-buffer transform that selects the top-k rewards for each prompt. |