Shortcuts

SACTrainer

class torchrl.trainers.algorithms.SACTrainer(*args, **kwargs)[source]

A trainer class for Soft Actor-Critic (SAC) algorithm.

This trainer implements the SAC algorithm, an off-policy actor-critic method that optimizes a stochastic policy in an off-policy way, forming a bridge between stochastic policy optimization and DDPG-style approaches. SAC incorporates the entropy measure of the policy into the reward to encourage exploration.

The trainer handles: - Replay buffer management for off-policy learning - Target network updates with configurable update frequency - Policy weight updates to the data collector - Comprehensive logging of training metrics - Gradient clipping and optimization steps

Parameters:
  • collector (DataCollectorBase) – The data collector used to gather environment interactions.

  • total_frames (int) – Total number of frames to collect during training.

  • frame_skip (int) – Number of frames to skip between policy updates.

  • optim_steps_per_batch (int) – Number of optimization steps per collected batch.

  • loss_module (LossModule | Callable) – The SAC loss module or a callable that computes losses.

  • optimizer (optim.Optimizer, optional) – The optimizer for training. If None, must be configured elsewhere.

  • logger (Logger, optional) – Logger for recording training metrics. Defaults to None.

  • clip_grad_norm (bool, optional) – Whether to clip gradient norms. Defaults to True.

  • clip_norm (float, optional) – Maximum gradient norm for clipping. Defaults to None.

  • progress_bar (bool, optional) – Whether to show a progress bar during training. Defaults to True.

  • seed (int, optional) – Random seed for reproducibility. Defaults to None.

  • save_trainer_interval (int, optional) – Interval for saving trainer state. Defaults to 10000.

  • log_interval (int, optional) – Interval for logging metrics. Defaults to 10000.

  • save_trainer_file (str | pathlib.Path, optional) – File path for saving trainer state. Defaults to None.

  • replay_buffer (ReplayBuffer, optional) – Replay buffer for storing and sampling experiences. Defaults to None.

  • batch_size (int, optional) – Batch size for sampling from replay buffer. Defaults to None.

  • enable_logging (bool, optional) – Whether to enable metric logging. Defaults to True.

  • log_rewards (bool, optional) – Whether to log reward statistics. Defaults to True.

  • log_actions (bool, optional) – Whether to log action statistics. Defaults to True.

  • log_observations (bool, optional) – Whether to log observation statistics. Defaults to False.

  • target_net_updater (TargetNetUpdater, optional) – Target network updater for soft updates. Defaults to None.

Example

>>> from torchrl.collectors import SyncDataCollector
>>> from torchrl.objectives import SACLoss
>>> from torchrl.data import ReplayBuffer, LazyTensorStorage
>>> from torch import optim
>>>
>>> # Set up collector, loss, and replay buffer
>>> collector = SyncDataCollector(env, policy, frames_per_batch=1000)
>>> loss_module = SACLoss(actor_network, qvalue_network)
>>> optimizer = optim.Adam(loss_module.parameters(), lr=3e-4)
>>> replay_buffer = ReplayBuffer(storage=LazyTensorStorage(100000))
>>>
>>> # Create and run trainer
>>> trainer = SACTrainer(
...     collector=collector,
...     total_frames=1000000,
...     frame_skip=1,
...     optim_steps_per_batch=100,
...     loss_module=loss_module,
...     optimizer=optimizer,
...     replay_buffer=replay_buffer,
... )
>>> trainer.train()

Note

This is an experimental/prototype feature. The API may change in future versions. SAC is particularly effective for continuous control tasks and environments where exploration is crucial due to its entropy regularization.

load_from_file(file: str | pathlib.Path, **kwargs) Trainer

Loads a file and its state-dict in the trainer.

Keyword arguments are passed to the load() function.

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources