Rate this Page

torchrl.trainers package#

The trainer package provides utilities to write reusable training scripts. The core idea is to use a trainer that implements a nested loop, where the outer loop runs the data collection steps and the inner loop the optimization steps.

Key Features#

  • Modular hook system: Customize training at 18 different points in the loop

  • Checkpointing support: Save and restore training state with torch, torchsnapshot, or memmap (set via the CKPT_BACKEND environment variable)

  • Algorithm trainers: High-level trainers for PPO, SAC, DQN, DDPG, IQL, CQL with Hydra configuration

  • Builder helpers: Utilities for constructing collectors, losses, and replay buffers

Quick Example#

from torchrl.trainers import Trainer
from torchrl.trainers import UpdateWeights, LogScalar

# Create trainer
trainer = Trainer(
    collector=collector,
    total_frames=1000000,
    loss_module=loss,
    optimizer=optimizer,
)

# Register hooks
UpdateWeights(collector, 10).register(trainer)
LogScalar("reward").register(trainer)

# Train
trainer.train()

Documentation Sections#