Shortcuts

torchrl.trainers.algorithms.configs.trainers.PPOTrainerConfig

class torchrl.trainers.algorithms.configs.trainers.PPOTrainerConfig(collector: Any, total_frames: int, optim_steps_per_batch: int | None, loss_module: Any, optimizer: Any, logger: Any, save_trainer_file: Any, replay_buffer: Any, frame_skip: int = 1, clip_grad_norm: bool = True, clip_norm: float | None = None, progress_bar: bool = True, seed: int | None = None, save_trainer_interval: int = 10000, log_interval: int = 10000, create_env_fn: Any = None, actor_network: Any = None, critic_network: Any = None, num_epochs: int = 4, async_collection: bool = False, add_gae: bool = True, gae: Any = None, weight_update_map: dict[str, str] | None = None, log_timings: bool = False, _target_: str = 'torchrl.trainers.algorithms.configs.trainers._make_ppo_trainer')[source]

Configuration class for PPO (Proximal Policy Optimization) trainer.

This class defines the configuration parameters for creating a PPO trainer, including both required and optional fields with sensible defaults.

Parameters:
  • collector – The data collector for gathering training data.

  • total_frames – Total number of frames to train for.

  • optim_steps_per_batch – Number of optimization steps per batch.

  • loss_module – The loss module for computing policy and value losses.

  • optimizer – The optimizer for training.

  • logger – Logger for tracking training metrics.

  • save_trainer_file – File path for saving trainer state.

  • replay_buffer – Replay buffer for storing data.

  • frame_skip – Frame skip value for the environment. Default: 1.

  • clip_grad_norm – Whether to clip gradient norms. Default: True.

  • clip_norm – Maximum gradient norm value.

  • progress_bar – Whether to show a progress bar. Default: True.

  • seed – Random seed for reproducibility.

  • save_trainer_interval – Interval for saving trainer state. Default: 10000.

  • log_interval – Interval for logging metrics. Default: 10000.

  • create_env_fn – Environment creation function.

  • actor_network – Actor network configuration.

  • critic_network – Critic network configuration.

  • num_epochs – Number of epochs per batch. Default: 4.

  • async_collection – Whether to use async collection. Default: False.

  • add_gae – Whether to add GAE computation. Default: True.

  • gae – Custom GAE module configuration.

  • weight_update_map

    Mapping from collector destination paths to trainer source paths. Required if collector has weight_sync_schemes configured. Example: {“policy”: “loss_module.actor_network”,

    ”replay_buffer.transforms[0]”: “loss_module.critic_network”}

  • log_timings – Whether to automatically log timing information for all hooks. If True, timing metrics will be logged to the logger (e.g., wandb, tensorboard) with prefix “time/” (e.g., “time/hook/UpdateWeights”). Default: False.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources