Rate this Page

CQLTrainer#

class torchrl.trainers.algorithms.CQLTrainer(*args, **kwargs)[source]#

A trainer class for Conservative Q-Learning (CQL) algorithm.

See also CQLTrainerConfig for the Hydra configuration counterpart.

This trainer implements the CQL algorithm, an off-policy actor-critic method that adds a conservative penalty to Q-value estimates to prevent overestimation of out-of-distribution actions. CQL is particularly effective for offline RL.

The trainer handles: - Replay buffer management for off-policy learning - Target network updates (SoftUpdate) for stable training - Policy weight updates to the data collector - Action spec passthrough from collector environment to loss module - Comprehensive logging of training metrics

Parameters:
  • collector (BaseCollector) – The data collector used to gather environment interactions.

  • total_frames (int) – Total number of frames to collect during training.

  • frame_skip (int) – Number of frames to skip between policy updates.

  • optim_steps_per_batch (int) – Number of optimization steps per collected batch.

  • loss_module (LossModule | Callable) – The CQL loss module.

  • optimizer (optim.Optimizer, optional) – The optimizer for training.

  • logger (Logger, optional) – Logger for recording training metrics. Defaults to None.

  • clip_grad_norm (bool, optional) – Whether to clip gradient norms. Defaults to True.

  • clip_norm (float, optional) – Maximum gradient norm for clipping. Defaults to None.

  • progress_bar (bool, optional) – Whether to show a progress bar. Defaults to True.

  • seed (int, optional) – Random seed for reproducibility. Defaults to None.

  • save_trainer_interval (int, optional) – Interval for saving trainer state. Defaults to 10000.

  • log_interval (int, optional) – Interval for logging metrics. Defaults to 10000.

  • save_trainer_file (str | pathlib.Path, optional) – File path for saving trainer state.

  • replay_buffer (ReplayBuffer, optional) – Replay buffer for storing experiences. Defaults to None.

  • enable_logging (bool, optional) – Whether to enable metric logging. Defaults to True.

  • log_rewards (bool, optional) – Whether to log reward statistics. Defaults to True.

  • log_actions (bool, optional) – Whether to log action statistics. Defaults to True.

  • log_observations (bool, optional) – Whether to log observation statistics. Defaults to False.

  • target_net_updater (TargetNetUpdater, optional) – Target network updater (typically SoftUpdate).

  • async_collection (bool, optional) – Whether to use async data collection. Defaults to False.

  • log_timings (bool, optional) – Whether to log timing information for hooks. Defaults to False.

Note

This is an experimental/prototype feature. The API may change in future versions. CQL is designed for offline RL but also works online. The conservative penalty prevents overestimation of Q-values for unseen actions.

load_from_file(file: str | Path, **kwargs) Trainer#

Loads a file and its state-dict in the trainer.

Keyword arguments are passed to the load() function. They are ignored when CKPT_BACKEND=memmap.

Note

When CKPT_BACKEND=torch, weights_only=True is set by default for safer deserialization. Pass weights_only=False explicitly only if you have custom (non-stdlib) objects in your state dict.

request_stop(reason: str | None = None) None#

Signal that training should stop at the next loop boundary.