IQLTrainer

class torchrl.trainers.algorithms.IQLTrainer(*args, **kwargs)[source]

A trainer class for Implicit Q-Learning (IQL) algorithm.

This trainer implements the IQL algorithm, an off-policy actor-critic method that uses expectile regression for value function learning. IQL avoids querying out-of-distribution actions by using an implicit approach to Q-learning.

The trainer handles: - Replay buffer management for off-policy learning - Target network updates (SoftUpdate on Q-networks) - Policy weight updates to the data collector - Comprehensive logging of training metrics

IQL uses three networks: actor, Q-value, and value networks. The value network is trained with expectile regression, which provides an implicit way of extracting the maximum Q-value without explicit maximization.

Parameters:
  • collector (BaseCollector) – The data collector used to gather environment interactions.

  • total_frames (int) – Total number of frames to collect during training.

  • frame_skip (int) – Number of frames to skip between policy updates.

  • optim_steps_per_batch (int) – Number of optimization steps per collected batch.

  • loss_module (LossModule | Callable) – The IQL loss module.

  • optimizer (optim.Optimizer, optional) – The optimizer for training.

  • logger (Logger, optional) – Logger for recording training metrics. Defaults to None.

  • clip_grad_norm (bool, optional) – Whether to clip gradient norms. Defaults to True.

  • clip_norm (float, optional) – Maximum gradient norm for clipping. Defaults to None.

  • progress_bar (bool, optional) – Whether to show a progress bar. Defaults to True.

  • seed (int, optional) – Random seed for reproducibility. Defaults to None.

  • save_trainer_interval (int, optional) – Interval for saving trainer state. Defaults to 10000.

  • log_interval (int, optional) – Interval for logging metrics. Defaults to 10000.

  • save_trainer_file (str | pathlib.Path, optional) – File path for saving trainer state.

  • replay_buffer (ReplayBuffer, optional) – Replay buffer for storing experiences. Defaults to None.

  • enable_logging (bool, optional) – Whether to enable metric logging. Defaults to True.

  • log_rewards (bool, optional) – Whether to log reward statistics. Defaults to True.

  • log_actions (bool, optional) – Whether to log action statistics. Defaults to True.

  • log_observations (bool, optional) – Whether to log observation statistics. Defaults to False.

  • target_net_updater (TargetNetUpdater, optional) – Target network updater (typically SoftUpdate).

  • async_collection (bool, optional) – Whether to use async data collection. Defaults to False.

  • log_timings (bool, optional) – Whether to log timing information for hooks. Defaults to False.

Note

This is an experimental/prototype feature. The API may change in future versions. IQL works well for both online and offline RL. For offline RL, configure the collector to use a pre-collected dataset.

load_from_file(file: str | Path, **kwargs) Trainer

Loads a file and its state-dict in the trainer.

Keyword arguments are passed to the load() function.

Docs

Lorem ipsum dolor sit amet, consectetur

View Docs

Tutorials

Lorem ipsum dolor sit amet, consectetur

View Tutorials

Resources

Lorem ipsum dolor sit amet, consectetur

View Resources