torchrl.trainers package¶
The trainer package provides utilities to write re-usable training scripts. The core idea is to use a trainer that implements a nested loop, where the outer loop runs the data collection steps and the inner loop the optimization steps. We believe this fits multiple RL training schemes, such as on-policy, off-policy, model-based and model-free solutions, offline RL and others. More particular cases, such as meta-RL algorithms may have training schemes that differ substantially.
The trainer.train()
method can be sketched as follows:
>>> for batch in collector:
... batch = self._process_batch_hook(batch) # "batch_process"
... self._pre_steps_log_hook(batch) # "pre_steps_log"
... self._pre_optim_hook() # "pre_optim_steps"
... for j in range(self.optim_steps_per_batch):
... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch"
... losses = self.loss_module(sub_batch)
... self._post_loss_hook(sub_batch) # "post_loss"
... self.optimizer.step()
... self.optimizer.zero_grad()
... self._post_optim_hook() # "post_optim"
... self._post_optim_log(sub_batch) # "post_optim_log"
... self._post_steps_hook() # "post_steps"
... self._post_steps_log_hook(batch) # "post_steps_log"
There are 10 hooks that can be used in a trainer loop:
>>> for batch in collector:
... batch = self._process_batch_hook(batch) # "batch_process"
... self._pre_steps_log_hook(batch) # "pre_steps_log"
... self._pre_optim_hook() # "pre_optim_steps"
... for j in range(self.optim_steps_per_batch):
... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch"
... losses = self.loss_module(sub_batch)
... self._post_loss_hook(sub_batch) # "post_loss"
... self.optimizer.step()
... self.optimizer.zero_grad()
... self._post_optim_hook() # "post_optim"
... self._post_optim_log(sub_batch) # "post_optim_log"
... self._post_steps_hook() # "post_steps"
... self._post_steps_log_hook(batch) # "post_steps_log"
There are 10 hooks that can be used in a trainer loop:
>>> for batch in collector:
... batch = self._process_batch_hook(batch) # "batch_process"
... self._pre_steps_log_hook(batch) # "pre_steps_log"
... self._pre_optim_hook() # "pre_optim_steps"
... for j in range(self.optim_steps_per_batch):
... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch"
... losses = self.loss_module(sub_batch)
... self._post_loss_hook(sub_batch) # "post_loss"
... self.optimizer.step()
... self.optimizer.zero_grad()
... self._post_optim_hook() # "post_optim"
... self._post_optim_log(sub_batch) # "post_optim_log"
... self._post_steps_hook() # "post_steps"
... self._post_steps_log_hook(batch) # "post_steps_log"
There are 10 hooks that can be used in a trainer loop: "batch_process"
, "pre_optim_steps"
,
"process_optim_batch"
, "post_loss"
, "post_steps"
, "post_optim"
, "pre_steps_log"
,
"post_steps_log"
, "post_optim_log"
and "optimizer"
. They are indicated in the comments where they are applied.
Hooks can be split into 3 categories: data processing ("batch_process"
and "process_optim_batch"
),
logging ("pre_steps_log"
, "post_optim_log"
and "post_steps_log"
) and operations hook
("pre_optim_steps"
, "post_loss"
, "post_optim"
and "post_steps"
).
Data processing hooks update a tensordict of data. Hooks
__call__
method should accept aTensorDict
object as input and update it given some strategy. Examples of such hooks include Replay Buffer extension (ReplayBufferTrainer.extend
), data normalization (including normalization constants update), data subsampling (:class:~torchrl.trainers.BatchSubSampler
) and such.Logging hooks take a batch of data presented as a
TensorDict
and write in the logger some information retrieved from that data. Examples include theLogValidationReward
hook, the reward logger (LogScalar
) and such. Hooks should return a dictionary (or a None value) containing the data to log. The key"log_pbar"
is reserved to boolean values indicating if the logged value should be displayed on the progression bar printed on the training log.Operation hooks are hooks that execute specific operations over the models, data collectors, target network updates and such. For instance, syncing the weights of the collectors using
UpdateWeights
or update the priority of the replay buffer usingReplayBufferTrainer.update_priority
are examples of operation hooks. They are data-independent (they do not require aTensorDict
input), they are just supposed to be executed once at every iteration (or every N iterations).
The hooks provided by TorchRL usually inherit from a common abstract class TrainerHookBase
,
and all implement three base methods: a state_dict
and load_state_dict
method for
checkpointing and a register
method that registers the hook at the default value in the
trainer. This method takes a trainer and a module name as input. For instance, the following logging
hook is executed every 10 calls to "post_optim_log"
:
>>> class LoggingHook(TrainerHookBase):
... def __init__(self):
... self.counter = 0
...
... def register(self, trainer, name):
... trainer.register_module(self, "logging_hook")
... trainer.register_op("post_optim_log", self)
...
... def save_dict(self):
... return {"counter": self.counter}
...
... def load_state_dict(self, state_dict):
... self.counter = state_dict["counter"]
...
... def __call__(self, batch):
... if self.counter % 10 == 0:
... self.counter += 1
... out = {"some_value": batch["some_value"].item(), "log_pbar": False}
... else:
... out = None
... self.counter += 1
... return out
Checkpointing¶
The trainer class and hooks support checkpointing, which can be achieved either
using the torchsnapshot backend or
the regular torch backend. This can be controlled via the global variable CKPT_BACKEND
:
$ CKPT_BACKEND=torchsnapshot python script.py
CKPT_BACKEND
defaults to torch
. The advantage of torchsnapshot over pytorch
is that it is a more flexible API, which supports distributed checkpointing and
also allows users to load tensors from a file stored on disk to a tensor with a
physical storage (which pytorch currently does not support). This allows, for instance,
to load tensors from and to a replay buffer that would otherwise not fit in memory.
When building a trainer, one can provide a path where the checkpoints are to
be written. With the torchsnapshot
backend, a directory path is expected,
whereas the torch
backend expects a file path (typically a .pt
file).
>>> filepath = "path/to/dir/or/file"
>>> trainer = Trainer(
... collector=collector,
... total_frames=total_frames,
... frame_skip=frame_skip,
... loss_module=loss_module,
... optimizer=optimizer,
... save_trainer_file=filepath,
... )
>>> select_keys = SelectKeys(["action", "observation"])
>>> select_keys.register(trainer)
>>> # to save to a path
>>> trainer.save_trainer(True)
>>> # to load from a path
>>> trainer.load_from_file(filepath)
The Trainer.train()
method can be used to execute the above loop with all of
its hooks, although using the Trainer
class for its checkpointing capability
only is also a perfectly valid use.
Trainer and hooks¶
|
Data subsampler for online RL sota-implementations. |
|
Clears cuda cache at a given interval. |
|
A frame counter hook. |
|
Generic scalar logger hook for any tensor values in the batch. |
|
Add an optimizer for one or more loss components. |
|
Recorder hook for |
|
Replay buffer hook provider. |
|
Reward normalizer hook. |
|
Selects keys in a TensorDict batch. |
|
A generic Trainer class. |
An abstract hooking class for torchrl Trainer class. |
|
|
A collector weights update hook class. |
Algorithm-specific trainers (Experimental)¶
Warning
The following trainers are experimental/prototype features. The API may change in future versions. Please report any issues or feedback to help improve these implementations.
TorchRL provides high-level, algorithm-specific trainers that combine the modular components into complete training solutions with sensible defaults and comprehensive configuration options.
|
PPO (Proximal Policy Optimization) trainer implementation. |
PPOTrainer¶
The PPOTrainer
provides a complete PPO training solution
with configurable defaults and a comprehensive configuration system built on Hydra.
Key Features:
Complete training pipeline with environment setup, data collection, and optimization
Extensive configuration system using dataclasses and Hydra
Built-in logging for rewards, actions, and training statistics
Modular design built on existing TorchRL components
Minimal code: Complete SOTA implementation in just ~20 lines!
Warning
This is an experimental feature. The API may change in future versions. We welcome feedback and contributions to help improve this implementation!
Quick Start - Command Line Interface:
# Basic usage - train PPO on Pendulum-v1 with default settings
python sota-implementations/ppo_trainer/train.py
Custom Configuration:
# Override specific parameters via command line
python sota-implementations/ppo_trainer/train.py \
trainer.total_frames=2000000 \
training_env.create_env_fn.base_env.env_name=HalfCheetah-v4 \
networks.policy_network.num_cells=[256,256] \
optimizer.lr=0.0003
Environment Switching:
# Switch to a different environment and logger
python sota-implementations/ppo_trainer/train.py \
env=gym \
training_env.create_env_fn.base_env.env_name=Walker2d-v4 \
logger=tensorboard
See All Options:
# View all available configuration options
python sota-implementations/ppo_trainer/train.py --help
Configuration Groups:
The PPOTrainer configuration is organized into logical groups:
Environment:
env_cfg__env_name
,env_cfg__backend
,env_cfg__device
Networks:
actor_network__network__num_cells
,critic_network__module__num_cells
Training:
total_frames
,clip_norm
,num_epochs
,optimizer_cfg__lr
Logging:
log_rewards
,log_actions
,log_observations
Working Example:
The sota-implementations/ppo_trainer/ directory contains a complete, working PPO implementation that demonstrates the simplicity and power of the trainer system:
import hydra
from torchrl.trainers.algorithms.configs import *
@hydra.main(config_path="config", config_name="config", version_base="1.1")
def main(cfg):
trainer = hydra.utils.instantiate(cfg.trainer)
trainer.train()
if __name__ == "__main__":
main()
Complete PPO training with full configurability in ~20 lines!
Configuration Classes:
The PPOTrainer uses a hierarchical configuration system with these main config classes.
Note
The configuration system requires Python 3.10+ due to its use of modern type annotation syntax.
Trainer:
PPOTrainerConfig
Environment:
GymEnvConfig
,BatchedEnvConfig
Networks:
MLPConfig
,TanhNormalModelConfig
Data:
TensorDictReplayBufferConfig
,MultiaSyncDataCollectorConfig
Objectives:
PPOLossConfig
Optimizers:
AdamConfig
,AdamWConfig
Logging:
WandbLoggerConfig
,TensorboardLoggerConfig
Future Development:
This is the first of many planned algorithm-specific trainers. Future releases will include:
Additional algorithms: SAC, TD3, DQN, A2C, and more
Full integration of all TorchRL components within the configuration system
Enhanced configuration validation and error reporting
Distributed training support for high-level trainers
See the complete configuration system documentation for all available options.
Builders¶
|
Returns a data collector for off-policy sota-implementations. |
|
Makes a collector in on-policy settings. |
|
Builds the DQN loss module. |
|
Builds a replay buffer using the config built from ReplayArgsConfig. |
|
Builds a target network weight update object. |
|
Creates a Trainer instance given its constituents. |
|
Returns a parallel environment from an argparse.Namespace built with the appropriate parser constructor. |
|
Runs asynchronous collectors, each running synchronous environments. |
|
Runs synchronous collectors, each running synchronous environments. |
|
Returns an environment creator from an argparse.Namespace built with the appropriate parser constructor. |
Utils¶
Correct the arguments for the input frame_skip, by dividing all the arguments that reflect a count of frames by the frame_skip. |
|
|
Gathers stas (loc and scale) from an environment using random rollouts. |
Loggers¶
|
A template for loggers. |
|
A minimal-dependency CSV logger. |
|
Wrapper for the mlflow logger. |
|
Wrapper for the Tensoarboard logger. |
|
Wrapper for the wandb logger. |
|
Get a logger instance of the provided logger_type. |
|
Generates an ID (str) for the described experiment using UUID and current date. |
Recording utils¶
Recording utils are detailed here.
|
Video Recorder transform. |
|
TensorDict recorder. |
|
A transform to call render on the parent environment and register the pixel observation in the tensordict. |