.. currentmodule:: torchrl.trainers torchrl.trainers package ======================== .. _ref_trainers: The trainer package provides utilities to write re-usable training scripts. The core idea is to use a trainer that implements a nested loop, where the outer loop runs the data collection steps and the inner loop the optimization steps. We believe this fits multiple RL training schemes, such as on-policy, off-policy, model-based and model-free solutions, offline RL and others. More particular cases, such as meta-RL algorithms may have training schemes that differ substantially. The ``trainer.train()`` method can be sketched as follows: .. code-block:: :caption: Trainer loops >>> for batch in collector: ... batch = self._process_batch_hook(batch) # "batch_process" ... self._pre_steps_log_hook(batch) # "pre_steps_log" ... self._pre_optim_hook() # "pre_optim_steps" ... for j in range(self.optim_steps_per_batch): ... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch" ... losses = self.loss_module(sub_batch) ... self._post_loss_hook(sub_batch) # "post_loss" ... self.optimizer.step() ... self.optimizer.zero_grad() ... self._post_optim_hook() # "post_optim" ... self._post_optim_log(sub_batch) # "post_optim_log" ... self._post_steps_hook() # "post_steps" ... self._post_steps_log_hook(batch) # "post_steps_log" There are 10 hooks that can be used in a trainer loop: >>> for batch in collector: ... batch = self._process_batch_hook(batch) # "batch_process" ... self._pre_steps_log_hook(batch) # "pre_steps_log" ... self._pre_optim_hook() # "pre_optim_steps" ... for j in range(self.optim_steps_per_batch): ... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch" ... losses = self.loss_module(sub_batch) ... self._post_loss_hook(sub_batch) # "post_loss" ... self.optimizer.step() ... self.optimizer.zero_grad() ... self._post_optim_hook() # "post_optim" ... self._post_optim_log(sub_batch) # "post_optim_log" ... self._post_steps_hook() # "post_steps" ... self._post_steps_log_hook(batch) # "post_steps_log" There are 10 hooks that can be used in a trainer loop: >>> for batch in collector: ... batch = self._process_batch_hook(batch) # "batch_process" ... self._pre_steps_log_hook(batch) # "pre_steps_log" ... self._pre_optim_hook() # "pre_optim_steps" ... for j in range(self.optim_steps_per_batch): ... sub_batch = self._process_optim_batch_hook(batch) # "process_optim_batch" ... losses = self.loss_module(sub_batch) ... self._post_loss_hook(sub_batch) # "post_loss" ... self.optimizer.step() ... self.optimizer.zero_grad() ... self._post_optim_hook() # "post_optim" ... self._post_optim_log(sub_batch) # "post_optim_log" ... self._post_steps_hook() # "post_steps" ... self._post_steps_log_hook(batch) # "post_steps_log" There are 10 hooks that can be used in a trainer loop: ``"batch_process"``, ``"pre_optim_steps"``, ``"process_optim_batch"``, ``"post_loss"``, ``"post_steps"``, ``"post_optim"``, ``"pre_steps_log"``, ``"post_steps_log"``, ``"post_optim_log"`` and ``"optimizer"``. They are indicated in the comments where they are applied. Hooks can be split into 3 categories: **data processing** (``"batch_process"`` and ``"process_optim_batch"``), **logging** (``"pre_steps_log"``, ``"post_optim_log"`` and ``"post_steps_log"``) and **operations** hook (``"pre_optim_steps"``, ``"post_loss"``, ``"post_optim"`` and ``"post_steps"``). - **Data processing** hooks update a tensordict of data. Hooks ``__call__`` method should accept a ``TensorDict`` object as input and update it given some strategy. Examples of such hooks include Replay Buffer extension (``ReplayBufferTrainer.extend``), data normalization (including normalization constants update), data subsampling (:class:``~torchrl.trainers.BatchSubSampler``) and such. - **Logging** hooks take a batch of data presented as a ``TensorDict`` and write in the logger some information retrieved from that data. Examples include the ``LogValidationReward`` hook, the reward logger (``LogScalar``) and such. Hooks should return a dictionary (or a None value) containing the data to log. The key ``"log_pbar"`` is reserved to boolean values indicating if the logged value should be displayed on the progression bar printed on the training log. - **Operation** hooks are hooks that execute specific operations over the models, data collectors, target network updates and such. For instance, syncing the weights of the collectors using ``UpdateWeights`` or update the priority of the replay buffer using ``ReplayBufferTrainer.update_priority`` are examples of operation hooks. They are data-independent (they do not require a ``TensorDict`` input), they are just supposed to be executed once at every iteration (or every N iterations). The hooks provided by TorchRL usually inherit from a common abstract class ``TrainerHookBase``, and all implement three base methods: a ``state_dict`` and ``load_state_dict`` method for checkpointing and a ``register`` method that registers the hook at the default value in the trainer. This method takes a trainer and a module name as input. For instance, the following logging hook is executed every 10 calls to ``"post_optim_log"``: .. code-block:: >>> class LoggingHook(TrainerHookBase): ... def __init__(self): ... self.counter = 0 ... ... def register(self, trainer, name): ... trainer.register_module(self, "logging_hook") ... trainer.register_op("post_optim_log", self) ... ... def save_dict(self): ... return {"counter": self.counter} ... ... def load_state_dict(self, state_dict): ... self.counter = state_dict["counter"] ... ... def __call__(self, batch): ... if self.counter % 10 == 0: ... self.counter += 1 ... out = {"some_value": batch["some_value"].item(), "log_pbar": False} ... else: ... out = None ... self.counter += 1 ... return out Checkpointing ------------- The trainer class and hooks support checkpointing, which can be achieved either using the `torchsnapshot `_ backend or the regular torch backend. This can be controlled via the global variable ``CKPT_BACKEND``: .. code-block:: $ CKPT_BACKEND=torchsnapshot python script.py ``CKPT_BACKEND`` defaults to ``torch``. The advantage of torchsnapshot over pytorch is that it is a more flexible API, which supports distributed checkpointing and also allows users to load tensors from a file stored on disk to a tensor with a physical storage (which pytorch currently does not support). This allows, for instance, to load tensors from and to a replay buffer that would otherwise not fit in memory. When building a trainer, one can provide a path where the checkpoints are to be written. With the ``torchsnapshot`` backend, a directory path is expected, whereas the ``torch`` backend expects a file path (typically a ``.pt`` file). .. code-block:: >>> filepath = "path/to/dir/or/file" >>> trainer = Trainer( ... collector=collector, ... total_frames=total_frames, ... frame_skip=frame_skip, ... loss_module=loss_module, ... optimizer=optimizer, ... save_trainer_file=filepath, ... ) >>> select_keys = SelectKeys(["action", "observation"]) >>> select_keys.register(trainer) >>> # to save to a path >>> trainer.save_trainer(True) >>> # to load from a path >>> trainer.load_from_file(filepath) The ``Trainer.train()`` method can be used to execute the above loop with all of its hooks, although using the :obj:`Trainer` class for its checkpointing capability only is also a perfectly valid use. Trainer and hooks ----------------- .. autosummary:: :toctree: generated/ :template: rl_template.rst BatchSubSampler ClearCudaCache CountFramesLog LogScalar OptimizerHook LogValidationReward ReplayBufferTrainer RewardNormalizer SelectKeys Trainer TrainerHookBase UpdateWeights Algorithm-specific trainers (Experimental) ------------------------------------------ .. warning:: The following trainers are experimental/prototype features. The API may change in future versions. Please report any issues or feedback to help improve these implementations. TorchRL provides high-level, algorithm-specific trainers that combine the modular components into complete training solutions with sensible defaults and comprehensive configuration options. .. currentmodule:: torchrl.trainers.algorithms .. autosummary:: :toctree: generated/ :template: rl_template.rst PPOTrainer PPOTrainer ~~~~~~~~~~ The :class:`~torchrl.trainers.algorithms.PPOTrainer` provides a complete PPO training solution with configurable defaults and a comprehensive configuration system built on Hydra. **Key Features:** - Complete training pipeline with environment setup, data collection, and optimization - Extensive configuration system using dataclasses and Hydra - Built-in logging for rewards, actions, and training statistics - Modular design built on existing TorchRL components - **Minimal code**: Complete SOTA implementation in just ~20 lines! .. warning:: This is an experimental feature. The API may change in future versions. We welcome feedback and contributions to help improve this implementation! **Quick Start - Command Line Interface:** .. code-block:: bash # Basic usage - train PPO on Pendulum-v1 with default settings python sota-implementations/ppo_trainer/train.py **Custom Configuration:** .. code-block:: bash # Override specific parameters via command line python sota-implementations/ppo_trainer/train.py \ trainer.total_frames=2000000 \ training_env.create_env_fn.base_env.env_name=HalfCheetah-v4 \ networks.policy_network.num_cells=[256,256] \ optimizer.lr=0.0003 **Environment Switching:** .. code-block:: bash # Switch to a different environment and logger python sota-implementations/ppo_trainer/train.py \ env=gym \ training_env.create_env_fn.base_env.env_name=Walker2d-v4 \ logger=tensorboard **See All Options:** .. code-block:: bash # View all available configuration options python sota-implementations/ppo_trainer/train.py --help **Configuration Groups:** The PPOTrainer configuration is organized into logical groups: - **Environment**: ``env_cfg__env_name``, ``env_cfg__backend``, ``env_cfg__device`` - **Networks**: ``actor_network__network__num_cells``, ``critic_network__module__num_cells`` - **Training**: ``total_frames``, ``clip_norm``, ``num_epochs``, ``optimizer_cfg__lr`` - **Logging**: ``log_rewards``, ``log_actions``, ``log_observations`` **Working Example:** The `sota-implementations/ppo_trainer/ `_ directory contains a complete, working PPO implementation that demonstrates the simplicity and power of the trainer system: .. code-block:: python import hydra from torchrl.trainers.algorithms.configs import * @hydra.main(config_path="config", config_name="config", version_base="1.1") def main(cfg): trainer = hydra.utils.instantiate(cfg.trainer) trainer.train() if __name__ == "__main__": main() *Complete PPO training with full configurability in ~20 lines!* **Configuration Classes:** The PPOTrainer uses a hierarchical configuration system with these main config classes. .. note:: The configuration system requires Python 3.10+ due to its use of modern type annotation syntax. - **Trainer**: :class:`~torchrl.trainers.algorithms.configs.trainers.PPOTrainerConfig` - **Environment**: :class:`~torchrl.trainers.algorithms.configs.envs_libs.GymEnvConfig`, :class:`~torchrl.trainers.algorithms.configs.envs.BatchedEnvConfig` - **Networks**: :class:`~torchrl.trainers.algorithms.configs.modules.MLPConfig`, :class:`~torchrl.trainers.algorithms.configs.modules.TanhNormalModelConfig` - **Data**: :class:`~torchrl.trainers.algorithms.configs.data.TensorDictReplayBufferConfig`, :class:`~torchrl.trainers.algorithms.configs.collectors.MultiaSyncDataCollectorConfig` - **Objectives**: :class:`~torchrl.trainers.algorithms.configs.objectives.PPOLossConfig` - **Optimizers**: :class:`~torchrl.trainers.algorithms.configs.utils.AdamConfig`, :class:`~torchrl.trainers.algorithms.configs.utils.AdamWConfig` - **Logging**: :class:`~torchrl.trainers.algorithms.configs.logging.WandbLoggerConfig`, :class:`~torchrl.trainers.algorithms.configs.logging.TensorboardLoggerConfig` **Future Development:** This is the first of many planned algorithm-specific trainers. Future releases will include: - Additional algorithms: SAC, TD3, DQN, A2C, and more - Full integration of all TorchRL components within the configuration system - Enhanced configuration validation and error reporting - Distributed training support for high-level trainers See the complete `configuration system documentation `_ for all available options. Builders -------- .. currentmodule:: torchrl.trainers.helpers .. autosummary:: :toctree: generated/ :template: rl_template_fun.rst make_collector_offpolicy make_collector_onpolicy make_dqn_loss make_replay_buffer make_target_updater make_trainer parallel_env_constructor sync_async_collector sync_sync_collector transformed_env_constructor Utils ----- .. autosummary:: :toctree: generated/ :template: rl_template_fun.rst correct_for_frame_skip get_stats_random_rollout Loggers ------- .. _ref_loggers: .. currentmodule:: torchrl.record.loggers .. autosummary:: :toctree: generated/ :template: rl_template_fun.rst Logger csv.CSVLogger mlflow.MLFlowLogger tensorboard.TensorboardLogger wandb.WandbLogger get_logger generate_exp_name Recording utils --------------- Recording utils are detailed :ref:`here `. .. currentmodule:: torchrl.record .. autosummary:: :toctree: generated/ :template: rl_template_fun.rst VideoRecorder TensorDictRecorder PixelRenderTransform