DQNTrainer¶
- class torchrl.trainers.algorithms.DQNTrainer(*args, **kwargs)[source]¶
A trainer class for Deep Q-Network (DQN) algorithm.
See also
DQNTrainerConfigfor the Hydra configuration counterpart.This trainer implements the DQN algorithm, a value-based method for discrete action spaces that learns a Q-function and derives a greedy policy from it.
The trainer handles: - Replay buffer management for off-policy learning - Target network updates (typically HardUpdate) with configurable update frequency - Policy weight updates to the data collector - Comprehensive logging of training metrics
- Parameters:
collector (BaseCollector) – The data collector used to gather environment interactions.
total_frames (int) – Total number of frames to collect during training.
frame_skip (int) – Number of frames to skip between policy updates.
optim_steps_per_batch (int) – Number of optimization steps per collected batch.
loss_module (LossModule | Callable) – The DQN loss module or a callable that computes losses.
optimizer (optim.Optimizer, optional) – The optimizer for training.
logger (Logger, optional) – Logger for recording training metrics. Defaults to None.
clip_grad_norm (bool, optional) – Whether to clip gradient norms. Defaults to True.
clip_norm (float, optional) – Maximum gradient norm for clipping. Defaults to None.
progress_bar (bool, optional) – Whether to show a progress bar during training. Defaults to True.
seed (int, optional) – Random seed for reproducibility. Defaults to None.
save_trainer_interval (int, optional) – Interval for saving trainer state. Defaults to 10000.
log_interval (int, optional) – Interval for logging metrics. Defaults to 10000.
save_trainer_file (str | pathlib.Path, optional) – File path for saving trainer state. Defaults to None.
replay_buffer (ReplayBuffer, optional) – Replay buffer for storing and sampling experiences. Defaults to None.
enable_logging (bool, optional) – Whether to enable metric logging. Defaults to True.
log_rewards (bool, optional) – Whether to log reward statistics. Defaults to True.
log_observations (bool, optional) – Whether to log observation statistics. Defaults to False.
target_net_updater (TargetNetUpdater, optional) – Target network updater (typically HardUpdate). Defaults to None.
greedy_module (EGreedyModule, optional) – Epsilon-greedy exploration module. When provided, the module’s epsilon is annealed during training. Defaults to None.
async_collection (bool, optional) – Whether to use async data collection. Defaults to False.
log_timings (bool, optional) – Whether to log timing information for hooks. Defaults to False.
mixing_strategy (str, optional) – Multi-agent mixing strategy. Accepted values are
"qmix"and"vdn"for mixed-value training,"iql"for independent Q-learning, or None for standard DQN. Defaults to None.done_key (NestedKey, optional) – Key for the done signal used by logging. Defaults to
"done".terminated_key (NestedKey, optional) – Key for the terminated signal. Defaults to
"terminated".reward_key (NestedKey, optional) – Source reward key used by logging and reward aggregation. Defaults to
"reward".episode_reward_key (NestedKey, optional) – Source episode reward key used by logging and reward aggregation. Defaults to
"reward_sum".aggregated_reward_key (NestedKey, optional) – Destination key for rewards averaged over the agent dimension when using QMIX or VDN. The source is
reward_key. Set this toreward_keyto overwrite the source reward in-place. Required whenmixing_strategyis"qmix"or"vdn". Defaults to None.aggregated_episode_reward_key (NestedKey, optional) – Destination key for episode rewards averaged over the agent dimension when using QMIX or VDN. The source is
episode_reward_key. Set this toepisode_reward_keyto overwrite the source reward in-place. Required whenmixing_strategyis"qmix"or"vdn". Defaults to None.action_key (NestedKey, optional) – Key for actions used by the exploration module and policy specs. Defaults to
"action".observation_key (NestedKey, optional) – Key for observations used by logging. Defaults to
"observation".
Example
>>> from torchrl.collectors import Collector >>> from torchrl.objectives import DQNLoss >>> from torchrl.data import ReplayBuffer, LazyTensorStorage >>> from torchrl.objectives.utils import HardUpdate >>> from torch import optim >>> >>> # Set up collector, loss, and replay buffer >>> collector = Collector(env, policy, frames_per_batch=128) >>> loss_module = DQNLoss(value_network, delay_value=True) >>> optimizer = optim.Adam(loss_module.parameters(), lr=2.5e-4) >>> replay_buffer = ReplayBuffer(storage=LazyTensorStorage(100000)) >>> target_net_updater = HardUpdate(loss_module, value_network_update_interval=50) >>> >>> trainer = DQNTrainer( ... collector=collector, ... total_frames=500000, ... frame_skip=1, ... optim_steps_per_batch=10, ... loss_module=loss_module, ... optimizer=optimizer, ... replay_buffer=replay_buffer, ... target_net_updater=target_net_updater, ... ) >>> trainer.train()
Note
This is an experimental/prototype feature. The API may change in future versions. DQN is designed for discrete action spaces (e.g., CartPole, Atari). For continuous control, consider using SACTrainer or DDPGTrainer instead.
- load_from_file(file: str | Path, **kwargs) Trainer¶
Loads a file and its state-dict in the trainer.
Keyword arguments are passed to the
load()function. They are ignored whenCKPT_BACKEND=memmap.Note
When
CKPT_BACKEND=torch,weights_only=Trueis set by default for safer deserialization. Passweights_only=Falseexplicitly only if you have custom (non-stdlib) objects in your state dict.
- request_stop(reason: str | None = None) None¶
Signal that training should stop at the next loop boundary.