Shortcuts

Source code for torchrl.trainers.algorithms.configs.trainers

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from dataclasses import dataclass
from typing import Any

import torch

from torchrl.collectors import DataCollectorBase
from torchrl.objectives.common import LossModule
from torchrl.objectives.utils import TargetNetUpdater
from torchrl.trainers.algorithms.configs.common import ConfigBase
from torchrl.trainers.algorithms.ppo import PPOTrainer
from torchrl.trainers.algorithms.sac import SACTrainer


[docs]@dataclass class TrainerConfig(ConfigBase): """Base configuration class for trainers.""" def __post_init__(self) -> None: """Post-initialization hook for trainer configurations."""
@dataclass class SACTrainerConfig(TrainerConfig): """Configuration class for SAC (Soft Actor Critic) trainer. This class defines the configuration parameters for creating a SAC trainer, including both required and optional fields with sensible defaults. """ collector: Any total_frames: int optim_steps_per_batch: int | None loss_module: Any optimizer: Any logger: Any save_trainer_file: Any replay_buffer: Any frame_skip: int = 1 clip_grad_norm: bool = True clip_norm: float | None = None progress_bar: bool = True seed: int | None = None save_trainer_interval: int = 10000 log_interval: int = 10000 create_env_fn: Any = None actor_network: Any = None critic_network: Any = None target_net_updater: Any = None async_collection: bool = False _target_: str = "torchrl.trainers.algorithms.configs.trainers._make_sac_trainer" def __post_init__(self) -> None: """Post-initialization hook for SAC trainer configuration.""" super().__post_init__() def _make_sac_trainer(*args, **kwargs) -> SACTrainer: from torchrl.trainers.trainers import Logger collector = kwargs.pop("collector") total_frames = kwargs.pop("total_frames") if total_frames is None: total_frames = collector.total_frames frame_skip = kwargs.pop("frame_skip", 1) optim_steps_per_batch = kwargs.pop("optim_steps_per_batch", 1) loss_module = kwargs.pop("loss_module") optimizer = kwargs.pop("optimizer") logger = kwargs.pop("logger") clip_grad_norm = kwargs.pop("clip_grad_norm", True) clip_norm = kwargs.pop("clip_norm") progress_bar = kwargs.pop("progress_bar", True) replay_buffer = kwargs.pop("replay_buffer") save_trainer_interval = kwargs.pop("save_trainer_interval", 10000) log_interval = kwargs.pop("log_interval", 10000) save_trainer_file = kwargs.pop("save_trainer_file") seed = kwargs.pop("seed") actor_network = kwargs.pop("actor_network") critic_network = kwargs.pop("critic_network") kwargs.pop("create_env_fn") target_net_updater = kwargs.pop("target_net_updater") async_collection = kwargs.pop("async_collection", False) # Instantiate networks first if actor_network is not None: actor_network = actor_network() if critic_network is not None: critic_network = critic_network() if not isinstance(collector, DataCollectorBase): # then it's a partial config if not async_collection: collector = collector() elif replay_buffer is not None: collector = collector(replay_buffer=replay_buffer) elif getattr(collector, "replay_buffer", None) is None: if async_collection and ( collector.replay_buffer is None or replay_buffer is None ): raise ValueError( "replay_buffer must be provided when async_collection is True" ) if not isinstance(loss_module, LossModule): # then it's a partial config loss_module = loss_module( actor_network=actor_network, critic_network=critic_network ) if not isinstance(target_net_updater, TargetNetUpdater): # target_net_updater must be a partial taking the loss as input target_net_updater = target_net_updater(loss_module) if not isinstance(optimizer, torch.optim.Optimizer): # then it's a partial config optimizer = optimizer(params=loss_module.parameters()) # Quick instance checks if not isinstance(collector, DataCollectorBase): raise ValueError( f"collector must be a DataCollectorBase, got {type(collector)}" ) if not isinstance(loss_module, LossModule): raise ValueError(f"loss_module must be a LossModule, got {type(loss_module)}") if not isinstance(optimizer, torch.optim.Optimizer): raise ValueError( f"optimizer must be a torch.optim.Optimizer, got {type(optimizer)}" ) if not isinstance(logger, Logger) and logger is not None: raise ValueError(f"logger must be a Logger, got {type(logger)}") return SACTrainer( collector=collector, total_frames=total_frames, frame_skip=frame_skip, optim_steps_per_batch=optim_steps_per_batch, loss_module=loss_module, optimizer=optimizer, logger=logger, clip_grad_norm=clip_grad_norm, clip_norm=clip_norm, progress_bar=progress_bar, seed=seed, save_trainer_interval=save_trainer_interval, log_interval=log_interval, save_trainer_file=save_trainer_file, replay_buffer=replay_buffer, target_net_updater=target_net_updater, async_collection=async_collection, )
[docs]@dataclass class PPOTrainerConfig(TrainerConfig): """Configuration class for PPO (Proximal Policy Optimization) trainer. This class defines the configuration parameters for creating a PPO trainer, including both required and optional fields with sensible defaults. """ collector: Any total_frames: int optim_steps_per_batch: int | None loss_module: Any optimizer: Any logger: Any save_trainer_file: Any replay_buffer: Any frame_skip: int = 1 clip_grad_norm: bool = True clip_norm: float | None = None progress_bar: bool = True seed: int | None = None save_trainer_interval: int = 10000 log_interval: int = 10000 create_env_fn: Any = None actor_network: Any = None critic_network: Any = None num_epochs: int = 4 async_collection: bool = False _target_: str = "torchrl.trainers.algorithms.configs.trainers._make_ppo_trainer" def __post_init__(self) -> None: """Post-initialization hook for PPO trainer configuration.""" super().__post_init__()
def _make_ppo_trainer(*args, **kwargs) -> PPOTrainer: from torchrl.trainers.trainers import Logger collector = kwargs.pop("collector") total_frames = kwargs.pop("total_frames") if total_frames is None: total_frames = collector.total_frames frame_skip = kwargs.pop("frame_skip", 1) optim_steps_per_batch = kwargs.pop("optim_steps_per_batch", 1) loss_module = kwargs.pop("loss_module") optimizer = kwargs.pop("optimizer") logger = kwargs.pop("logger") clip_grad_norm = kwargs.pop("clip_grad_norm", True) clip_norm = kwargs.pop("clip_norm") progress_bar = kwargs.pop("progress_bar", True) replay_buffer = kwargs.pop("replay_buffer") save_trainer_interval = kwargs.pop("save_trainer_interval", 10000) log_interval = kwargs.pop("log_interval", 10000) save_trainer_file = kwargs.pop("save_trainer_file") seed = kwargs.pop("seed") actor_network = kwargs.pop("actor_network") critic_network = kwargs.pop("critic_network") create_env_fn = kwargs.pop("create_env_fn") if create_env_fn is not None: # could be referenced somewhere else, no need to raise an error pass num_epochs = kwargs.pop("num_epochs", 4) async_collection = kwargs.pop("async_collection", False) # Instantiate networks first if actor_network is not None: actor_network = actor_network() if critic_network is not None: critic_network = critic_network() if not isinstance(collector, DataCollectorBase): # then it's a partial config if not async_collection: collector = collector() else: collector = collector(replay_buffer=replay_buffer) elif async_collection and getattr(collector, "replay_buffer", None) is None: raise RuntimeError( "replay_buffer must be provided when async_collection is True" ) if not isinstance(loss_module, LossModule): # then it's a partial config loss_module = loss_module( actor_network=actor_network, critic_network=critic_network ) if not isinstance(optimizer, torch.optim.Optimizer): # then it's a partial config optimizer = optimizer(params=loss_module.parameters()) # Quick instance checks if not isinstance(collector, DataCollectorBase): raise ValueError( f"collector must be a DataCollectorBase, got {type(collector)}" ) if not isinstance(loss_module, LossModule): raise ValueError(f"loss_module must be a LossModule, got {type(loss_module)}") if not isinstance(optimizer, torch.optim.Optimizer): raise ValueError( f"optimizer must be a torch.optim.Optimizer, got {type(optimizer)}" ) if not isinstance(logger, Logger) and logger is not None: raise ValueError(f"logger must be a Logger, got {type(logger)}") return PPOTrainer( collector=collector, total_frames=total_frames, frame_skip=frame_skip, optim_steps_per_batch=optim_steps_per_batch, loss_module=loss_module, optimizer=optimizer, logger=logger, clip_grad_norm=clip_grad_norm, clip_norm=clip_norm, progress_bar=progress_bar, seed=seed, save_trainer_interval=save_trainer_interval, log_interval=log_interval, save_trainer_file=save_trainer_file, replay_buffer=replay_buffer, num_epochs=num_epochs, async_collection=async_collection, )

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources