Source code for torchrl.trainers.algorithms.configs.trainers
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
import torch
from torchrl.collectors import DataCollectorBase
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import TargetNetUpdater
from torchrl.trainers.algorithms.configs.common import ConfigBase
from torchrl.trainers.algorithms.ppo import PPOTrainer
from torchrl.trainers.algorithms.sac import SACTrainer
[docs]@dataclass
class TrainerConfig(ConfigBase):
"""Base configuration class for trainers."""
def __post_init__(self) -> None:
"""Post-initialization hook for trainer configurations."""
@dataclass
class SACTrainerConfig(TrainerConfig):
"""Configuration class for SAC (Soft Actor Critic) trainer.
This class defines the configuration parameters for creating a SAC trainer,
including both required and optional fields with sensible defaults.
"""
collector: Any
total_frames: int
optim_steps_per_batch: int | None
loss_module: Any
optimizer: Any
logger: Any
save_trainer_file: Any
replay_buffer: Any
frame_skip: int = 1
clip_grad_norm: bool = True
clip_norm: float | None = None
progress_bar: bool = True
seed: int | None = None
save_trainer_interval: int = 10000
log_interval: int = 10000
create_env_fn: Any = None
actor_network: Any = None
critic_network: Any = None
target_net_updater: Any = None
async_collection: bool = False
_target_: str = "torchrl.trainers.algorithms.configs.trainers._make_sac_trainer"
def __post_init__(self) -> None:
"""Post-initialization hook for SAC trainer configuration."""
super().__post_init__()
def _make_sac_trainer(*args, **kwargs) -> SACTrainer:
from torchrl.trainers.trainers import Logger
collector = kwargs.pop("collector")
total_frames = kwargs.pop("total_frames")
if total_frames is None:
total_frames = collector.total_frames
frame_skip = kwargs.pop("frame_skip", 1)
optim_steps_per_batch = kwargs.pop("optim_steps_per_batch", 1)
loss_module = kwargs.pop("loss_module")
optimizer = kwargs.pop("optimizer")
logger = kwargs.pop("logger")
clip_grad_norm = kwargs.pop("clip_grad_norm", True)
clip_norm = kwargs.pop("clip_norm")
progress_bar = kwargs.pop("progress_bar", True)
replay_buffer = kwargs.pop("replay_buffer")
save_trainer_interval = kwargs.pop("save_trainer_interval", 10000)
log_interval = kwargs.pop("log_interval", 10000)
save_trainer_file = kwargs.pop("save_trainer_file")
seed = kwargs.pop("seed")
actor_network = kwargs.pop("actor_network")
critic_network = kwargs.pop("critic_network")
kwargs.pop("create_env_fn")
target_net_updater = kwargs.pop("target_net_updater")
async_collection = kwargs.pop("async_collection", False)
# Instantiate networks first
if actor_network is not None:
actor_network = actor_network()
if critic_network is not None:
critic_network = critic_network()
if not isinstance(collector, DataCollectorBase):
# then it's a partial config
if not async_collection:
collector = collector()
elif replay_buffer is not None:
collector = collector(replay_buffer=replay_buffer)
elif getattr(collector, "replay_buffer", None) is None:
if async_collection and (
collector.replay_buffer is None or replay_buffer is None
):
raise ValueError(
"replay_buffer must be provided when async_collection is True"
)
if not isinstance(loss_module, LossModule):
# then it's a partial config
loss_module = loss_module(
actor_network=actor_network, critic_network=critic_network
)
if not isinstance(target_net_updater, TargetNetUpdater):
# target_net_updater must be a partial taking the loss as input
target_net_updater = target_net_updater(loss_module)
if not isinstance(optimizer, torch.optim.Optimizer):
# then it's a partial config
optimizer = optimizer(params=loss_module.parameters())
# Quick instance checks
if not isinstance(collector, DataCollectorBase):
raise ValueError(
f"collector must be a DataCollectorBase, got {type(collector)}"
)
if not isinstance(loss_module, LossModule):
raise ValueError(f"loss_module must be a LossModule, got {type(loss_module)}")
if not isinstance(optimizer, torch.optim.Optimizer):
raise ValueError(
f"optimizer must be a torch.optim.Optimizer, got {type(optimizer)}"
)
if not isinstance(logger, Logger) and logger is not None:
raise ValueError(f"logger must be a Logger, got {type(logger)}")
return SACTrainer(
collector=collector,
total_frames=total_frames,
frame_skip=frame_skip,
optim_steps_per_batch=optim_steps_per_batch,
loss_module=loss_module,
optimizer=optimizer,
logger=logger,
clip_grad_norm=clip_grad_norm,
clip_norm=clip_norm,
progress_bar=progress_bar,
seed=seed,
save_trainer_interval=save_trainer_interval,
log_interval=log_interval,
save_trainer_file=save_trainer_file,
replay_buffer=replay_buffer,
target_net_updater=target_net_updater,
async_collection=async_collection,
)
[docs]@dataclass
class PPOTrainerConfig(TrainerConfig):
"""Configuration class for PPO (Proximal Policy Optimization) trainer.
This class defines the configuration parameters for creating a PPO trainer,
including both required and optional fields with sensible defaults.
"""
collector: Any
total_frames: int
optim_steps_per_batch: int | None
loss_module: Any
optimizer: Any
logger: Any
save_trainer_file: Any
replay_buffer: Any
frame_skip: int = 1
clip_grad_norm: bool = True
clip_norm: float | None = None
progress_bar: bool = True
seed: int | None = None
save_trainer_interval: int = 10000
log_interval: int = 10000
create_env_fn: Any = None
actor_network: Any = None
critic_network: Any = None
num_epochs: int = 4
async_collection: bool = False
_target_: str = "torchrl.trainers.algorithms.configs.trainers._make_ppo_trainer"
def __post_init__(self) -> None:
"""Post-initialization hook for PPO trainer configuration."""
super().__post_init__()
def _make_ppo_trainer(*args, **kwargs) -> PPOTrainer:
from torchrl.trainers.trainers import Logger
collector = kwargs.pop("collector")
total_frames = kwargs.pop("total_frames")
if total_frames is None:
total_frames = collector.total_frames
frame_skip = kwargs.pop("frame_skip", 1)
optim_steps_per_batch = kwargs.pop("optim_steps_per_batch", 1)
loss_module = kwargs.pop("loss_module")
optimizer = kwargs.pop("optimizer")
logger = kwargs.pop("logger")
clip_grad_norm = kwargs.pop("clip_grad_norm", True)
clip_norm = kwargs.pop("clip_norm")
progress_bar = kwargs.pop("progress_bar", True)
replay_buffer = kwargs.pop("replay_buffer")
save_trainer_interval = kwargs.pop("save_trainer_interval", 10000)
log_interval = kwargs.pop("log_interval", 10000)
save_trainer_file = kwargs.pop("save_trainer_file")
seed = kwargs.pop("seed")
actor_network = kwargs.pop("actor_network")
critic_network = kwargs.pop("critic_network")
create_env_fn = kwargs.pop("create_env_fn")
if create_env_fn is not None:
# could be referenced somewhere else, no need to raise an error
pass
num_epochs = kwargs.pop("num_epochs", 4)
async_collection = kwargs.pop("async_collection", False)
# Instantiate networks first
if actor_network is not None:
actor_network = actor_network()
if critic_network is not None:
critic_network = critic_network()
if not isinstance(collector, DataCollectorBase):
# then it's a partial config
if not async_collection:
collector = collector()
else:
collector = collector(replay_buffer=replay_buffer)
elif async_collection and getattr(collector, "replay_buffer", None) is None:
raise RuntimeError(
"replay_buffer must be provided when async_collection is True"
)
if not isinstance(loss_module, LossModule):
# then it's a partial config
loss_module = loss_module(
actor_network=actor_network, critic_network=critic_network
)
if not isinstance(optimizer, torch.optim.Optimizer):
# then it's a partial config
optimizer = optimizer(params=loss_module.parameters())
# Quick instance checks
if not isinstance(collector, DataCollectorBase):
raise ValueError(
f"collector must be a DataCollectorBase, got {type(collector)}"
)
if not isinstance(loss_module, LossModule):
raise ValueError(f"loss_module must be a LossModule, got {type(loss_module)}")
if not isinstance(optimizer, torch.optim.Optimizer):
raise ValueError(
f"optimizer must be a torch.optim.Optimizer, got {type(optimizer)}"
)
if not isinstance(logger, Logger) and logger is not None:
raise ValueError(f"logger must be a Logger, got {type(logger)}")
return PPOTrainer(
collector=collector,
total_frames=total_frames,
frame_skip=frame_skip,
optim_steps_per_batch=optim_steps_per_batch,
loss_module=loss_module,
optimizer=optimizer,
logger=logger,
clip_grad_norm=clip_grad_norm,
clip_norm=clip_norm,
progress_bar=progress_bar,
seed=seed,
save_trainer_interval=save_trainer_interval,
log_interval=log_interval,
save_trainer_file=save_trainer_file,
replay_buffer=replay_buffer,
num_epochs=num_epochs,
async_collection=async_collection,
)