.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/torchrl_demo.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_torchrl_demo.py: Introduction to TorchRL ======================= Get started with reinforcement learning in PyTorch. .. GENERATED FROM PYTHON SOURCE LINES 9-22 TorchRL is an open-source Reinforcement Learning (RL) library for PyTorch. This tutorial provides a hands-on introduction to its main components. **Key features:** - **PyTorch-native**: Seamless integration with PyTorch's ecosystem - **Modular**: Easily swap components and build custom pipelines - **Efficient**: Optimized for both research and production - **Comprehensive**: Environments, modules, losses, collectors, and more By the end of this tutorial, you'll understand how TorchRL's components work together to build RL training pipelines. Let's start with a quick example to see what's possible: .. GENERATED FROM PYTHON SOURCE LINES 22-24 .. code-block:: Python :dedent: 1 .. GENERATED FROM PYTHON SOURCE LINES 37-43 Quick Start ----------- Before diving into the details, here's a taste of what TorchRL can do. In just a few lines, we can create an environment, build a policy, and collect a trajectory: .. GENERATED FROM PYTHON SOURCE LINES 43-63 .. code-block:: Python import torch from torchrl.envs import GymEnv from torchrl.modules import MLP, QValueActor env = GymEnv("CartPole-v1") actor = QValueActor( MLP( in_features=env.observation_spec["observation"].shape[-1], out_features=2, num_cells=[64, 64], ), in_keys=["observation"], spec=env.action_spec, ) rollout = env.rollout(max_steps=200, policy=actor) print( f"Collected {rollout.shape[0]} steps, total reward: {rollout['next', 'reward'].sum().item():.0f}" ) .. rst-class:: sphx-glr-script-out .. code-block:: none Collected 24 steps, total reward: 24 .. GENERATED FROM PYTHON SOURCE LINES 64-82 That's it! We wrapped a Gym environment, created a Q-value actor with an MLP backbone, and used :meth:`~torchrl.envs.EnvBase.rollout` to collect a full trajectory. The result is a :class:`~tensordict.TensorDict` containing observations, actions, rewards, and more. Now let's understand each component in detail. TensorDict: The Data Backbone ----------------------------- At the heart of TorchRL is :class:`~tensordict.TensorDict` - a dictionary-like container that holds tensors and supports batched operations. Think of it as a "tensor of dictionaries" or a "dictionary of tensors" that knows about its batch dimensions. Why TensorDict? In RL, we constantly pass around groups of related tensors: observations, actions, rewards, done flags, next observations, etc. TensorDict keeps these organized and lets us manipulate them as a unit. .. GENERATED FROM PYTHON SOURCE LINES 82-95 .. code-block:: Python from tensordict import TensorDict # Create a TensorDict representing a batch of 4 transitions batch_size = 4 data = TensorDict( obs=torch.randn(batch_size, 3), action=torch.randn(batch_size, 2), reward=torch.randn(batch_size, 1), batch_size=[batch_size], ) print(data) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ action: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False), obs: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 96-99 TensorDicts support all the operations you'd expect from PyTorch tensors. You can index them, slice them, move them between devices, and stack them together - all while keeping the dictionary structure intact: .. GENERATED FROM PYTHON SOURCE LINES 99-112 .. code-block:: Python # Indexing works just like tensors - grab the first transition print("First element:", data[0]) print("Slice:", data[:2]) # Device transfer moves all contained tensors data_cpu = data.to("cpu") # Stacking is especially useful for building trajectories data2 = data.clone() stacked = torch.stack([data, data2], dim=0) print("Stacked shape:", stacked.batch_size) .. rst-class:: sphx-glr-script-out .. code-block:: none First element: TensorDict( fields={ action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), obs: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) Slice: TensorDict( fields={ action: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False), obs: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([2, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False) Stacked shape: torch.Size([2, 4]) .. GENERATED FROM PYTHON SOURCE LINES 113-116 TensorDicts can also be nested, which is useful for organizing complex observations (e.g., an agent that receives both image pixels and vector state) or for separating "current" from "next" step data: .. GENERATED FROM PYTHON SOURCE LINES 116-128 .. code-block:: Python nested = TensorDict( observation=TensorDict( pixels=torch.randn(4, 3, 84, 84), vector=torch.randn(4, 10), batch_size=[4], ), action=torch.randn(4, 2), batch_size=[4], ) print(nested) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ action: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False), observation: TensorDict( fields={ pixels: Tensor(shape=torch.Size([4, 3, 84, 84]), device=cpu, dtype=torch.float32, is_shared=False), vector: Tensor(shape=torch.Size([4, 10]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4]), device=None, is_shared=False)}, batch_size=torch.Size([4]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 129-140 Environments ------------ TorchRL provides a unified interface for RL environments. Whether you're using Gym, DMControl, IsaacGym, or other simulators, the API stays the same: environments accept and return TensorDicts. **Creating Environments** The simplest way to create an environment is with :class:`~torchrl.envs.GymEnv`, which wraps any Gymnasium (or legacy Gym) environment: .. GENERATED FROM PYTHON SOURCE LINES 140-147 .. code-block:: Python from torchrl.envs import GymEnv env = GymEnv("Pendulum-v1") print("Action spec:", env.action_spec) print("Observation spec:", env.observation_spec) .. rst-class:: sphx-glr-script-out .. code-block:: none Action spec: BoundedContinuous( shape=torch.Size([1]), space=ContinuousBox( low=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True), high=Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, contiguous=True)), device=cpu, dtype=torch.float32, domain=continuous) Observation spec: Composite( observation: BoundedContinuous( shape=torch.Size([3]), space=ContinuousBox( low=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True), high=Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, contiguous=True)), device=cpu, dtype=torch.float32, domain=continuous), device=None, shape=torch.Size([]), data_cls=None) .. GENERATED FROM PYTHON SOURCE LINES 148-153 Every environment has *specs* that describe the shape and bounds of observations, actions, rewards, and done flags. These specs are essential for building correctly-shaped networks and for validating data. The environment interaction follows a familiar pattern - reset, then step: .. GENERATED FROM PYTHON SOURCE LINES 153-162 .. code-block:: Python td = env.reset() print("Reset output:", td) # Sample a random action and take a step td["action"] = env.action_spec.rand() td = env.step(td) print("Step output:", td) .. rst-class:: sphx-glr-script-out .. code-block:: none Reset output: TensorDict( fields={ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) Step output: TensorDict( fields={ action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 163-173 Notice that :meth:`~torchrl.envs.EnvBase.step` returns the same TensorDict with additional keys filled in: the ``"next"`` sub-TensorDict contains the resulting observation, reward, and done flag. **Transforms** Just like torchvision transforms for images, TorchRL provides transforms for environments. These modify observations, actions, or rewards in a composable way. Common uses include normalizing observations, stacking frames, or adding step counters: .. GENERATED FROM PYTHON SOURCE LINES 173-184 .. code-block:: Python from torchrl.envs import Compose, StepCounter, TransformedEnv env = TransformedEnv( GymEnv("Pendulum-v1"), Compose( StepCounter(max_steps=200), # Track steps and auto-terminate ), ) print("Transformed env:", env) .. rst-class:: sphx-glr-script-out .. code-block:: none Transformed env: TransformedEnv( env=GymEnv(env=Pendulum-v1, batch_size=torch.Size([]), device=None), transform=Compose( StepCounter(keys=[]))) .. GENERATED FROM PYTHON SOURCE LINES 185-197 **Batched Environments** RL algorithms are data-hungry. Running multiple environment instances in parallel can dramatically speed up data collection. TorchRL's :class:`~torchrl.envs.ParallelEnv` runs environments in separate processes, returning batched TensorDicts: .. note:: ``ParallelEnv`` uses multiprocessing. The ``mp_start_method`` parameter controls how processes are spawned: ``"fork"`` (Linux default) is fast but can have issues with some libraries; ``"spawn"`` (Windows/macOS default) is safer but requires code to be guarded with ``if __name__ == "__main__"``. .. GENERATED FROM PYTHON SOURCE LINES 197-216 .. code-block:: Python from torchrl.envs import ParallelEnv def make_env(): return GymEnv("Pendulum-v1") # Run 4 environments in parallel vec_env = ParallelEnv(4, make_env) td = vec_env.reset() print("Batched reset:", td.batch_size) td["action"] = vec_env.action_spec.rand() td = vec_env.step(td) print("Batched step:", td.batch_size) vec_env.close() .. rst-class:: sphx-glr-script-out .. code-block:: none Batched reset: torch.Size([4]) Batched step: torch.Size([4]) .. GENERATED FROM PYTHON SOURCE LINES 217-232 The batch dimension (4 in this case) propagates through all tensors, making it easy to process multiple environments with a single forward pass. Modules and Policies -------------------- TorchRL extends PyTorch's ``nn.Module`` system with modules that read from and write to TensorDicts. This makes it easy to build policies that integrate seamlessly with the environment interface. **TensorDictModule** The core building block is :class:`~tensordict.nn.TensorDictModule`. It wraps any ``nn.Module`` and specifies which TensorDict keys to read as inputs and which keys to write as outputs: .. GENERATED FROM PYTHON SOURCE LINES 232-244 .. code-block:: Python from tensordict.nn import TensorDictModule from torch import nn module = nn.Linear(3, 2) td_module = TensorDictModule(module, in_keys=["observation"], out_keys=["action"]) # The module reads "observation" and writes "action" td = TensorDict(observation=torch.randn(4, 3), batch_size=[4]) td_module(td) print(td) # Now has "action" key .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ action: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 245-253 This pattern has a powerful benefit: modules become composable. You can chain them together, and each module only needs to know about its own input/output keys. **Built-in Networks** TorchRL includes common network architectures used in RL. These are regular PyTorch modules that you can wrap with TensorDictModule: .. GENERATED FROM PYTHON SOURCE LINES 253-264 .. code-block:: Python from torchrl.modules import ConvNet, MLP # MLP for vector observations - specify input/output dims and hidden layers mlp = MLP(in_features=64, out_features=10, num_cells=[128, 128]) print(mlp(torch.randn(4, 64)).shape) # ConvNet for image observations - outputs a flat feature vector cnn = ConvNet(num_cells=[32, 64], kernel_sizes=[8, 4], strides=[4, 2]) print(cnn(torch.randn(4, 3, 84, 84)).shape) .. rst-class:: sphx-glr-script-out .. code-block:: none torch.Size([4, 10]) torch.Size([4, 5184]) .. GENERATED FROM PYTHON SOURCE LINES 265-271 **Probabilistic Policies** Many RL algorithms (PPO, SAC, etc.) use stochastic policies that output probability distributions over actions. TorchRL provides :class:`~tensordict.nn.ProbabilisticTensorDictModule` to sample from distributions and optionally compute log-probabilities: .. GENERATED FROM PYTHON SOURCE LINES 271-300 .. code-block:: Python from tensordict.nn import ( ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, ) from torchrl.modules import NormalParamExtractor, TanhNormal # The network outputs mean and std (via NormalParamExtractor) net = nn.Sequential( nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 4), NormalParamExtractor() ) backbone = TensorDictModule(net, in_keys=["observation"], out_keys=["loc", "scale"]) # Combine backbone with a distribution sampler policy = ProbabilisticTensorDictSequential( backbone, ProbabilisticTensorDictModule( in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal, return_log_prob=True, ), ) td = TensorDict(observation=torch.randn(4, 3), batch_size=[4]) policy(td) print("Sampled action:", td["action"].shape) print("Log prob:", td["action_log_prob"].shape) .. rst-class:: sphx-glr-script-out .. code-block:: none Sampled action: torch.Size([4, 2]) Log prob: torch.Size([4]) .. GENERATED FROM PYTHON SOURCE LINES 301-315 The ``TanhNormal`` distribution squashes samples to [-1, 1], which is useful for continuous control. The log-probability accounts for this transformation, which is crucial for policy gradient methods. Data Collection --------------- In RL, we need to repeatedly collect experience from the environment. While you can write your own rollout loop, TorchRL's *collectors* handle this efficiently, including batching, device management, and multi-process collection. The :class:`~torchrl.collectors.SyncDataCollector` collects data synchronously - it waits for a batch to be ready before returning: .. GENERATED FROM PYTHON SOURCE LINES 315-335 .. code-block:: Python from torchrl.collectors import SyncDataCollector # A simple deterministic policy for demonstration actor = TensorDictModule(nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]) collector = SyncDataCollector( create_env_fn=lambda: GymEnv("Pendulum-v1"), policy=actor, frames_per_batch=200, # Collect 200 frames per iteration total_frames=1000, # Stop after 1000 total frames ) for batch in collector: print( f"Collected batch: {batch.shape}, reward: {batch['next', 'reward'].mean():.2f}" ) collector.shutdown() .. rst-class:: sphx-glr-script-out .. code-block:: none Collected batch: torch.Size([200]), reward: -8.15 Collected batch: torch.Size([200]), reward: -8.81 Collected batch: torch.Size([200]), reward: -9.42 Collected batch: torch.Size([200]), reward: -9.33 Collected batch: torch.Size([200]), reward: -9.26 .. GENERATED FROM PYTHON SOURCE LINES 336-345 For async collection (useful when training takes longer than collecting), see :class:`~torchrl.collectors.MultiaSyncDataCollector`. Replay Buffers -------------- Most RL algorithms don't learn from experience immediately - they store transitions in a buffer and sample mini-batches for training. TorchRL's replay buffers handle this efficiently: .. GENERATED FROM PYTHON SOURCE LINES 345-359 .. code-block:: Python from torchrl.data import LazyTensorStorage, ReplayBuffer buffer = ReplayBuffer(storage=LazyTensorStorage(max_size=10000)) # Add a batch of experience buffer.extend( TensorDict(obs=torch.randn(100, 4), action=torch.randn(100, 2), batch_size=[100]) ) # Sample a mini-batch for training sample = buffer.sample(32) print("Sampled batch:", sample.batch_size) .. rst-class:: sphx-glr-script-out .. code-block:: none Sampled batch: torch.Size([32]) .. GENERATED FROM PYTHON SOURCE LINES 360-363 The :class:`~torchrl.data.LazyTensorStorage` allocates memory lazily based on the first batch added. For prioritized experience replay (used in DQN variants), use :class:`~torchrl.data.PrioritizedReplayBuffer`: .. GENERATED FROM PYTHON SOURCE LINES 363-377 .. code-block:: Python from torchrl.data import PrioritizedReplayBuffer buffer = PrioritizedReplayBuffer( alpha=0.6, # Priority exponent beta=0.4, # Importance sampling exponent storage=LazyTensorStorage(max_size=10000), ) buffer.extend(TensorDict(obs=torch.randn(100, 4), batch_size=[100])) # Use return_info=True to get sampling metadata (indices, weights) sample, info = buffer.sample(32, return_info=True) print("Prioritized sample indices:", info["index"][:5], "...") # First 5 indices .. rst-class:: sphx-glr-script-out .. code-block:: none Prioritized sample indices: tensor([42, 33, 61, 39, 54]) ... .. GENERATED FROM PYTHON SOURCE LINES 378-392 Loss Functions -------------- The final piece is the objective function. TorchRL provides loss classes for major RL algorithms, encapsulating the often-complex loss computations: - :class:`~torchrl.objectives.DQNLoss` - Deep Q-Networks - :class:`~torchrl.objectives.DDPGLoss` - Deep Deterministic Policy Gradient - :class:`~torchrl.objectives.SACLoss` - Soft Actor-Critic - :class:`~torchrl.objectives.PPOLoss` - Proximal Policy Optimization - :class:`~torchrl.objectives.TD3Loss` - Twin Delayed DDPG Here's how to set up a DQN loss. We create a Q-network wrapped in a :class:`~torchrl.modules.QValueActor`, which handles action selection: .. GENERATED FROM PYTHON SOURCE LINES 392-407 .. code-block:: Python from torchrl.objectives import DQNLoss qnet = TensorDictModule( nn.Sequential(nn.Linear(4, 64), nn.ReLU(), nn.Linear(64, 2)), in_keys=["observation"], out_keys=["action_value"], ) # QValueActor wraps the Q-network to select actions and output chosen values from torchrl.data import Categorical actor = QValueActor(qnet, in_keys=["observation"], spec=Categorical(n=2)) loss_fn = DQNLoss(actor, action_space="categorical") .. GENERATED FROM PYTHON SOURCE LINES 408-410 The loss function expects batches with specific keys. Let's create a dummy batch to see it in action: .. GENERATED FROM PYTHON SOURCE LINES 410-427 .. code-block:: Python batch = TensorDict( observation=torch.randn(32, 4), action=torch.randint(0, 2, (32,)), next=TensorDict( observation=torch.randn(32, 4), reward=torch.randn(32, 1), done=torch.zeros(32, 1, dtype=torch.bool), terminated=torch.zeros(32, 1, dtype=torch.bool), batch_size=[32], ), batch_size=[32], ) loss_td = loss_fn(batch) print("Loss:", loss_td["loss"]) .. rst-class:: sphx-glr-script-out .. code-block:: none Loss: tensor(1.0701, grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 428-436 The loss function handles target network updates, Bellman backup computation, and all the bookkeeping needed for stable training. Putting It All Together ----------------------- Now let's see how all these components work together in a complete training loop. We'll train a simple DQN agent on CartPole: .. GENERATED FROM PYTHON SOURCE LINES 436-489 .. code-block:: Python torch.manual_seed(0) # 1. Create the environment env = GymEnv("CartPole-v1") # 2. Build a Q-network and wrap it as a policy qnet = TensorDictModule( nn.Sequential(nn.Linear(4, 128), nn.ReLU(), nn.Linear(128, 2)), in_keys=["observation"], out_keys=["action_value"], ) policy = QValueActor(qnet, in_keys=["observation"], spec=env.action_spec) # 3. Set up the data collector collector = SyncDataCollector( create_env_fn=lambda: GymEnv("CartPole-v1"), policy=policy, frames_per_batch=100, total_frames=2000, ) # 4. Create a replay buffer buffer = ReplayBuffer(storage=LazyTensorStorage(max_size=10000)) # 5. Set up the loss and optimizer (pass the QValueActor, not just the network) loss_fn = DQNLoss(policy, action_space=env.action_spec) optimizer = torch.optim.Adam(policy.parameters(), lr=1e-3) # 6. Training loop: collect -> store -> sample -> train for i, batch in enumerate(collector): # Store collected experience buffer.extend(batch) # Wait until we have enough data if len(buffer) < 100: continue # Sample a batch and compute the loss sample = buffer.sample(64) loss = loss_fn(sample) # Standard PyTorch optimization step optimizer.zero_grad() loss["loss"].backward() optimizer.step() if i % 5 == 0: print(f"Step {i}: loss={loss['loss'].item():.3f}") collector.shutdown() env.close() .. rst-class:: sphx-glr-script-out .. code-block:: none Step 0: loss=1.027 Step 5: loss=0.687 Step 10: loss=0.451 Step 15: loss=0.269 .. GENERATED FROM PYTHON SOURCE LINES 490-528 This is a minimal example - a production DQN would include target network updates, epsilon-greedy exploration, and more. Check out the full implementations in ``sota-implementations/dqn/``. What's Next? ------------ This tutorial covered the basics. TorchRL has much more to offer: **Tutorials:** - `PPO Tutorial <../tutorials/coding_ppo.html>`_ - Train PPO on MuJoCo - `DQN Tutorial <../tutorials/coding_dqn.html>`_ - Deep Q-Learning from scratch - `Multi-Agent RL <../tutorials/multiagent_ppo.html>`_ - Cooperative and competitive agents **SOTA Implementations:** The `sota-implementations/ `_ folder contains production-ready implementations of: - PPO, A2C, SAC, TD3, DDPG, DQN - Offline RL: CQL, IQL, Decision Transformer - Multi-agent: IPPO, QMIX, MADDPG - LLM training: GRPO, Expert Iteration **Advanced Features:** - Distributed training with Ray and RPC - Offline RL datasets (D4RL, Minari) - Model-based RL (Dreamer) - LLM integration for RLHF **Resources:** - `API Reference `_ - `GitHub `_ - `Contributing Guide `_ .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 4.362 seconds) .. _sphx_glr_download_tutorials_torchrl_demo.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: torchrl_demo.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: torchrl_demo.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: torchrl_demo.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_