Shortcuts

Source code for torchrl.trainers.algorithms.configs.trainers

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from dataclasses import dataclass
from typing import Any

import torch
from tensordict.nn import TensorDictModuleBase

from torchrl.collectors import DataCollectorBase
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import TargetNetUpdater
from torchrl.objectives.value.advantages import GAE
from torchrl.trainers.algorithms.configs.common import ConfigBase
from torchrl.trainers.algorithms.ppo import PPOTrainer
from torchrl.trainers.algorithms.sac import SACTrainer


[docs]@dataclass class TrainerConfig(ConfigBase): """Base configuration class for trainers.""" def __post_init__(self) -> None: """Post-initialization hook for trainer configurations."""
@dataclass class SACTrainerConfig(TrainerConfig): """Configuration class for SAC (Soft Actor Critic) trainer. This class defines the configuration parameters for creating a SAC trainer, including both required and optional fields with sensible defaults. """ collector: Any total_frames: int optim_steps_per_batch: int | None loss_module: Any optimizer: Any logger: Any save_trainer_file: Any replay_buffer: Any frame_skip: int = 1 clip_grad_norm: bool = True clip_norm: float | None = None progress_bar: bool = True seed: int | None = None save_trainer_interval: int = 10000 log_interval: int = 10000 create_env_fn: Any = None actor_network: Any = None critic_network: Any = None target_net_updater: Any = None async_collection: bool = False log_timings: bool = False _target_: str = "torchrl.trainers.algorithms.configs.trainers._make_sac_trainer" def __post_init__(self) -> None: """Post-initialization hook for SAC trainer configuration.""" super().__post_init__() def _make_sac_trainer(*args, **kwargs) -> SACTrainer: from torchrl.trainers.trainers import Logger collector = kwargs.pop("collector") total_frames = kwargs.pop("total_frames") if total_frames is None: total_frames = collector.total_frames frame_skip = kwargs.pop("frame_skip", 1) optim_steps_per_batch = kwargs.pop("optim_steps_per_batch", 1) loss_module = kwargs.pop("loss_module") optimizer = kwargs.pop("optimizer") logger = kwargs.pop("logger") clip_grad_norm = kwargs.pop("clip_grad_norm", True) clip_norm = kwargs.pop("clip_norm") progress_bar = kwargs.pop("progress_bar", True) replay_buffer = kwargs.pop("replay_buffer") save_trainer_interval = kwargs.pop("save_trainer_interval", 10000) log_interval = kwargs.pop("log_interval", 10000) save_trainer_file = kwargs.pop("save_trainer_file") seed = kwargs.pop("seed") actor_network = kwargs.pop("actor_network") critic_network = kwargs.pop("critic_network") kwargs.pop("create_env_fn") target_net_updater = kwargs.pop("target_net_updater") async_collection = kwargs.pop("async_collection", False) log_timings = kwargs.pop("log_timings", False) # Instantiate networks first if actor_network is not None: actor_network = actor_network() if critic_network is not None: critic_network = critic_network() if not isinstance(collector, DataCollectorBase): # then it's a partial config if not async_collection: collector = collector() elif replay_buffer is not None: collector = collector(replay_buffer=replay_buffer) elif getattr(collector, "replay_buffer", None) is None: if async_collection and ( collector.replay_buffer is None or replay_buffer is None ): raise ValueError( "replay_buffer must be provided when async_collection is True" ) if not isinstance(loss_module, LossModule): # then it's a partial config loss_module = loss_module( actor_network=actor_network, critic_network=critic_network ) if not isinstance(target_net_updater, TargetNetUpdater): # target_net_updater must be a partial taking the loss as input target_net_updater = target_net_updater(loss_module) if not isinstance(optimizer, torch.optim.Optimizer): # then it's a partial config optimizer = optimizer(params=loss_module.parameters()) # Quick instance checks if not isinstance(collector, DataCollectorBase): raise ValueError( f"collector must be a DataCollectorBase, got {type(collector)}" ) if not isinstance(loss_module, LossModule): raise ValueError(f"loss_module must be a LossModule, got {type(loss_module)}") if not isinstance(optimizer, torch.optim.Optimizer): raise ValueError( f"optimizer must be a torch.optim.Optimizer, got {type(optimizer)}" ) if not isinstance(logger, Logger) and logger is not None: raise ValueError(f"logger must be a Logger, got {type(logger)}") return SACTrainer( collector=collector, total_frames=total_frames, frame_skip=frame_skip, optim_steps_per_batch=optim_steps_per_batch, loss_module=loss_module, optimizer=optimizer, logger=logger, clip_grad_norm=clip_grad_norm, clip_norm=clip_norm, progress_bar=progress_bar, seed=seed, save_trainer_interval=save_trainer_interval, log_interval=log_interval, save_trainer_file=save_trainer_file, replay_buffer=replay_buffer, target_net_updater=target_net_updater, async_collection=async_collection, log_timings=log_timings, )
[docs]@dataclass class PPOTrainerConfig(TrainerConfig): """Configuration class for PPO (Proximal Policy Optimization) trainer. This class defines the configuration parameters for creating a PPO trainer, including both required and optional fields with sensible defaults. Args: collector: The data collector for gathering training data. total_frames: Total number of frames to train for. optim_steps_per_batch: Number of optimization steps per batch. loss_module: The loss module for computing policy and value losses. optimizer: The optimizer for training. logger: Logger for tracking training metrics. save_trainer_file: File path for saving trainer state. replay_buffer: Replay buffer for storing data. frame_skip: Frame skip value for the environment. Default: 1. clip_grad_norm: Whether to clip gradient norms. Default: True. clip_norm: Maximum gradient norm value. progress_bar: Whether to show a progress bar. Default: True. seed: Random seed for reproducibility. save_trainer_interval: Interval for saving trainer state. Default: 10000. log_interval: Interval for logging metrics. Default: 10000. create_env_fn: Environment creation function. actor_network: Actor network configuration. critic_network: Critic network configuration. num_epochs: Number of epochs per batch. Default: 4. async_collection: Whether to use async collection. Default: False. add_gae: Whether to add GAE computation. Default: True. gae: Custom GAE module configuration. weight_update_map: Mapping from collector destination paths to trainer source paths. Required if collector has weight_sync_schemes configured. Example: {"policy": "loss_module.actor_network", "replay_buffer.transforms[0]": "loss_module.critic_network"} log_timings: Whether to automatically log timing information for all hooks. If True, timing metrics will be logged to the logger (e.g., wandb, tensorboard) with prefix "time/" (e.g., "time/hook/UpdateWeights"). Default: False. """ collector: Any total_frames: int optim_steps_per_batch: int | None loss_module: Any optimizer: Any logger: Any save_trainer_file: Any replay_buffer: Any frame_skip: int = 1 clip_grad_norm: bool = True clip_norm: float | None = None progress_bar: bool = True seed: int | None = None save_trainer_interval: int = 10000 log_interval: int = 10000 create_env_fn: Any = None actor_network: Any = None critic_network: Any = None num_epochs: int = 4 async_collection: bool = False add_gae: bool = True gae: Any = None weight_update_map: dict[str, str] | None = None log_timings: bool = False _target_: str = "torchrl.trainers.algorithms.configs.trainers._make_ppo_trainer" def __post_init__(self) -> None: """Post-initialization hook for PPO trainer configuration.""" super().__post_init__()
def _make_ppo_trainer(*args, **kwargs) -> PPOTrainer: from torchrl.trainers.trainers import Logger collector = kwargs.pop("collector") total_frames = kwargs.pop("total_frames") if total_frames is None: total_frames = collector.total_frames frame_skip = kwargs.pop("frame_skip", 1) optim_steps_per_batch = kwargs.pop("optim_steps_per_batch", 1) loss_module = kwargs.pop("loss_module") optimizer = kwargs.pop("optimizer") logger = kwargs.pop("logger") clip_grad_norm = kwargs.pop("clip_grad_norm", True) clip_norm = kwargs.pop("clip_norm") progress_bar = kwargs.pop("progress_bar", True) replay_buffer = kwargs.pop("replay_buffer") save_trainer_interval = kwargs.pop("save_trainer_interval", 10000) log_interval = kwargs.pop("log_interval", 10000) save_trainer_file = kwargs.pop("save_trainer_file") seed = kwargs.pop("seed") actor_network = kwargs.pop("actor_network") critic_network = kwargs.pop("critic_network") add_gae = kwargs.pop("add_gae", True) gae = kwargs.pop("gae") create_env_fn = kwargs.pop("create_env_fn") weight_update_map = kwargs.pop("weight_update_map", None) log_timings = kwargs.pop("log_timings", False) if create_env_fn is not None: # could be referenced somewhere else, no need to raise an error pass num_epochs = kwargs.pop("num_epochs", 4) async_collection = kwargs.pop("async_collection", False) # Instantiate networks first if actor_network is not None: actor_network = actor_network() if critic_network is not None: critic_network = critic_network() else: critic_network = loss_module.critic_network # Ensure GAE in replay buffer uses the same value network instance as loss module # This fixes the issue where Hydra instantiates separate instances of value_model if ( replay_buffer is not None and hasattr(replay_buffer, "_transform") and len(replay_buffer._transform) > 1 and hasattr(replay_buffer._transform[1], "module") and hasattr(replay_buffer._transform[1].module, "value_network") ): replay_buffer._transform[1].module.value_network = critic_network if not isinstance(collector, DataCollectorBase): # then it's a partial config if not async_collection: collector = collector() else: collector = collector(replay_buffer=replay_buffer) elif async_collection and getattr(collector, "replay_buffer", None) is None: raise RuntimeError( "replay_buffer must be provided when async_collection is True" ) if not isinstance(loss_module, LossModule): # then it's a partial config loss_module = loss_module( actor_network=actor_network, critic_network=critic_network ) if not isinstance(optimizer, torch.optim.Optimizer): # then it's a partial config optimizer = optimizer(params=loss_module.parameters()) # Quick instance checks if not isinstance(collector, DataCollectorBase): raise ValueError( f"collector must be a DataCollectorBase, got {type(collector)}" ) if not isinstance(loss_module, LossModule): raise ValueError(f"loss_module must be a LossModule, got {type(loss_module)}") if not isinstance(optimizer, torch.optim.Optimizer): raise ValueError( f"optimizer must be a torch.optim.Optimizer, got {type(optimizer)}" ) if not isinstance(logger, Logger) and logger is not None: raise ValueError(f"logger must be a Logger, got {type(logger)}") # instantiate gae if it is a partial config if not isinstance(gae, (GAE, TensorDictModuleBase)) and gae is not None: gae = gae() return PPOTrainer( collector=collector, total_frames=total_frames, frame_skip=frame_skip, optim_steps_per_batch=optim_steps_per_batch, loss_module=loss_module, optimizer=optimizer, logger=logger, clip_grad_norm=clip_grad_norm, clip_norm=clip_norm, progress_bar=progress_bar, seed=seed, save_trainer_interval=save_trainer_interval, log_interval=log_interval, save_trainer_file=save_trainer_file, replay_buffer=replay_buffer, num_epochs=num_epochs, async_collection=async_collection, add_gae=add_gae, gae=gae, weight_update_map=weight_update_map, log_timings=log_timings, )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources