SACTrainer¶
- class torchrl.trainers.algorithms.SACTrainer(*args, **kwargs)[source]¶
A trainer class for Soft Actor-Critic (SAC) algorithm.
This trainer implements the SAC algorithm, an off-policy actor-critic method that optimizes a stochastic policy in an off-policy way, forming a bridge between stochastic policy optimization and DDPG-style approaches. SAC incorporates the entropy measure of the policy into the reward to encourage exploration.
The trainer handles: - Replay buffer management for off-policy learning - Target network updates with configurable update frequency - Policy weight updates to the data collector - Comprehensive logging of training metrics - Gradient clipping and optimization steps
- Parameters:
collector (DataCollectorBase) – The data collector used to gather environment interactions.
total_frames (int) – Total number of frames to collect during training.
frame_skip (int) – Number of frames to skip between policy updates.
optim_steps_per_batch (int) – Number of optimization steps per collected batch.
loss_module (LossModule | Callable) – The SAC loss module or a callable that computes losses.
optimizer (optim.Optimizer, optional) – The optimizer for training. If None, must be configured elsewhere.
logger (Logger, optional) – Logger for recording training metrics. Defaults to None.
clip_grad_norm (bool, optional) – Whether to clip gradient norms. Defaults to True.
clip_norm (float, optional) – Maximum gradient norm for clipping. Defaults to None.
progress_bar (bool, optional) – Whether to show a progress bar during training. Defaults to True.
seed (int, optional) – Random seed for reproducibility. Defaults to None.
save_trainer_interval (int, optional) – Interval for saving trainer state. Defaults to 10000.
log_interval (int, optional) – Interval for logging metrics. Defaults to 10000.
save_trainer_file (str | pathlib.Path, optional) – File path for saving trainer state. Defaults to None.
replay_buffer (ReplayBuffer, optional) – Replay buffer for storing and sampling experiences. Defaults to None.
batch_size (int, optional) – Batch size for sampling from replay buffer. Defaults to None.
enable_logging (bool, optional) – Whether to enable metric logging. Defaults to True.
log_rewards (bool, optional) – Whether to log reward statistics. Defaults to True.
log_actions (bool, optional) – Whether to log action statistics. Defaults to True.
log_observations (bool, optional) – Whether to log observation statistics. Defaults to False.
target_net_updater (TargetNetUpdater, optional) – Target network updater for soft updates. Defaults to None.
Example
>>> from torchrl.collectors import SyncDataCollector >>> from torchrl.objectives import SACLoss >>> from torchrl.data import ReplayBuffer, LazyTensorStorage >>> from torch import optim >>> >>> # Set up collector, loss, and replay buffer >>> collector = SyncDataCollector(env, policy, frames_per_batch=1000) >>> loss_module = SACLoss(actor_network, qvalue_network) >>> optimizer = optim.Adam(loss_module.parameters(), lr=3e-4) >>> replay_buffer = ReplayBuffer(storage=LazyTensorStorage(100000)) >>> >>> # Create and run trainer >>> trainer = SACTrainer( ... collector=collector, ... total_frames=1000000, ... frame_skip=1, ... optim_steps_per_batch=100, ... loss_module=loss_module, ... optimizer=optimizer, ... replay_buffer=replay_buffer, ... ) >>> trainer.train()
Note
This is an experimental/prototype feature. The API may change in future versions. SAC is particularly effective for continuous control tasks and environments where exploration is crucial due to its entropy regularization.