SACTrainer

class torchrl.trainers.algorithms.SACTrainer(*args, **kwargs)[source]

A trainer class for Soft Actor-Critic (SAC) algorithm.

This trainer implements the SAC algorithm, an off-policy actor-critic method that optimizes a stochastic policy in an off-policy way, forming a bridge between stochastic policy optimization and DDPG-style approaches. SAC incorporates the entropy measure of the policy into the reward to encourage exploration.

The trainer handles: - Replay buffer management for off-policy learning - Target network updates with configurable update frequency - Policy weight updates to the data collector - Comprehensive logging of training metrics - Gradient clipping and optimization steps

Parameters:
  • collector (BaseCollector) – The data collector used to gather environment interactions.

  • total_frames (int) – Total number of frames to collect during training.

  • frame_skip (int) – Number of frames to skip between policy updates.

  • optim_steps_per_batch (int) – Number of optimization steps per collected batch.

  • loss_module (LossModule | Callable) – The SAC loss module or a callable that computes losses.

  • optimizer (optim.Optimizer, optional) – The optimizer for training. If None, must be configured elsewhere.

  • logger (Logger, optional) – Logger for recording training metrics. Defaults to None.

  • clip_grad_norm (bool, optional) – Whether to clip gradient norms. Defaults to True.

  • clip_norm (float, optional) – Maximum gradient norm for clipping. Defaults to None.

  • progress_bar (bool, optional) – Whether to show a progress bar during training. Defaults to True.

  • seed (int, optional) – Random seed for reproducibility. Defaults to None.

  • save_trainer_interval (int, optional) – Interval for saving trainer state. Defaults to 10000.

  • log_interval (int, optional) – Interval for logging metrics. Defaults to 10000.

  • save_trainer_file (str | pathlib.Path, optional) – File path for saving trainer state. Defaults to None.

  • replay_buffer (ReplayBuffer, optional) – Replay buffer for storing and sampling experiences. Defaults to None.

  • batch_size (int, optional) – Batch size for sampling from replay buffer. Defaults to None.

  • enable_logging (bool, optional) – Whether to enable metric logging. Defaults to True.

  • log_rewards (bool, optional) – Whether to log reward statistics. Defaults to True.

  • log_actions (bool, optional) – Whether to log action statistics. Defaults to True.

  • log_observations (bool, optional) – Whether to log observation statistics. Defaults to False.

  • target_net_updater (TargetNetUpdater, optional) – Target network updater for soft updates. Defaults to None.

Example

>>> from torchrl.collectors import Collector
>>> from torchrl.objectives import SACLoss
>>> from torchrl.data import ReplayBuffer, LazyTensorStorage
>>> from torch import optim
>>>
>>> # Set up collector, loss, and replay buffer
>>> collector = Collector(env, policy, frames_per_batch=1000)
>>> loss_module = SACLoss(actor_network, qvalue_network)
>>> optimizer = optim.Adam(loss_module.parameters(), lr=3e-4)
>>> replay_buffer = ReplayBuffer(storage=LazyTensorStorage(100000))
>>>
>>> # Create and run trainer
>>> trainer = SACTrainer(
...     collector=collector,
...     total_frames=1000000,
...     frame_skip=1,
...     optim_steps_per_batch=100,
...     loss_module=loss_module,
...     optimizer=optimizer,
...     replay_buffer=replay_buffer,
... )
>>> trainer.train()

Note

This is an experimental/prototype feature. The API may change in future versions. SAC is particularly effective for continuous control tasks and environments where exploration is crucial due to its entropy regularization.

load_from_file(file: str | Path, **kwargs) Trainer

Loads a file and its state-dict in the trainer.

Keyword arguments are passed to the load() function.

Docs

Lorem ipsum dolor sit amet, consectetur

View Docs

Tutorials

Lorem ipsum dolor sit amet, consectetur

View Tutorials

Resources

Lorem ipsum dolor sit amet, consectetur

View Resources