.. currentmodule:: torchrl.trainers
torchrl.trainers package
========================
.. _ref_trainers:
The trainer package provides utilities to write re-usable training scripts. The core idea is to use a
trainer that implements a nested loop, where the outer loop runs the data collection steps and the inner
loop the optimization steps. We believe this fits multiple RL training schemes, such as
on-policy, off-policy, model-based and model-free solutions, offline RL and others.
More particular cases, such as meta-RL algorithms may have training schemes that differ substantially.
The ``trainer.train()`` method can be sketched as follows:
.. code-block::
:caption: Trainer loops
>>> for batch in collector:
... batch = self._process_batch_hook(batch) # "batch_process"
... self._pre_steps_log_hook(batch) # "pre_steps_log"
... self._pre_optim_hook() # "pre_optim_steps"
... for j in range(self.optim_steps_per_batch):
... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch"
... losses = self.loss_module(sub_batch)
... self._post_loss_hook(sub_batch) # "post_loss"
... self.optimizer.step()
... self.optimizer.zero_grad()
... self._post_optim_hook() # "post_optim"
... self._post_optim_log(sub_batch) # "post_optim_log"
... self._post_steps_hook() # "post_steps"
... self._post_steps_log_hook(batch) # "post_steps_log"
There are 10 hooks that can be used in a trainer loop:
>>> for batch in collector:
... batch = self._process_batch_hook(batch) # "batch_process"
... self._pre_steps_log_hook(batch) # "pre_steps_log"
... self._pre_optim_hook() # "pre_optim_steps"
... for j in range(self.optim_steps_per_batch):
... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch"
... losses = self.loss_module(sub_batch)
... self._post_loss_hook(sub_batch) # "post_loss"
... self.optimizer.step()
... self.optimizer.zero_grad()
... self._post_optim_hook() # "post_optim"
... self._post_optim_log(sub_batch) # "post_optim_log"
... self._post_steps_hook() # "post_steps"
... self._post_steps_log_hook(batch) # "post_steps_log"
There are 10 hooks that can be used in a trainer loop:
>>> for batch in collector:
... batch = self._process_batch_hook(batch) # "batch_process"
... self._pre_steps_log_hook(batch) # "pre_steps_log"
... self._pre_optim_hook() # "pre_optim_steps"
... for j in range(self.optim_steps_per_batch):
... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch"
... losses = self.loss_module(sub_batch)
... self._post_loss_hook(sub_batch) # "post_loss"
... self.optimizer.step()
... self.optimizer.zero_grad()
... self._post_optim_hook() # "post_optim"
... self._post_optim_log(sub_batch) # "post_optim_log"
... self._post_steps_hook() # "post_steps"
... self._post_steps_log_hook(batch) # "post_steps_log"
There are 10 hooks that can be used in a trainer loop: ``"batch_process"``, ``"pre_optim_steps"``,
``"process_optim_batch"``, ``"post_loss"``, ``"post_steps"``, ``"post_optim"``, ``"pre_steps_log"``,
``"post_steps_log"``, ``"post_optim_log"`` and ``"optimizer"``. They are indicated in the comments where they are applied.
Hooks can be split into 3 categories: **data processing** (``"batch_process"`` and ``"process_optim_batch"``),
**logging** (``"pre_steps_log"``, ``"post_optim_log"`` and ``"post_steps_log"``) and **operations** hook
(``"pre_optim_steps"``, ``"post_loss"``, ``"post_optim"`` and ``"post_steps"``).
- **Data processing** hooks update a tensordict of data. Hooks ``__call__`` method should accept
a ``TensorDict`` object as input and update it given some strategy.
Examples of such hooks include Replay Buffer extension (``ReplayBufferTrainer.extend``), data normalization (including normalization
constants update), data subsampling (:class:``~torchrl.trainers.BatchSubSampler``) and such.
- **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger
some information retrieved from that data. Examples include the ``LogValidationReward`` hook, the reward
logger (``LogScalar``) and such. Hooks should return a dictionary (or a None value) containing the
data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value
should be displayed on the progression bar printed on the training log.
- **Operation** hooks are hooks that execute specific operations over the models, data collectors,
target network updates and such. For instance, syncing the weights of the collectors using ``UpdateWeights``
or update the priority of the replay buffer using ``ReplayBufferTrainer.update_priority`` are examples
of operation hooks. They are data-independent (they do not require a ``TensorDict``
input), they are just supposed to be executed once at every iteration (or every N iterations).
The hooks provided by TorchRL usually inherit from a common abstract class ``TrainerHookBase``,
and all implement three base methods: a ``state_dict`` and ``load_state_dict`` method for
checkpointing and a ``register`` method that registers the hook at the default value in the
trainer. This method takes a trainer and a module name as input. For instance, the following logging
hook is executed every 10 calls to ``"post_optim_log"``:
.. code-block::
>>> class LoggingHook(TrainerHookBase):
... def __init__(self):
... self.counter = 0
...
... def register(self, trainer, name):
... trainer.register_module(self, "logging_hook")
... trainer.register_op("post_optim_log", self)
...
... def save_dict(self):
... return {"counter": self.counter}
...
... def load_state_dict(self, state_dict):
... self.counter = state_dict["counter"]
...
... def __call__(self, batch):
... if self.counter % 10 == 0:
... self.counter += 1
... out = {"some_value": batch["some_value"].item(), "log_pbar": False}
... else:
... out = None
... self.counter += 1
... return out
Checkpointing
-------------
The trainer class and hooks support checkpointing, which can be achieved either
using the `torchsnapshot `_ backend or
the regular torch backend. This can be controlled via the global variable ``CKPT_BACKEND``:
.. code-block::
$ CKPT_BACKEND=torchsnapshot python script.py
``CKPT_BACKEND`` defaults to ``torch``. The advantage of torchsnapshot over pytorch
is that it is a more flexible API, which supports distributed checkpointing and
also allows users to load tensors from a file stored on disk to a tensor with a
physical storage (which pytorch currently does not support). This allows, for instance,
to load tensors from and to a replay buffer that would otherwise not fit in memory.
When building a trainer, one can provide a path where the checkpoints are to
be written. With the ``torchsnapshot`` backend, a directory path is expected,
whereas the ``torch`` backend expects a file path (typically a ``.pt`` file).
.. code-block::
>>> filepath = "path/to/dir/or/file"
>>> trainer = Trainer(
... collector=collector,
... total_frames=total_frames,
... frame_skip=frame_skip,
... loss_module=loss_module,
... optimizer=optimizer,
... save_trainer_file=filepath,
... )
>>> select_keys = SelectKeys(["action", "observation"])
>>> select_keys.register(trainer)
>>> # to save to a path
>>> trainer.save_trainer(True)
>>> # to load from a path
>>> trainer.load_from_file(filepath)
The ``Trainer.train()`` method can be used to execute the above loop with all of
its hooks, although using the :obj:`Trainer` class for its checkpointing capability
only is also a perfectly valid use.
Trainer and hooks
-----------------
.. autosummary::
:toctree: generated/
:template: rl_template.rst
BatchSubSampler
ClearCudaCache
CountFramesLog
LogScalar
OptimizerHook
LogValidationReward
ReplayBufferTrainer
RewardNormalizer
SelectKeys
Trainer
TrainerHookBase
UpdateWeights
TargetNetUpdaterHook
UTDRHook
Algorithm-specific trainers (Experimental)
------------------------------------------
.. warning::
The following trainers are experimental/prototype features. The API may change in future versions.
Please report any issues or feedback to help improve these implementations.
TorchRL provides high-level, algorithm-specific trainers that combine the modular components
into complete training solutions with sensible defaults and comprehensive configuration options.
.. currentmodule:: torchrl.trainers.algorithms
.. autosummary::
:toctree: generated/
:template: rl_template.rst
PPOTrainer
SACTrainer
Algorithm Trainers
~~~~~~~~~~~~~~~~~~
TorchRL provides high-level algorithm trainers that offer complete training solutions with minimal code.
These trainers feature comprehensive configuration systems built on Hydra, enabling both simple usage
and sophisticated customization.
**Currently Available:**
- :class:`~torchrl.trainers.algorithms.PPOTrainer` - Proximal Policy Optimization
- :class:`~torchrl.trainers.algorithms.SACTrainer` - Soft Actor-Critic
**Key Features:**
- **Complete pipeline**: Environment setup, data collection, and optimization
- **Hydra configuration**: Extensive dataclass-based configuration system
- **Built-in logging**: Rewards, actions, and algorithm-specific metrics
- **Modular design**: Built on existing TorchRL components
- **Minimal code**: Complete SOTA implementations in ~20 lines!
.. warning::
Algorithm trainers are experimental features. The API may change in future versions.
We welcome feedback and contributions to help improve these implementations!
Quick Start Examples
^^^^^^^^^^^^^^^^^^^^
**PPO Training:**
.. code-block:: bash
# Train PPO on Pendulum-v1 with default settings
python sota-implementations/ppo_trainer/train.py
**SAC Training:**
.. code-block:: bash
# Train SAC on a continuous control task
python sota-implementations/sac_trainer/train.py
**Custom Configuration:**
.. code-block:: bash
# Override parameters for any algorithm
python sota-implementations/ppo_trainer/train.py \
trainer.total_frames=2000000 \
training_env.create_env_fn.base_env.env_name=HalfCheetah-v4 \
networks.policy_network.num_cells=[256,256] \
optimizer.lr=0.0003
**Environment Switching:**
.. code-block:: bash
# Switch environment and logger for any trainer
python sota-implementations/sac_trainer/train.py \
training_env.create_env_fn.base_env.env_name=Walker2d-v4 \
logger=tensorboard \
logger.exp_name=sac_walker2d
**View Configuration Options:**
.. code-block:: bash
# See all available options for any trainer
python sota-implementations/ppo_trainer/train.py --help
python sota-implementations/sac_trainer/train.py --help
Universal Configuration System
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
All algorithm trainers share a unified configuration architecture organized into logical groups:
- **Environment**: ``training_env.create_env_fn.base_env.env_name``, ``training_env.num_workers``
- **Networks**: ``networks.policy_network.num_cells``, ``networks.value_network.num_cells``
- **Training**: ``trainer.total_frames``, ``trainer.clip_norm``, ``optimizer.lr``
- **Data**: ``collector.frames_per_batch``, ``replay_buffer.batch_size``, ``replay_buffer.storage.max_size``
- **Logging**: ``logger.exp_name``, ``logger.project``, ``trainer.log_interval``
**Working Example:**
All trainer implementations follow the same simple pattern:
.. code-block:: python
import hydra
from torchrl.trainers.algorithms.configs import *
@hydra.main(config_path="config", config_name="config", version_base="1.1")
def main(cfg):
trainer = hydra.utils.instantiate(cfg.trainer)
trainer.train()
if __name__ == "__main__":
main()
*Complete algorithm training with full configurability in ~20 lines!*
Configuration Classes
^^^^^^^^^^^^^^^^^^^^^
The trainer system uses a hierarchical configuration system with shared components.
.. note::
The configuration system requires Python 3.10+ due to its use of modern type annotation syntax.
**Algorithm-Specific Trainers:**
- **PPO**: :class:`~torchrl.trainers.algorithms.configs.trainers.PPOTrainerConfig`
- **SAC**: :class:`~torchrl.trainers.algorithms.configs.trainers.SACTrainerConfig`
**Shared Configuration Components:**
- **Environment**: :class:`~torchrl.trainers.algorithms.configs.envs_libs.GymEnvConfig`, :class:`~torchrl.trainers.algorithms.configs.envs.BatchedEnvConfig`
- **Networks**: :class:`~torchrl.trainers.algorithms.configs.modules.MLPConfig`, :class:`~torchrl.trainers.algorithms.configs.modules.TanhNormalModelConfig`
- **Data**: :class:`~torchrl.trainers.algorithms.configs.data.TensorDictReplayBufferConfig`, :class:`~torchrl.trainers.algorithms.configs.collectors.MultiaSyncDataCollectorConfig`
- **Objectives**: :class:`~torchrl.trainers.algorithms.configs.objectives.PPOLossConfig`, :class:`~torchrl.trainers.algorithms.configs.objectives.SACLossConfig`
- **Optimizers**: :class:`~torchrl.trainers.algorithms.configs.utils.AdamConfig`, :class:`~torchrl.trainers.algorithms.configs.utils.AdamWConfig`
- **Logging**: :class:`~torchrl.trainers.algorithms.configs.logging.WandbLoggerConfig`, :class:`~torchrl.trainers.algorithms.configs.logging.TensorboardLoggerConfig`
Algorithm-Specific Features
^^^^^^^^^^^^^^^^^^^^^^^^^^^
**PPOTrainer:**
- On-policy learning with advantage estimation
- Policy clipping and value function optimization
- Configurable number of epochs per batch
- Built-in GAE (Generalized Advantage Estimation)
**SACTrainer:**
- Off-policy learning with replay buffer
- Entropy-regularized policy optimization
- Target network soft updates
- Continuous action space optimization
**Future Development:**
The trainer system is actively expanding. Upcoming features include:
- Additional algorithms: TD3, DQN, A2C, DDPG, and more
- Enhanced distributed training support
- Advanced configuration validation and error reporting
- Integration with more TorchRL ecosystem components
See the complete `configuration system documentation `_ for all available options and examples.
Builders
--------
.. currentmodule:: torchrl.trainers.helpers
.. autosummary::
:toctree: generated/
:template: rl_template_fun.rst
make_collector_offpolicy
make_collector_onpolicy
make_dqn_loss
make_replay_buffer
make_target_updater
make_trainer
parallel_env_constructor
sync_async_collector
sync_sync_collector
transformed_env_constructor
Utils
-----
.. autosummary::
:toctree: generated/
:template: rl_template_fun.rst
correct_for_frame_skip
get_stats_random_rollout
Loggers
-------
.. _ref_loggers:
.. currentmodule:: torchrl.record.loggers
.. autosummary::
:toctree: generated/
:template: rl_template_fun.rst
Logger
csv.CSVLogger
mlflow.MLFlowLogger
tensorboard.TensorboardLogger
wandb.WandbLogger
get_logger
generate_exp_name
Recording utils
---------------
Recording utils are detailed :ref:`here `.
.. currentmodule:: torchrl.record
.. autosummary::
:toctree: generated/
:template: rl_template_fun.rst
VideoRecorder
TensorDictRecorder
PixelRenderTransform