Shortcuts

Source code for torchrl.trainers.algorithms.configs.trainers

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from dataclasses import dataclass
from typing import Any

import torch

from torchrl.collectors import DataCollectorBase
from torchrl.objectives.common import LossModule
from torchrl.trainers.algorithms.configs.common import ConfigBase
from torchrl.trainers.algorithms.ppo import PPOTrainer


[docs]@dataclass class TrainerConfig(ConfigBase): """Base configuration class for trainers.""" def __post_init__(self) -> None: """Post-initialization hook for trainer configurations."""
[docs]@dataclass class PPOTrainerConfig(TrainerConfig): """Configuration class for PPO (Proximal Policy Optimization) trainer. This class defines the configuration parameters for creating a PPO trainer, including both required and optional fields with sensible defaults. """ collector: Any total_frames: int optim_steps_per_batch: int | None loss_module: Any optimizer: Any logger: Any save_trainer_file: Any replay_buffer: Any frame_skip: int = 1 clip_grad_norm: bool = True clip_norm: float | None = None progress_bar: bool = True seed: int | None = None save_trainer_interval: int = 10000 log_interval: int = 10000 create_env_fn: Any = None actor_network: Any = None critic_network: Any = None num_epochs: int = 4 _target_: str = "torchrl.trainers.algorithms.configs.trainers._make_ppo_trainer" def __post_init__(self) -> None: """Post-initialization hook for PPO trainer configuration.""" super().__post_init__()
def _make_ppo_trainer(*args, **kwargs) -> PPOTrainer: from torchrl.trainers.trainers import Logger collector = kwargs.pop("collector") total_frames = kwargs.pop("total_frames") if total_frames is None: total_frames = collector.total_frames frame_skip = kwargs.pop("frame_skip", 1) optim_steps_per_batch = kwargs.pop("optim_steps_per_batch", 1) loss_module = kwargs.pop("loss_module") optimizer = kwargs.pop("optimizer") logger = kwargs.pop("logger") clip_grad_norm = kwargs.pop("clip_grad_norm", True) clip_norm = kwargs.pop("clip_norm") progress_bar = kwargs.pop("progress_bar", True) replay_buffer = kwargs.pop("replay_buffer") save_trainer_interval = kwargs.pop("save_trainer_interval", 10000) log_interval = kwargs.pop("log_interval", 10000) save_trainer_file = kwargs.pop("save_trainer_file") seed = kwargs.pop("seed") actor_network = kwargs.pop("actor_network") critic_network = kwargs.pop("critic_network") create_env_fn = kwargs.pop("create_env_fn") num_epochs = kwargs.pop("num_epochs", 4) # Instantiate networks first if actor_network is not None: actor_network = actor_network() if critic_network is not None: critic_network = critic_network() if not isinstance(collector, DataCollectorBase): # then it's a partial config collector = collector(create_env_fn=create_env_fn, policy=actor_network) if not isinstance(loss_module, LossModule): # then it's a partial config loss_module = loss_module( actor_network=actor_network, critic_network=critic_network ) if not isinstance(optimizer, torch.optim.Optimizer): # then it's a partial config optimizer = optimizer(params=loss_module.parameters()) # Quick instance checks if not isinstance(collector, DataCollectorBase): raise ValueError( f"collector must be a DataCollectorBase, got {type(collector)}" ) if not isinstance(loss_module, LossModule): raise ValueError(f"loss_module must be a LossModule, got {type(loss_module)}") if not isinstance(optimizer, torch.optim.Optimizer): raise ValueError( f"optimizer must be a torch.optim.Optimizer, got {type(optimizer)}" ) if not isinstance(logger, Logger) and logger is not None: raise ValueError(f"logger must be a Logger, got {type(logger)}") return PPOTrainer( collector=collector, total_frames=total_frames, frame_skip=frame_skip, optim_steps_per_batch=optim_steps_per_batch, loss_module=loss_module, optimizer=optimizer, logger=logger, clip_grad_norm=clip_grad_norm, clip_norm=clip_norm, progress_bar=progress_bar, seed=seed, save_trainer_interval=save_trainer_interval, log_interval=log_interval, save_trainer_file=save_trainer_file, replay_buffer=replay_buffer, num_epochs=num_epochs, )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources