DQNTrainer

class torchrl.trainers.algorithms.DQNTrainer(*args, **kwargs)[source]

A trainer class for Deep Q-Network (DQN) algorithm.

See also DQNTrainerConfig for the Hydra configuration counterpart.

This trainer implements the DQN algorithm, a value-based method for discrete action spaces that learns a Q-function and derives a greedy policy from it.

The trainer handles: - Replay buffer management for off-policy learning - Target network updates (typically HardUpdate) with configurable update frequency - Policy weight updates to the data collector - Comprehensive logging of training metrics

Parameters:
  • collector (BaseCollector) – The data collector used to gather environment interactions.

  • total_frames (int) – Total number of frames to collect during training.

  • frame_skip (int) – Number of frames to skip between policy updates.

  • optim_steps_per_batch (int) – Number of optimization steps per collected batch.

  • loss_module (LossModule | Callable) – The DQN loss module or a callable that computes losses.

  • optimizer (optim.Optimizer, optional) – The optimizer for training.

  • logger (Logger, optional) – Logger for recording training metrics. Defaults to None.

  • clip_grad_norm (bool, optional) – Whether to clip gradient norms. Defaults to True.

  • clip_norm (float, optional) – Maximum gradient norm for clipping. Defaults to None.

  • progress_bar (bool, optional) – Whether to show a progress bar during training. Defaults to True.

  • seed (int, optional) – Random seed for reproducibility. Defaults to None.

  • save_trainer_interval (int, optional) – Interval for saving trainer state. Defaults to 10000.

  • log_interval (int, optional) – Interval for logging metrics. Defaults to 10000.

  • save_trainer_file (str | pathlib.Path, optional) – File path for saving trainer state. Defaults to None.

  • replay_buffer (ReplayBuffer, optional) – Replay buffer for storing and sampling experiences. Defaults to None.

  • enable_logging (bool, optional) – Whether to enable metric logging. Defaults to True.

  • log_rewards (bool, optional) – Whether to log reward statistics. Defaults to True.

  • log_observations (bool, optional) – Whether to log observation statistics. Defaults to False.

  • target_net_updater (TargetNetUpdater, optional) – Target network updater (typically HardUpdate). Defaults to None.

  • greedy_module (EGreedyModule, optional) – Epsilon-greedy exploration module. When provided, the module’s epsilon is annealed during training. Defaults to None.

  • async_collection (bool, optional) – Whether to use async data collection. Defaults to False.

  • log_timings (bool, optional) – Whether to log timing information for hooks. Defaults to False.

  • mixing_strategy (str, optional) – Multi-agent mixing strategy. Accepted values are "qmix" and "vdn" for mixed-value training, "iql" for independent Q-learning, or None for standard DQN. Defaults to None.

  • done_key (NestedKey, optional) – Key for the done signal used by logging. Defaults to "done".

  • terminated_key (NestedKey, optional) – Key for the terminated signal. Defaults to "terminated".

  • reward_key (NestedKey, optional) – Source reward key used by logging and reward aggregation. Defaults to "reward".

  • episode_reward_key (NestedKey, optional) – Source episode reward key used by logging and reward aggregation. Defaults to "reward_sum".

  • aggregated_reward_key (NestedKey, optional) – Destination key for rewards averaged over the agent dimension when using QMIX or VDN. The source is reward_key. Set this to reward_key to overwrite the source reward in-place. Required when mixing_strategy is "qmix" or "vdn". Defaults to None.

  • aggregated_episode_reward_key (NestedKey, optional) – Destination key for episode rewards averaged over the agent dimension when using QMIX or VDN. The source is episode_reward_key. Set this to episode_reward_key to overwrite the source reward in-place. Required when mixing_strategy is "qmix" or "vdn". Defaults to None.

  • action_key (NestedKey, optional) – Key for actions used by the exploration module and policy specs. Defaults to "action".

  • observation_key (NestedKey, optional) – Key for observations used by logging. Defaults to "observation".

Example

>>> from torchrl.collectors import Collector
>>> from torchrl.objectives import DQNLoss
>>> from torchrl.data import ReplayBuffer, LazyTensorStorage
>>> from torchrl.objectives.utils import HardUpdate
>>> from torch import optim
>>>
>>> # Set up collector, loss, and replay buffer
>>> collector = Collector(env, policy, frames_per_batch=128)
>>> loss_module = DQNLoss(value_network, delay_value=True)
>>> optimizer = optim.Adam(loss_module.parameters(), lr=2.5e-4)
>>> replay_buffer = ReplayBuffer(storage=LazyTensorStorage(100000))
>>> target_net_updater = HardUpdate(loss_module, value_network_update_interval=50)
>>>
>>> trainer = DQNTrainer(
...     collector=collector,
...     total_frames=500000,
...     frame_skip=1,
...     optim_steps_per_batch=10,
...     loss_module=loss_module,
...     optimizer=optimizer,
...     replay_buffer=replay_buffer,
...     target_net_updater=target_net_updater,
... )
>>> trainer.train()

Note

This is an experimental/prototype feature. The API may change in future versions. DQN is designed for discrete action spaces (e.g., CartPole, Atari). For continuous control, consider using SACTrainer or DDPGTrainer instead.

load_from_file(file: str | Path, **kwargs) Trainer

Loads a file and its state-dict in the trainer.

Keyword arguments are passed to the load() function. They are ignored when CKPT_BACKEND=memmap.

Note

When CKPT_BACKEND=torch, weights_only=True is set by default for safer deserialization. Pass weights_only=False explicitly only if you have custom (non-stdlib) objects in your state dict.

request_stop(reason: str | None = None) None

Signal that training should stop at the next loop boundary.

Docs

Lorem ipsum dolor sit amet, consectetur

View Docs

Tutorials

Lorem ipsum dolor sit amet, consectetur

View Tutorials

Resources

Lorem ipsum dolor sit amet, consectetur

View Resources