.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/coding_dqn.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_coding_dqn.py: TorchRL trainer: A DQN example ============================== **Author**: `Vincent Moens `_ .. _coding_dqn: .. GENERATED FROM PYTHON SOURCE LINES 11-85 TorchRL provides a generic :class:`~torchrl.trainers.Trainer` class to handle your training loop. The trainer executes a nested loop where the outer loop is the data collection and the inner loop consumes this data or some data retrieved from the replay buffer to train the model. At various points in this training loop, hooks can be attached and executed at given intervals. In this tutorial, we will be using the trainer class to train a DQN algorithm to solve the CartPole task from scratch. Main takeaways: - Building a trainer with its essential components: data collector, loss module, replay buffer and optimizer. - Adding hooks to a trainer, such as loggers, target network updaters and such. The trainer is fully customisable and offers a large set of functionalities. The tutorial is organised around its construction. We will be detailing how to build each of the components of the library first, and then put the pieces together using the :class:`~torchrl.trainers.Trainer` class. Along the road, we will also focus on some other aspects of the library: - how to build an environment in TorchRL, including transforms (e.g. data normalization, frame concatenation, resizing and turning to grayscale) and parallel execution. Unlike what we did in the :ref:`DDPG tutorial `, we will normalize the pixels and not the state vector. - how to design a :class:`~torchrl.modules.QValueActor` object, i.e. an actor that estimates the action values and picks up the action with the highest estimated return; - how to collect data from your environment efficiently and store them in a replay buffer; - how to use multi-step, a simple preprocessing step for off-policy sota-implementations; - and finally how to evaluate your model. **Prerequisites**: We encourage you to get familiar with torchrl through the :ref:`PPO tutorial ` first. DQN --- DQN (`Deep Q-Learning `_) was the founding work in deep reinforcement learning. On a high level, the algorithm is quite simple: Q-learning consists in learning a table of state-action values in such a way that, when encountering any particular state, we know which action to pick just by searching for the one with the highest value. This simple setting requires the actions and states to be discrete, otherwise a lookup table cannot be built. DQN uses a neural network that encodes a map from the state-action space to a value (scalar) space, which amortizes the cost of storing and exploring all the possible state-action combinations: if a state has not been seen in the past, we can still pass it in conjunction with the various actions available through our neural network and get an interpolated value for each of the actions available. We will solve the classic control problem of the cart pole. From the Gymnasium doc from where this environment is retrieved: | A pole is attached by an un-actuated joint to a cart, which moves along a | frictionless track. The pendulum is placed upright on the cart and the goal | is to balance the pole by applying forces in the left and right direction | on the cart. .. figure:: /_static/img/cartpole_demo.gif :alt: Cart Pole We do not aim at giving a SOTA implementation of the algorithm, but rather to provide a high-level illustration of TorchRL features in the context of this algorithm. .. GENERATED FROM PYTHON SOURCE LINES 85-136 .. code-block:: Python import os import uuid import torch from torch import nn from torchrl.collectors import MultiaSyncDataCollector, SyncDataCollector from torchrl.data import LazyMemmapStorage, MultiStep, TensorDictReplayBuffer from torchrl.envs import ( EnvCreator, ExplorationType, ParallelEnv, RewardScaling, StepCounter, ) from torchrl.envs.libs.gym import GymEnv from torchrl.envs.transforms import ( CatFrames, Compose, GrayScale, ObservationNorm, Resize, ToTensorImage, TransformedEnv, ) from torchrl.modules import DuelingCnnDQNet, EGreedyModule, QValueActor from torchrl.objectives import DQNLoss, SoftUpdate from torchrl.record.loggers.csv import CSVLogger from torchrl.trainers import ( LogReward, Recorder, ReplayBufferTrainer, Trainer, UpdateWeights, ) def is_notebook() -> bool: try: shell = get_ipython().__class__.__name__ if shell == "ZMQInteractiveShell": return True # Jupyter notebook or qtconsole elif shell == "TerminalInteractiveShell": return False # Terminal running IPython else: return False # Other type (?) except NameError: return False # Probably standard Python interpreter .. GENERATED FROM PYTHON SOURCE LINES 165-217 Let's get started with the various pieces we need for our algorithm: - An environment; - A policy (and related modules that we group under the "model" umbrella); - A data collector, which makes the policy play in the environment and delivers training data; - A replay buffer to store the training data; - A loss module, which computes the objective function to train our policy to maximise the return; - An optimizer, which performs parameter updates based on our loss. Additional modules include a logger, a recorder (executes the policy in "eval" mode) and a target network updater. With all these components into place, it is easy to see how one could misplace or misuse one component in the training script. The trainer is there to orchestrate everything for you! Building the environment ------------------------ First let's write a helper function that will output an environment. As usual, the "raw" environment may be too simple to be used in practice and we'll need some data transformation to expose its output to the policy. We will be using five transforms: - :class:`~torchrl.envs.StepCounter` to count the number of steps in each trajectory; - :class:`~torchrl.envs.transforms.ToTensorImage` will convert a ``[W, H, C]`` uint8 tensor in a floating point tensor in the ``[0, 1]`` space with shape ``[C, W, H]``; - :class:`~torchrl.envs.transforms.RewardScaling` to reduce the scale of the return; - :class:`~torchrl.envs.transforms.GrayScale` will turn our image into grayscale; - :class:`~torchrl.envs.transforms.Resize` will resize the image in a 64x64 format; - :class:`~torchrl.envs.transforms.CatFrames` will concatenate an arbitrary number of successive frames (``N=4``) in a single tensor along the channel dimension. This is useful as a single image does not carry information about the motion of the cartpole. Some memory about past observations and actions is needed, either via a recurrent neural network or using a stack of frames. - :class:`~torchrl.envs.transforms.ObservationNorm` which will normalize our observations given some custom summary statistics. In practice, our environment builder has two arguments: - ``parallel``: determines whether multiple environments have to be run in parallel. We stack the transforms after the :class:`~torchrl.envs.ParallelEnv` to take advantage of vectorization of the operations on device, although this would technically work with every single environment attached to its own set of transforms. - ``obs_norm_sd`` will contain the normalizing constants for the :class:`~torchrl.envs.ObservationNorm` transform. .. GENERATED FROM PYTHON SOURCE LINES 217-266 .. code-block:: Python def make_env( parallel=False, obs_norm_sd=None, num_workers=1, ): if obs_norm_sd is None: obs_norm_sd = {"standard_normal": True} if parallel: def maker(): return GymEnv( "CartPole-v1", from_pixels=True, pixels_only=True, device=device, ) base_env = ParallelEnv( num_workers, EnvCreator(maker), # Don't create a sub-process if we have only one worker serial_for_single=True, mp_start_method=mp_context, ) else: base_env = GymEnv( "CartPole-v1", from_pixels=True, pixels_only=True, device=device, ) env = TransformedEnv( base_env, Compose( StepCounter(), # to count the steps of each trajectory ToTensorImage(), RewardScaling(loc=0.0, scale=0.1), GrayScale(), Resize(64, 64), CatFrames(4, in_keys=["pixels"], dim=-3), ObservationNorm(in_keys=["pixels"], **obs_norm_sd), ), ) return env .. GENERATED FROM PYTHON SOURCE LINES 267-278 Compute normalizing constants ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ To normalize images, we don't want to normalize each pixel independently with a full ``[C, W, H]`` normalizing mask, but with simpler ``[C, 1, 1]`` shaped set of normalizing constants (loc and scale parameters). We will be using the ``reduce_dim`` argument of :meth:`~torchrl.envs.ObservationNorm.init_stats` to instruct which dimensions must be reduced, and the ``keep_dims`` parameter to ensure that not all dimensions disappear in the process: .. GENERATED FROM PYTHON SOURCE LINES 278-294 .. code-block:: Python def get_norm_stats(): test_env = make_env() test_env.transform[-1].init_stats( num_iter=1000, cat_dim=0, reduce_dim=[-1, -2, -4], keep_dims=(-1, -2) ) obs_norm_sd = test_env.transform[-1].state_dict() # let's check that normalizing constants have a size of ``[C, 1, 1]`` where # ``C=4`` (because of :class:`~torchrl.envs.CatFrames`). print("state dict of the observation norm:", obs_norm_sd) test_env.close() del test_env return obs_norm_sd .. GENERATED FROM PYTHON SOURCE LINES 295-317 Building the model (Deep Q-network) ----------------------------------- The following function builds a :class:`~torchrl.modules.DuelingCnnDQNet` object which is a simple CNN followed by a two-layer MLP. The only trick used here is that the action values (i.e. left and right action value) are computed using .. math:: \mathbb{v} = b(obs) + v(obs) - \mathbb{E}[v(obs)] where :math:`\mathbb{v}` is our vector of action values, :math:`b` is a :math:`\mathbb{R}^n \rightarrow 1` function and :math:`v` is a :math:`\mathbb{R}^n \rightarrow \mathbb{R}^m` function, for :math:`n = \# obs` and :math:`m = \# actions`. Our network is wrapped in a :class:`~torchrl.modules.QValueActor`, which will read the state-action values, pick up the one with the maximum value and write all those results in the input :class:`tensordict.TensorDict`. .. GENERATED FROM PYTHON SOURCE LINES 317-361 .. code-block:: Python def make_model(dummy_env): cnn_kwargs = { "num_cells": [32, 64, 64], "kernel_sizes": [6, 4, 3], "strides": [2, 2, 1], "activation_class": nn.ELU, # This can be used to reduce the size of the last layer of the CNN # "squeeze_output": True, # "aggregator_class": nn.AdaptiveAvgPool2d, # "aggregator_kwargs": {"output_size": (1, 1)}, } mlp_kwargs = { "depth": 2, "num_cells": [ 64, 64, ], "activation_class": nn.ELU, } net = DuelingCnnDQNet( dummy_env.action_spec.shape[-1], 1, cnn_kwargs, mlp_kwargs ).to(device) net.value[-1].bias.data.fill_(init_bias) actor = QValueActor(net, in_keys=["pixels"], spec=dummy_env.action_spec).to(device) # init actor: because the model is composed of lazy conv/linear layers, # we must pass a fake batch of data through it to instantiate them. tensordict = dummy_env.fake_tensordict() actor(tensordict) # we join our actor with an EGreedyModule for data collection exploration_module = EGreedyModule( spec=dummy_env.action_spec, annealing_num_steps=total_frames, eps_init=eps_greedy_val, eps_end=eps_greedy_val_env, ) actor_explore = TensorDictSequential(actor, exploration_module) return actor, actor_explore .. GENERATED FROM PYTHON SOURCE LINES 362-381 Collecting and storing data --------------------------- Replay buffers ~~~~~~~~~~~~~~ Replay buffers play a central role in off-policy RL sota-implementations such as DQN. They constitute the dataset we will be sampling from during training. Here, we will use a regular sampling strategy, although a prioritized RB could improve the performance significantly. We place the storage on disk using :class:`~torchrl.data.replay_buffers.storages.LazyMemmapStorage` class. This storage is created in a lazy manner: it will only be instantiated once the first batch of data is passed to it. The only requirement of this storage is that the data passed to it at write time must always have the same shape. .. GENERATED FROM PYTHON SOURCE LINES 381-392 .. code-block:: Python def get_replay_buffer(buffer_size, n_optim, batch_size): replay_buffer = TensorDictReplayBuffer( batch_size=batch_size, storage=LazyMemmapStorage(buffer_size), prefetch=n_optim, ) return replay_buffer .. GENERATED FROM PYTHON SOURCE LINES 393-428 Data collector ~~~~~~~~~~~~~~ As in :ref:`PPO ` and :ref:`DDPG `, we will be using a data collector as a dataloader in the outer loop. We choose the following configuration: we will be running a series of parallel environments synchronously in parallel in different collectors, themselves running in parallel but asynchronously. .. note:: This feature is only available when running the code within the "spawn" start method of python multiprocessing library. If this tutorial is run directly as a script (thereby using the "fork" method) we will be using a regular :class:`~torchrl.collectors.SyncDataCollector`. The advantage of this configuration is that we can balance the amount of compute that is executed in batch with what we want to be executed asynchronously. We encourage the reader to experiment how the collection speed is impacted by modifying the number of collectors (ie the number of environment constructors passed to the collector) and the number of environment executed in parallel in each collector (controlled by the ``num_workers`` hyperparameter). Collector's devices are fully parametrizable through the ``device`` (general), ``policy_device``, ``env_device`` and ``storing_device`` arguments. The ``storing_device`` argument will modify the location of the data being collected: if the batches that we are gathering have a considerable size, we may want to store them on a different location than the device where the computation is happening. For asynchronous data collectors such as ours, different storing devices mean that the data that we collect won't sit on the same device each time, which is something that out training loop must account for. For simplicity, we set the devices to the same value for all sub-collectors. .. GENERATED FROM PYTHON SOURCE LINES 428-464 .. code-block:: Python def get_collector( stats, num_collectors, actor_explore, frames_per_batch, total_frames, device, ): # We can't use nested child processes with mp_start_method="fork" if is_fork: cls = SyncDataCollector env_arg = make_env(parallel=True, obs_norm_sd=stats, num_workers=num_workers) else: cls = MultiaSyncDataCollector env_arg = [ make_env(parallel=True, obs_norm_sd=stats, num_workers=num_workers) ] * num_collectors data_collector = cls( env_arg, policy=actor_explore, frames_per_batch=frames_per_batch, total_frames=total_frames, # this is the default behaviour: the collector runs in ``"random"`` (or explorative) mode exploration_type=ExplorationType.RANDOM, # We set the all the devices to be identical. Below is an example of # heterogeneous devices device=device, storing_device=device, split_trajs=False, postproc=MultiStep(gamma=gamma, n_steps=5), ) return data_collector .. GENERATED FROM PYTHON SOURCE LINES 465-482 Loss function ------------- Building our loss function is straightforward: we only need to provide the model and a bunch of hyperparameters to the DQNLoss class. Target parameters ~~~~~~~~~~~~~~~~~ Many off-policy RL sota-implementations use the concept of "target parameters" when it comes to estimate the value of the next state or state-action pair. The target parameters are lagged copies of the model parameters. Because their predictions mismatch those of the current model configuration, they help learning by putting a pessimistic bound on the value being estimated. This is a powerful trick (known as "Double Q-Learning") that is ubiquitous in similar sota-implementations. .. GENERATED FROM PYTHON SOURCE LINES 482-491 .. code-block:: Python def get_loss_module(actor, gamma): loss_module = DQNLoss(actor, delay_value=True) loss_module.make_value_estimator(gamma=gamma) target_updater = SoftUpdate(loss_module, eps=0.995) return loss_module, target_updater .. GENERATED FROM PYTHON SOURCE LINES 492-498 Hyperparameters --------------- Let's start with our hyperparameters. The following setting should work well in practice, and the performance of the algorithm should hopefully not be too sensitive to slight variations of these. .. GENERATED FROM PYTHON SOURCE LINES 498-506 .. code-block:: Python is_fork = multiprocessing.get_start_method() == "fork" device = ( torch.device(0) if torch.cuda.is_available() and not is_fork else torch.device("cpu") ) .. GENERATED FROM PYTHON SOURCE LINES 507-509 Optimizer ~~~~~~~~~ .. GENERATED FROM PYTHON SOURCE LINES 509-519 .. code-block:: Python # the learning rate of the optimizer lr = 2e-3 # weight decay wd = 1e-5 # the beta parameters of Adam betas = (0.9, 0.999) # Optimization steps per batch collected (aka UPD or updates per data) n_optim = 8 .. GENERATED FROM PYTHON SOURCE LINES 520-523 DQN parameters ~~~~~~~~~~~~~~ gamma decay factor .. GENERATED FROM PYTHON SOURCE LINES 523-525 .. code-block:: Python gamma = 0.99 .. GENERATED FROM PYTHON SOURCE LINES 526-529 Smooth target network update decay parameter. This loosely corresponds to a 1/tau interval with hard target network update .. GENERATED FROM PYTHON SOURCE LINES 529-531 .. code-block:: Python tau = 0.02 .. GENERATED FROM PYTHON SOURCE LINES 532-545 Data collection and replay buffer ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. note:: Values to be used for proper training have been commented. Total frames collected in the environment. In other implementations, the user defines a maximum number of episodes. This is harder to do with our data collectors since they return batches of N collected frames, where N is a constant. However, one can easily get the same restriction on number of episodes by breaking the training loop when a certain number episodes has been collected. .. GENERATED FROM PYTHON SOURCE LINES 545-547 .. code-block:: Python total_frames = 5_000 # 500000 .. GENERATED FROM PYTHON SOURCE LINES 548-549 Random frames used to initialize the replay buffer. .. GENERATED FROM PYTHON SOURCE LINES 549-551 .. code-block:: Python init_random_frames = 100 # 1000 .. GENERATED FROM PYTHON SOURCE LINES 552-553 Frames in each batch collected. .. GENERATED FROM PYTHON SOURCE LINES 553-555 .. code-block:: Python frames_per_batch = 32 # 128 .. GENERATED FROM PYTHON SOURCE LINES 556-557 Frames sampled from the replay buffer at each optimization step .. GENERATED FROM PYTHON SOURCE LINES 557-559 .. code-block:: Python batch_size = 32 # 256 .. GENERATED FROM PYTHON SOURCE LINES 560-561 Size of the replay buffer in terms of frames .. GENERATED FROM PYTHON SOURCE LINES 561-563 .. code-block:: Python buffer_size = min(total_frames, 100000) .. GENERATED FROM PYTHON SOURCE LINES 564-565 Number of environments run in parallel in each data collector .. GENERATED FROM PYTHON SOURCE LINES 565-568 .. code-block:: Python num_workers = 2 # 8 num_collectors = 2 # 4 .. GENERATED FROM PYTHON SOURCE LINES 569-576 Environment and exploration ~~~~~~~~~~~~~~~~~~~~~~~~~~~ We set the initial and final value of the epsilon factor in Epsilon-greedy exploration. Since our policy is deterministic, exploration is crucial: without it, the only source of randomness would be the environment reset. .. GENERATED FROM PYTHON SOURCE LINES 576-580 .. code-block:: Python eps_greedy_val = 0.1 eps_greedy_val_env = 0.005 .. GENERATED FROM PYTHON SOURCE LINES 581-583 To speed up learning, we set the bias of the last layer of our value network to a predefined value (this is not mandatory) .. GENERATED FROM PYTHON SOURCE LINES 583-585 .. code-block:: Python init_bias = 2.0 .. GENERATED FROM PYTHON SOURCE LINES 586-591 .. note:: For fast rendering of the tutorial ``total_frames`` hyperparameter was set to a very low number. To get a reasonable performance, use a greater value e.g. 500000 .. GENERATED FROM PYTHON SOURCE LINES 593-610 Building a Trainer ------------------ TorchRL's :class:`~torchrl.trainers.Trainer` class constructor takes the following keyword-only arguments: - ``collector`` - ``loss_module`` - ``optimizer`` - ``logger``: A logger can be - ``total_frames``: this parameter defines the lifespan of the trainer. - ``frame_skip``: when a frame-skip is used, the collector must be made aware of it in order to accurately count the number of frames collected etc. Making the trainer aware of this parameter is not mandatory but helps to have a fairer comparison between settings where the total number of frames (budget) is fixed but the frame-skip is variable. .. GENERATED FROM PYTHON SOURCE LINES 610-633 .. code-block:: Python stats = get_norm_stats() test_env = make_env(parallel=False, obs_norm_sd=stats) # Get model actor, actor_explore = make_model(test_env) loss_module, target_net_updater = get_loss_module(actor, gamma) collector = get_collector( stats=stats, num_collectors=num_collectors, actor_explore=actor_explore, frames_per_batch=frames_per_batch, total_frames=total_frames, device=device, ) optimizer = torch.optim.Adam( loss_module.parameters(), lr=lr, weight_decay=wd, betas=betas ) exp_name = f"dqn_exp_{uuid.uuid1()}" tmpdir = tempfile.TemporaryDirectory() logger = CSVLogger(exp_name=exp_name, log_dir=tmpdir.name) warnings.warn(f"log dir: {logger.experiment.log_dir}") .. rst-class:: sphx-glr-script-out .. code-block:: none state dict of the observation norm: OrderedDict([('standard_normal', tensor(True)), ('loc', tensor([[[0.9895]], [[0.9895]], [[0.9895]], [[0.9895]]])), ('scale', tensor([[[0.0737]], [[0.0737]], [[0.0737]], [[0.0737]]]))]) .. GENERATED FROM PYTHON SOURCE LINES 634-636 We can control how often the scalars should be logged. Here we set this to a low value as our training loop is short: .. GENERATED FROM PYTHON SOURCE LINES 636-650 .. code-block:: Python log_interval = 500 trainer = Trainer( collector=collector, total_frames=total_frames, frame_skip=1, loss_module=loss_module, optimizer=optimizer, logger=logger, optim_steps_per_batch=n_optim, log_interval=log_interval, ) .. GENERATED FROM PYTHON SOURCE LINES 651-662 Registering hooks ~~~~~~~~~~~~~~~~~ Registering hooks can be achieved in two separate ways: - If the hook has it, the :meth:`~torchrl.trainers.TrainerHookBase.register` method is the first choice. One just needs to provide the trainer as input and the hook will be registered with a default name at a default location. For some hooks, the registration can be quite complex: :class:`~torchrl.trainers.ReplayBufferTrainer` requires 3 hooks (``extend``, ``sample`` and ``update_priority``) which can be cumbersome to implement. .. GENERATED FROM PYTHON SOURCE LINES 662-682 .. code-block:: Python buffer_hook = ReplayBufferTrainer( get_replay_buffer(buffer_size, n_optim, batch_size=batch_size), flatten_tensordicts=True, ) buffer_hook.register(trainer) weight_updater = UpdateWeights(collector, update_weights_interval=1) weight_updater.register(trainer) recorder = Recorder( record_interval=100, # log every 100 optimization steps record_frames=1000, # maximum number of frames in the record frame_skip=1, policy_exploration=actor_explore, environment=test_env, exploration_type=ExplorationType.MODE, log_keys=[("next", "reward")], out_keys={("next", "reward"): "rewards"}, log_pbar=True, ) recorder.register(trainer) .. GENERATED FROM PYTHON SOURCE LINES 683-685 The exploration module epsilon factor is also annealed: .. GENERATED FROM PYTHON SOURCE LINES 685-688 .. code-block:: Python trainer.register_op("post_steps", actor_explore[1].step, frames=frames_per_batch) .. GENERATED FROM PYTHON SOURCE LINES 689-697 - Any callable (including :class:`~torchrl.trainers.TrainerHookBase` subclasses) can be registered using :meth:`~torchrl.trainers.Trainer.register_op`. In this case, a location must be explicitly passed (). This method gives more control over the location of the hook but it also requires more understanding of the Trainer mechanism. Check the :ref:`trainer documentation ` for a detailed description of the trainer hooks. .. GENERATED FROM PYTHON SOURCE LINES 697-699 .. code-block:: Python trainer.register_op("post_optim", target_net_updater.step) .. GENERATED FROM PYTHON SOURCE LINES 700-707 We can log the training rewards too. Note that this is of limited interest with CartPole, as rewards are always 1. The discounted sum of rewards is maximised not by getting higher rewards but by keeping the cart-pole alive for longer. This will be reflected by the `total_rewards` value displayed in the progress bar. .. GENERATED FROM PYTHON SOURCE LINES 707-710 .. code-block:: Python log_reward = LogReward(log_pbar=True) log_reward.register(trainer) .. GENERATED FROM PYTHON SOURCE LINES 711-720 .. note:: It is possible to link multiple optimizers to the trainer if needed. In this case, each optimizer will be tied to a field in the loss dictionary. Check the :class:`~torchrl.trainers.OptimizerHook` to learn more. Here we are, ready to train our algorithm! A simple call to ``trainer.train()`` and we'll be getting our results logged in. .. GENERATED FROM PYTHON SOURCE LINES 720-722 .. code-block:: Python trainer.train() .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0/5000 [00:00` of the documentation. - A distributional loss (see :class:`~torchrl.objectives.DistributionalDQNLoss` for more information). - More fancy exploration techniques, such as :class:`~torchrl.modules.NoisyLinear` layers and such. .. rst-class:: sphx-glr-timing **Total running time of the script:** (2 minutes 33.662 seconds) **Estimated memory usage:** 853 MB .. _sphx_glr_download_tutorials_coding_dqn.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: coding_dqn.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: coding_dqn.py ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_