Source code for torchrl.trainers.algorithms.configs.trainers
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import torch
from torchrl.collectors import DataCollectorBase
from torchrl.objectives.common import LossModule
from torchrl.trainers.algorithms.configs.common import ConfigBase
from torchrl.trainers.algorithms.ppo import PPOTrainer
[docs]@dataclass
class TrainerConfig(ConfigBase):
"""Base configuration class for trainers."""
def __post_init__(self) -> None:
"""Post-initialization hook for trainer configurations."""
[docs]@dataclass
class PPOTrainerConfig(TrainerConfig):
"""Configuration class for PPO (Proximal Policy Optimization) trainer.
This class defines the configuration parameters for creating a PPO trainer,
including both required and optional fields with sensible defaults.
"""
collector: Any
total_frames: int
optim_steps_per_batch: int | None
loss_module: Any
optimizer: Any
logger: Any
save_trainer_file: Any
replay_buffer: Any
frame_skip: int = 1
clip_grad_norm: bool = True
clip_norm: float | None = None
progress_bar: bool = True
seed: int | None = None
save_trainer_interval: int = 10000
log_interval: int = 10000
create_env_fn: Any = None
actor_network: Any = None
critic_network: Any = None
num_epochs: int = 4
_target_: str = "torchrl.trainers.algorithms.configs.trainers._make_ppo_trainer"
def __post_init__(self) -> None:
"""Post-initialization hook for PPO trainer configuration."""
super().__post_init__()
def _make_ppo_trainer(*args, **kwargs) -> PPOTrainer:
from torchrl.trainers.trainers import Logger
collector = kwargs.pop("collector")
total_frames = kwargs.pop("total_frames")
if total_frames is None:
total_frames = collector.total_frames
frame_skip = kwargs.pop("frame_skip", 1)
optim_steps_per_batch = kwargs.pop("optim_steps_per_batch", 1)
loss_module = kwargs.pop("loss_module")
optimizer = kwargs.pop("optimizer")
logger = kwargs.pop("logger")
clip_grad_norm = kwargs.pop("clip_grad_norm", True)
clip_norm = kwargs.pop("clip_norm")
progress_bar = kwargs.pop("progress_bar", True)
replay_buffer = kwargs.pop("replay_buffer")
save_trainer_interval = kwargs.pop("save_trainer_interval", 10000)
log_interval = kwargs.pop("log_interval", 10000)
save_trainer_file = kwargs.pop("save_trainer_file")
seed = kwargs.pop("seed")
actor_network = kwargs.pop("actor_network")
critic_network = kwargs.pop("critic_network")
create_env_fn = kwargs.pop("create_env_fn")
num_epochs = kwargs.pop("num_epochs", 4)
# Instantiate networks first
if actor_network is not None:
actor_network = actor_network()
if critic_network is not None:
critic_network = critic_network()
if not isinstance(collector, DataCollectorBase):
# then it's a partial config
collector = collector(create_env_fn=create_env_fn, policy=actor_network)
if not isinstance(loss_module, LossModule):
# then it's a partial config
loss_module = loss_module(
actor_network=actor_network, critic_network=critic_network
)
if not isinstance(optimizer, torch.optim.Optimizer):
# then it's a partial config
optimizer = optimizer(params=loss_module.parameters())
# Quick instance checks
if not isinstance(collector, DataCollectorBase):
raise ValueError(
f"collector must be a DataCollectorBase, got {type(collector)}"
)
if not isinstance(loss_module, LossModule):
raise ValueError(f"loss_module must be a LossModule, got {type(loss_module)}")
if not isinstance(optimizer, torch.optim.Optimizer):
raise ValueError(
f"optimizer must be a torch.optim.Optimizer, got {type(optimizer)}"
)
if not isinstance(logger, Logger) and logger is not None:
raise ValueError(f"logger must be a Logger, got {type(logger)}")
return PPOTrainer(
collector=collector,
total_frames=total_frames,
frame_skip=frame_skip,
optim_steps_per_batch=optim_steps_per_batch,
loss_module=loss_module,
optimizer=optimizer,
logger=logger,
clip_grad_norm=clip_grad_norm,
clip_norm=clip_norm,
progress_bar=progress_bar,
seed=seed,
save_trainer_interval=save_trainer_interval,
log_interval=log_interval,
save_trainer_file=save_trainer_file,
replay_buffer=replay_buffer,
num_epochs=num_epochs,
)