DQNTrainer¶
- class torchrl.trainers.algorithms.DQNTrainer(*args, **kwargs)[source]¶
A trainer class for Deep Q-Network (DQN) algorithm.
This trainer implements the DQN algorithm, a value-based method for discrete action spaces that learns a Q-function and derives a greedy policy from it.
The trainer handles: - Replay buffer management for off-policy learning - Target network updates (typically HardUpdate) with configurable update frequency - Policy weight updates to the data collector - Comprehensive logging of training metrics
- Parameters:
collector (BaseCollector) – The data collector used to gather environment interactions.
total_frames (int) – Total number of frames to collect during training.
frame_skip (int) – Number of frames to skip between policy updates.
optim_steps_per_batch (int) – Number of optimization steps per collected batch.
loss_module (LossModule | Callable) – The DQN loss module or a callable that computes losses.
optimizer (optim.Optimizer, optional) – The optimizer for training.
logger (Logger, optional) – Logger for recording training metrics. Defaults to None.
clip_grad_norm (bool, optional) – Whether to clip gradient norms. Defaults to True.
clip_norm (float, optional) – Maximum gradient norm for clipping. Defaults to None.
progress_bar (bool, optional) – Whether to show a progress bar during training. Defaults to True.
seed (int, optional) – Random seed for reproducibility. Defaults to None.
save_trainer_interval (int, optional) – Interval for saving trainer state. Defaults to 10000.
log_interval (int, optional) – Interval for logging metrics. Defaults to 10000.
save_trainer_file (str | pathlib.Path, optional) – File path for saving trainer state. Defaults to None.
replay_buffer (ReplayBuffer, optional) – Replay buffer for storing and sampling experiences. Defaults to None.
enable_logging (bool, optional) – Whether to enable metric logging. Defaults to True.
log_rewards (bool, optional) – Whether to log reward statistics. Defaults to True.
log_observations (bool, optional) – Whether to log observation statistics. Defaults to False.
target_net_updater (TargetNetUpdater, optional) – Target network updater (typically HardUpdate). Defaults to None.
greedy_module (EGreedyModule, optional) – Epsilon-greedy exploration module. When provided, the module’s epsilon is annealed during training. Defaults to None.
async_collection (bool, optional) – Whether to use async data collection. Defaults to False.
log_timings (bool, optional) – Whether to log timing information for hooks. Defaults to False.
Example
>>> from torchrl.collectors import Collector >>> from torchrl.objectives import DQNLoss >>> from torchrl.data import ReplayBuffer, LazyTensorStorage >>> from torchrl.objectives.utils import HardUpdate >>> from torch import optim >>> >>> # Set up collector, loss, and replay buffer >>> collector = Collector(env, policy, frames_per_batch=128) >>> loss_module = DQNLoss(value_network, delay_value=True) >>> optimizer = optim.Adam(loss_module.parameters(), lr=2.5e-4) >>> replay_buffer = ReplayBuffer(storage=LazyTensorStorage(100000)) >>> target_net_updater = HardUpdate(loss_module, value_network_update_interval=50) >>> >>> trainer = DQNTrainer( ... collector=collector, ... total_frames=500000, ... frame_skip=1, ... optim_steps_per_batch=10, ... loss_module=loss_module, ... optimizer=optimizer, ... replay_buffer=replay_buffer, ... target_net_updater=target_net_updater, ... ) >>> trainer.train()
Note
This is an experimental/prototype feature. The API may change in future versions. DQN is designed for discrete action spaces (e.g., CartPole, Atari). For continuous control, consider using SACTrainer or DDPGTrainer instead.