.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/torchrl_demo.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_torchrl_demo.py: Introduction to TorchRL ======================= This demo was presented at ICML 2022 on the industry demo day. .. GENERATED FROM PYTHON SOURCE LINES 7-186 It gives a good overview of TorchRL functionalities. Feel free to reach out to vmoens@fb.com or submit issues if you have questions or comments about it. TorchRL is an open-source Reinforcement Learning (RL) library for PyTorch. https://github.com/pytorch/rl The PyTorch ecosystem team (Meta) has decided to invest in that library to provide a leading platform to develop RL solutions in research settings. It provides pytorch and **python-first**, low and high level **abstractions** # for RL that are intended to be efficient, documented and properly tested. The code is aimed at supporting research in RL. Most of it is written in python in a highly modular way, such that researchers can easily swap components, transform them or write new ones with little effort. This repo attempts to align with the existing pytorch ecosystem libraries in that it has a dataset pillar (torchrl/envs), transforms, models, data utilities (e.g. collectors and containers), etc. TorchRL aims at having as few dependencies as possible (python standard library, numpy and pytorch). Common environment libraries (e.g. OpenAI gym) are only optional. **Content**: .. aafig:: "torchrl" │ ├── "collectors" │ └── "collectors.py" │ │ │ └── "distributed" │ └── "default_configs.py" │ └── "generic.py" │ └── "ray.py" │ └── "rpc.py" │ └── "sync.py" ├── "data" │ │ │ ├── "datasets" │ │ └── "atari_dqn.py" │ │ └── "d4rl.py" │ │ └── "d4rl_infos.py" │ │ └── "gen_dgrl.py" │ │ └── "minari_data.py" │ │ └── "openml.py" │ │ └── "openx.py" │ │ └── "roboset.py" │ │ └── "vd4rl.py" │ ├── "postprocs" │ │ └── "postprocs.py" │ ├── "replay_buffers" │ │ └── "replay_buffers.py" │ │ └── "samplers.py" │ │ └── "storages.py" │ │ └── "writers.py" │ ├── "rlhf" │ │ └── "dataset.py" │ │ └── "prompt.py" │ │ └── "reward.py" │ └── "tensor_specs.py" ├── "envs" │ └── "batched_envs.py" │ └── "common.py" │ └── "env_creator.py" │ └── "gym_like.py" │ ├── "libs" │ │ └── "brax.py" │ │ └── "dm_control.py" │ │ └── "envpool.py" │ │ └── "gym.py" │ │ └── "habitat.py" │ │ └── "isaacgym.py" │ │ └── "jumanji.py" │ │ └── "openml.py" │ │ └── "pettingzoo.py" │ │ └── "robohive.py" │ │ └── "smacv2.py" │ │ └── "vmas.py" │ ├── "model_based" │ │ └── "common.py" │ │ └── "dreamer.py" │ ├── "transforms" │ │ └── "functional.py" │ │ └── "gym_transforms.py" │ │ └── "r3m.py" │ │ └── "rlhf.py" │ │ └── "vc1.py" │ │ └── "vip.py" │ └── "vec_envs.py" ├── "modules" │ ├── "distributions" │ │ └── "continuous.py" │ │ └── "discrete.py" │ │ └── "truncated_normal.py" │ ├── "models" │ │ └── "decision_transformer.py" │ │ └── "exploration.py" │ │ └── "model_based.py" │ │ └── "models.py" │ │ └── "multiagent.py" │ │ └── "rlhf.py" │ ├── "planners" │ │ └── "cem.py" │ │ └── "common.py" │ │ └── "mppi.py" │ └── "tensordict_module" │ └── "actors.py" │ └── "common.py" │ └── "exploration.py" │ └── "probabilistic.py" │ └── "rnn.py" │ └── "sequence.py" │ └── "world_models.py" ├── "objectives" │ └── "a2c.py" │ └── "common.py" │ └── "cql.py" │ └── "ddpg.py" │ └── "decision_transformer.py" │ └── "deprecated.py" │ └── "dqn.py" │ └── "dreamer.py" │ └── "functional.py" │ └── "iql.py" │ ├── "multiagent" │ │ └── "qmixer.py" │ └── "ppo.py" │ └── "redq.py" │ └── "reinforce.py" │ └── "sac.py" │ └── "td3.py" │ ├── "value" │ └── "advantages.py" │ └── "functional.py" │ └── "pg.py" ├── "record" │ ├── "loggers" │ │ └── "common.py" │ │ └── "csv.py" │ │ └── "mlflow.py" │ │ └── "tensorboard.py" │ │ └── "wandb.py" │ └── "recorder.py" ├── "trainers" │ │ │ ├── "helpers" │ │ └── "collectors.py" │ │ └── "envs.py" │ │ └── "logger.py" │ │ └── "losses.py" │ │ └── "models.py" │ │ └── "replay_buffer.py" │ │ └── "trainers.py" │ └── "trainers.py" └── "version.py" Unlike other domains, RL is less about media than *algorithms*. As such, it is harder to make truly independent components. What TorchRL is not: * a collection of algorithms: we do not intend to provide SOTA implementations of RL algorithms, but we provide these algorithms only as examples of how to use the library. * a research framework: modularity in TorchRL comes in two flavours. First, we try to build re-usable components, such that they can be easily swapped with each other. Second, we make our best such that components can be used independently of the rest of the library. TorchRL has very few core dependencies, predominantly PyTorch and numpy. All other dependencies (gym, torchvision, wandb / tensorboard) are optional. Data ^^^^ TensorDict ---------- .. GENERATED FROM PYTHON SOURCE LINES 186-191 .. code-block:: Python import torch from tensordict import TensorDict .. GENERATED FROM PYTHON SOURCE LINES 214-215 Let's create a TensorDict. .. GENERATED FROM PYTHON SOURCE LINES 215-226 .. code-block:: Python batch_size = 5 tensordict = TensorDict( source={ "key 1": torch.zeros(batch_size, 3), "key 2": torch.zeros(batch_size, 5, 6, dtype=torch.bool), }, batch_size=[batch_size], ) print(tensordict) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ key 1: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False), key 2: Tensor(shape=torch.Size([5, 5, 6]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([5]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 227-228 You can index a TensorDict as well as query keys. .. GENERATED FROM PYTHON SOURCE LINES 228-232 .. code-block:: Python print(tensordict[2]) print(tensordict["key 1"] is tensordict.get("key 1")) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ key 1: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), key 2: Tensor(shape=torch.Size([5, 6]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) True .. GENERATED FROM PYTHON SOURCE LINES 233-234 The following shows how to stack multiple TensorDicts. .. GENERATED FROM PYTHON SOURCE LINES 234-254 .. code-block:: Python tensordict1 = TensorDict( source={ "key 1": torch.zeros(batch_size, 1), "key 2": torch.zeros(batch_size, 5, 6, dtype=torch.bool), }, batch_size=[batch_size], ) tensordict2 = TensorDict( source={ "key 1": torch.ones(batch_size, 1), "key 2": torch.ones(batch_size, 5, 6, dtype=torch.bool), }, batch_size=[batch_size], ) tensordict = torch.stack([tensordict1, tensordict2], 0) tensordict.batch_size, tensordict["key 1"] .. rst-class:: sphx-glr-script-out .. code-block:: none (torch.Size([2, 5]), tensor([[[0.], [0.], [0.], [0.], [0.]], [[1.], [1.], [1.], [1.], [1.]]])) .. GENERATED FROM PYTHON SOURCE LINES 255-256 Here are some other functionalities of TensorDict. .. GENERATED FROM PYTHON SOURCE LINES 256-281 .. code-block:: Python print( "view(-1): ", tensordict.view(-1).batch_size, tensordict.view(-1).get("key 1").shape, ) print("to device: ", tensordict.to("cpu")) # print("pin_memory: ", tensordict.pin_memory()) print("share memory: ", tensordict.share_memory_()) print( "permute(1, 0): ", tensordict.permute(1, 0).batch_size, tensordict.permute(1, 0).get("key 1").shape, ) print( "expand: ", tensordict.expand(3, *tensordict.batch_size).batch_size, tensordict.expand(3, *tensordict.batch_size).get("key 1").shape, ) .. rst-class:: sphx-glr-script-out .. code-block:: none view(-1): torch.Size([10]) torch.Size([10, 1]) to device: TensorDict( fields={ key 1: Tensor(shape=torch.Size([2, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False), key 2: Tensor(shape=torch.Size([2, 5, 5, 6]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([2, 5]), device=cpu, is_shared=False) share memory: TensorDict( fields={ key 1: Tensor(shape=torch.Size([2, 5, 1]), device=cpu, dtype=torch.float32, is_shared=True), key 2: Tensor(shape=torch.Size([2, 5, 5, 6]), device=cpu, dtype=torch.bool, is_shared=True)}, batch_size=torch.Size([2, 5]), device=None, is_shared=True) permute(1, 0): torch.Size([5, 2]) torch.Size([5, 2, 1]) expand: torch.Size([3, 2, 5]) torch.Size([3, 2, 5, 1]) .. GENERATED FROM PYTHON SOURCE LINES 282-283 You can create a **nested TensorDict** as well. .. GENERATED FROM PYTHON SOURCE LINES 283-296 .. code-block:: Python tensordict = TensorDict( source={ "key 1": torch.zeros(batch_size, 3), "key 2": TensorDict( source={"sub-key 1": torch.zeros(batch_size, 2, 1)}, batch_size=[batch_size, 2], ), }, batch_size=[batch_size], ) tensordict .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ key 1: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False), key 2: TensorDict( fields={ sub-key 1: Tensor(shape=torch.Size([5, 2, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5, 2]), device=None, is_shared=False)}, batch_size=torch.Size([5]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 297-299 Replay buffers ------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 299-302 .. code-block:: Python from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer .. GENERATED FROM PYTHON SOURCE LINES 303-308 .. code-block:: Python rb = ReplayBuffer(collate_fn=lambda x: x) rb.add(1) rb.sample(1) .. rst-class:: sphx-glr-script-out .. code-block:: none [1] .. GENERATED FROM PYTHON SOURCE LINES 309-313 .. code-block:: Python rb.extend([2, 3]) rb.sample(3) .. rst-class:: sphx-glr-script-out .. code-block:: none [2, 1, 3] .. GENERATED FROM PYTHON SOURCE LINES 314-320 .. code-block:: Python rb = PrioritizedReplayBuffer(alpha=0.7, beta=1.1, collate_fn=lambda x: x) rb.add(1) rb.sample(1) rb.update_priority(1, 0.5) .. GENERATED FROM PYTHON SOURCE LINES 321-322 Here are examples of using a replaybuffer with tensordicts. .. GENERATED FROM PYTHON SOURCE LINES 322-328 .. code-block:: Python collate_fn = torch.stack rb = ReplayBuffer(collate_fn=collate_fn) rb.add(TensorDict({"a": torch.randn(3)}, batch_size=[])) len(rb) .. rst-class:: sphx-glr-script-out .. code-block:: none 1 .. GENERATED FROM PYTHON SOURCE LINES 329-335 .. code-block:: Python rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2])) print(len(rb)) print(rb.sample(10)) print(rb.sample(2).contiguous()) .. rst-class:: sphx-glr-script-out .. code-block:: none 3 TensorDict( fields={ a: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([10]), device=None, is_shared=False) TensorDict( fields={ a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 336-345 .. code-block:: Python torch.manual_seed(0) from torchrl.data import TensorDictPrioritizedReplayBuffer rb = TensorDictPrioritizedReplayBuffer(alpha=0.7, beta=1.1, priority_key="td_error") rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2])) tensordict_sample = rb.sample(2).contiguous() tensordict_sample .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ _weight: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), a: Tensor(shape=torch.Size([2, 3]), device=cpu, dtype=torch.float32, is_shared=False), index: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([2]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 346-349 .. code-block:: Python tensordict_sample["index"] .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([0, 0]) .. GENERATED FROM PYTHON SOURCE LINES 350-364 .. code-block:: Python tensordict_sample["td_error"] = torch.rand(2) rb.update_tensordict_priority(tensordict_sample) for i, val in enumerate(rb._sampler._sum_tree): print(i, val) if i == len(rb): break try: import gymnasium as gym except ModuleNotFoundError: import gym .. rst-class:: sphx-glr-script-out .. code-block:: none 0 0.28791671991348267 1 1.0 2 0.0 .. GENERATED FROM PYTHON SOURCE LINES 365-367 Envs ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 367-374 .. code-block:: Python from torchrl.envs.libs.gym import GymEnv, GymWrapper gym_env = gym.make("Pendulum-v1") env = GymWrapper(gym_env) env = GymEnv("Pendulum-v1") .. GENERATED FROM PYTHON SOURCE LINES 375-379 .. code-block:: Python tensordict = env.reset() env.rand_step(tensordict) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 380-382 Changing environments config ------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 382-386 .. code-block:: Python env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False) env.reset() .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), pixels: Tensor(shape=torch.Size([500, 500, 3]), device=cpu, dtype=torch.uint8, is_shared=False), terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 387-391 .. code-block:: Python env.close() del env .. GENERATED FROM PYTHON SOURCE LINES 392-405 .. code-block:: Python from torchrl.envs import ( Compose, NoopResetEnv, ObservationNorm, ToTensorImage, TransformedEnv, ) base_env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False) env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage())) env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1)) .. rst-class:: sphx-glr-script-out .. code-block:: none TransformedEnv( env=GymEnv(env=Pendulum-v1, batch_size=torch.Size([]), device=None), transform=Compose( NoopResetEnv(noops=3, random=True), ToTensorImage(keys=['pixels']), ObservationNorm(loc=2.0000, scale=1.0000, keys=['pixels']))) .. GENERATED FROM PYTHON SOURCE LINES 406-408 Transforms ------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 408-422 .. code-block:: Python from torchrl.envs import ( Compose, NoopResetEnv, ObservationNorm, StepCounter, ToTensorImage, TransformedEnv, ) base_env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False) env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage())) env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1)) .. rst-class:: sphx-glr-script-out .. code-block:: none TransformedEnv( env=GymEnv(env=Pendulum-v1, batch_size=torch.Size([]), device=None), transform=Compose( NoopResetEnv(noops=3, random=True), ToTensorImage(keys=['pixels']), ObservationNorm(loc=2.0000, scale=1.0000, keys=['pixels']))) .. GENERATED FROM PYTHON SOURCE LINES 423-426 .. code-block:: Python env.reset() .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), pixels: Tensor(shape=torch.Size([3, 500, 500]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 427-431 .. code-block:: Python print("env: ", env) print("last transform parent: ", env.transform[2].parent) .. rst-class:: sphx-glr-script-out .. code-block:: none env: TransformedEnv( env=GymEnv(env=Pendulum-v1, batch_size=torch.Size([]), device=None), transform=Compose( NoopResetEnv(noops=3, random=True), ToTensorImage(keys=['pixels']), ObservationNorm(loc=2.0000, scale=1.0000, keys=['pixels']))) last transform parent: TransformedEnv( env=GymEnv(env=Pendulum-v1, batch_size=torch.Size([]), device=None), transform=Compose( NoopResetEnv(noops=3, random=True), ToTensorImage(keys=['pixels']))) .. GENERATED FROM PYTHON SOURCE LINES 432-434 Vectorized Environments ------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 434-448 .. code-block:: Python from torchrl.envs import ParallelEnv base_env = ParallelEnv( 4, lambda: GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False), mp_start_method="fork", # This will break on Windows machines! Remove and decorate with if __name__ == "__main__" ) env = TransformedEnv( base_env, Compose(StepCounter(), ToTensorImage()) ) # applies transforms on batch of envs env.append_transform(ObservationNorm(in_keys=["pixels"], loc=2, scale=1)) env.reset() .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ done: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False), pixels: Tensor(shape=torch.Size([4, 3, 500, 500]), device=cpu, dtype=torch.float32, is_shared=False), step_count: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.int64, is_shared=False), terminated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([4]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 449-455 .. code-block:: Python print(env.action_spec) env.close() del env .. rst-class:: sphx-glr-script-out .. code-block:: none BoundedTensorSpec( shape=torch.Size([4, 1]), space=ContinuousBox( low=Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, contiguous=True), high=Tensor(shape=torch.Size([4, 1]), device=cpu, dtype=torch.float32, contiguous=True)), device=cpu, dtype=torch.float32, domain=continuous) .. GENERATED FROM PYTHON SOURCE LINES 456-463 Modules ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Models ------------------------------ Example of a MLP model: .. GENERATED FROM PYTHON SOURCE LINES 463-466 .. code-block:: Python from torch import nn .. GENERATED FROM PYTHON SOURCE LINES 467-475 .. code-block:: Python from torchrl.modules import ConvNet, MLP from torchrl.modules.models.utils import SquashDims net = MLP(num_cells=[32, 64], out_features=4, activation_class=nn.ELU) print(net) print(net(torch.randn(10, 3)).shape) .. rst-class:: sphx-glr-script-out .. code-block:: none MLP( (0): LazyLinear(in_features=0, out_features=32, bias=True) (1): ELU(alpha=1.0) (2): Linear(in_features=32, out_features=64, bias=True) (3): ELU(alpha=1.0) (4): Linear(in_features=64, out_features=4, bias=True) ) torch.Size([10, 4]) .. GENERATED FROM PYTHON SOURCE LINES 476-477 Example of a CNN model: .. GENERATED FROM PYTHON SOURCE LINES 477-487 .. code-block:: Python cnn = ConvNet( num_cells=[32, 64], kernel_sizes=[8, 4], strides=[2, 1], aggregator_class=SquashDims, ) print(cnn) print(cnn(torch.randn(10, 3, 32, 32)).shape) # last tensor is squashed .. rst-class:: sphx-glr-script-out .. code-block:: none ConvNet( (0): LazyConv2d(0, 32, kernel_size=(8, 8), stride=(2, 2)) (1): ELU(alpha=1.0) (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(1, 1)) (3): ELU(alpha=1.0) (4): SquashDims() ) torch.Size([10, 6400]) .. GENERATED FROM PYTHON SOURCE LINES 488-490 TensorDictModules ------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 490-499 .. code-block:: Python from tensordict.nn import TensorDictModule tensordict = TensorDict({"key 1": torch.randn(10, 3)}, batch_size=[10]) module = nn.Linear(3, 4) td_module = TensorDictModule(module, in_keys=["key 1"], out_keys=["key 2"]) td_module(tensordict) print(tensordict) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ key 1: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False), key 2: Tensor(shape=torch.Size([10, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([10]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 500-502 Sequences of Modules ------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 502-517 .. code-block:: Python from tensordict.nn import TensorDictSequential backbone_module = nn.Linear(5, 3) backbone = TensorDictModule( backbone_module, in_keys=["observation"], out_keys=["hidden"] ) actor_module = nn.Linear(3, 4) actor = TensorDictModule(actor_module, in_keys=["hidden"], out_keys=["action"]) value_module = MLP(out_features=1, num_cells=[4, 5]) value = TensorDictModule(value_module, in_keys=["hidden", "action"], out_keys=["value"]) sequence = TensorDictSequential(backbone, actor, value) print(sequence) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDictSequential( module=ModuleList( (0): TensorDictModule( module=Linear(in_features=5, out_features=3, bias=True), device=cpu, in_keys=['observation'], out_keys=['hidden']) (1): TensorDictModule( module=Linear(in_features=3, out_features=4, bias=True), device=cpu, in_keys=['hidden'], out_keys=['action']) (2): TensorDictModule( module=MLP( (0): LazyLinear(in_features=0, out_features=4, bias=True) (1): Tanh() (2): Linear(in_features=4, out_features=5, bias=True) (3): Tanh() (4): Linear(in_features=5, out_features=1, bias=True) ), device=cpu, in_keys=['hidden', 'action'], out_keys=['value']) ), device=cpu, in_keys=['observation'], out_keys=['hidden', 'action', 'value']) .. GENERATED FROM PYTHON SOURCE LINES 518-521 .. code-block:: Python print(sequence.in_keys, sequence.out_keys) .. rst-class:: sphx-glr-script-out .. code-block:: none ['observation'] ['hidden', 'action', 'value'] .. GENERATED FROM PYTHON SOURCE LINES 522-531 .. code-block:: Python tensordict = TensorDict( {"observation": torch.randn(3, 5)}, [3], ) backbone(tensordict) actor(tensordict) value(tensordict) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), hidden: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([3, 5]), device=cpu, dtype=torch.float32, is_shared=False), value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 532-540 .. code-block:: Python tensordict = TensorDict( {"observation": torch.randn(3, 5)}, [3], ) sequence(tensordict) print(tensordict) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ action: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False), hidden: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([3, 5]), device=cpu, dtype=torch.float32, is_shared=False), value: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 541-543 Functional Programming (Ensembling / Meta-RL) ---------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 543-549 .. code-block:: Python from tensordict import TensorDict params = TensorDict.from_module(sequence) print("extracted params", params) .. rst-class:: sphx-glr-script-out .. code-block:: none extracted params TensorDict( fields={ module: TensorDict( fields={ 0: TensorDict( fields={ module: TensorDict( fields={ bias: Parameter(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), weight: Parameter(shape=torch.Size([3, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), 1: TensorDict( fields={ module: TensorDict( fields={ bias: Parameter(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Parameter(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), 2: TensorDict( fields={ module: TensorDict( fields={ 0: TensorDict( fields={ bias: Parameter(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Parameter(shape=torch.Size([4, 7]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), 2: TensorDict( fields={ bias: Parameter(shape=torch.Size([5]), device=cpu, dtype=torch.float32, is_shared=False), weight: Parameter(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), 4: TensorDict( fields={ bias: Parameter(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), weight: Parameter(shape=torch.Size([1, 5]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 550-551 functional call using tensordict: .. GENERATED FROM PYTHON SOURCE LINES 551-555 .. code-block:: Python with params.to_module(sequence): sequence(tensordict) .. GENERATED FROM PYTHON SOURCE LINES 556-557 Using vectorized map for model ensembling .. GENERATED FROM PYTHON SOURCE LINES 557-570 .. code-block:: Python from torch import vmap params_expand = params.expand(4) def exec_sequence(params, data): with params.to_module(sequence): return sequence(data) tensordict_exp = vmap(exec_sequence, (0, None))(params_expand, tensordict) print(tensordict_exp) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ action: Tensor(shape=torch.Size([4, 3, 4]), device=cpu, dtype=torch.float32, is_shared=False), hidden: Tensor(shape=torch.Size([4, 3, 3]), device=cpu, dtype=torch.float32, is_shared=False), observation: Tensor(shape=torch.Size([4, 3, 5]), device=cpu, dtype=torch.float32, is_shared=False), value: Tensor(shape=torch.Size([4, 3, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([4, 3]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 571-573 Specialized Classes ------------------------------ .. GENERATED FROM PYTHON SOURCE LINES 573-586 .. code-block:: Python torch.manual_seed(0) from torchrl.data import BoundedTensorSpec from torchrl.modules import SafeModule spec = BoundedTensorSpec(-torch.ones(3), torch.ones(3)) base_module = nn.Linear(5, 3) module = SafeModule( module=base_module, spec=spec, in_keys=["obs"], out_keys=["action"], safe=True ) tensordict = TensorDict({"obs": torch.randn(5)}, batch_size=[]) module(tensordict)["action"] .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([-0.0137, 0.1524, -0.0641], grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 587-591 .. code-block:: Python tensordict = TensorDict({"obs": torch.randn(5) * 100}, batch_size=[]) module(tensordict)["action"] # safe=True projects the result within the set .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([-1., 1., -1.], grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 592-605 .. code-block:: Python from torchrl.modules import Actor base_module = nn.Linear(5, 3) actor = Actor(base_module, in_keys=["obs"]) tensordict = TensorDict({"obs": torch.randn(5)}, batch_size=[]) actor(tensordict) # action is the default value from tensordict.nn import ( ProbabilisticTensorDictModule, ProbabilisticTensorDictSequential, ) .. GENERATED FROM PYTHON SOURCE LINES 606-627 .. code-block:: Python # Probabilistic modules from torchrl.modules import NormalParamExtractor, TanhNormal td = TensorDict({"input": torch.randn(3, 5)}, [3]) net = nn.Sequential( nn.Linear(5, 4), NormalParamExtractor() ) # splits the output in loc and scale module = TensorDictModule(net, in_keys=["input"], out_keys=["loc", "scale"]) td_module = ProbabilisticTensorDictSequential( module, ProbabilisticTensorDictModule( in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal, return_log_prob=False, ), ) td_module(td) print(td) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ action: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([3, 5]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 628-643 .. code-block:: Python # returning the log-probability td = TensorDict({"input": torch.randn(3, 5)}, [3]) td_module = ProbabilisticTensorDictSequential( module, ProbabilisticTensorDictModule( in_keys=["loc", "scale"], out_keys=["action"], distribution_class=TanhNormal, return_log_prob=True, ), ) td_module(td) print(td) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ action: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False), input: Tensor(shape=torch.Size([3, 5]), device=cpu, dtype=torch.float32, is_shared=False), loc: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False), sample_log_prob: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False), scale: Tensor(shape=torch.Size([3, 2]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 644-663 .. code-block:: Python # Sampling vs mode / mean from torchrl.envs.utils import ExplorationType, set_exploration_type td = TensorDict({"input": torch.randn(3, 5)}, [3]) torch.manual_seed(0) with set_exploration_type(ExplorationType.RANDOM): td_module(td) print("random:", td["action"]) with set_exploration_type(ExplorationType.MODE): td_module(td) print("mode:", td["action"]) with set_exploration_type(ExplorationType.MODE): td_module(td) print("mean:", td["action"]) .. rst-class:: sphx-glr-script-out .. code-block:: none random: tensor([[ 0.8728, -0.1334], [-0.9833, 0.3494], [-0.6887, -0.6402]], grad_fn=<_SafeTanhBackward>) mode: tensor([[-0.1132, 0.1762], [-0.3430, -0.2668], [ 0.2918, 0.6239]], grad_fn=<_SafeTanhBackward>) mean: tensor([[-0.1132, 0.1762], [-0.3430, -0.2668], [ 0.2918, 0.6239]], grad_fn=<_SafeTanhBackward>) .. GENERATED FROM PYTHON SOURCE LINES 664-666 Using Environments and Modules ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 666-694 .. code-block:: Python from torchrl.envs.utils import step_mdp env = GymEnv("Pendulum-v1") action_spec = env.action_spec actor_module = nn.Linear(3, 1) actor = SafeModule( actor_module, spec=action_spec, in_keys=["observation"], out_keys=["action"] ) torch.manual_seed(0) env.set_seed(0) max_steps = 100 tensordict = env.reset() tensordicts = TensorDict({}, [max_steps]) for i in range(max_steps): actor(tensordict) tensordicts[i] = env.step(tensordict) if tensordict["done"].any(): break tensordict = step_mdp(tensordict) # roughly equivalent to obs = next_obs tensordicts_prealloc = tensordicts.clone() print("total steps:", i) print(tensordicts) .. rst-class:: sphx-glr-script-out .. code-block:: none total steps: 99 TensorDict( fields={ action: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([100, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([100]), device=None, is_shared=False), observation: Tensor(shape=torch.Size([100, 3]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([100]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 695-713 .. code-block:: Python # equivalent torch.manual_seed(0) env.set_seed(0) max_steps = 100 tensordict = env.reset() tensordicts = [] for _ in range(max_steps): actor(tensordict) tensordicts.append(env.step(tensordict)) if tensordict["done"].any(): break tensordict = step_mdp(tensordict) # roughly equivalent to obs = next_obs tensordicts_stack = torch.stack(tensordicts, 0) print("total steps:", i) print(tensordicts_stack) .. rst-class:: sphx-glr-script-out .. code-block:: none total steps: 99 TensorDict( fields={ action: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.float32, is_shared=False), done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([100, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([100]), device=None, is_shared=False), observation: Tensor(shape=torch.Size([100, 3]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([100, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([100]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 714-717 .. code-block:: Python (tensordicts_stack == tensordicts_prealloc).all() .. rst-class:: sphx-glr-script-out .. code-block:: none True .. GENERATED FROM PYTHON SOURCE LINES 718-729 .. code-block:: Python torch.manual_seed(0) env.set_seed(0) tensordict_rollout = env.rollout(policy=actor, max_steps=max_steps) tensordict_rollout (tensordict_rollout == tensordicts_prealloc).all() from tensordict.nn import TensorDictModule .. GENERATED FROM PYTHON SOURCE LINES 730-732 Collectors ^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 732-738 .. code-block:: Python from torchrl.collectors import MultiaSyncDataCollector, MultiSyncDataCollector from torchrl.envs import EnvCreator, SerialEnv from torchrl.envs.libs.gym import GymEnv .. GENERATED FROM PYTHON SOURCE LINES 739-741 EnvCreator makes sure that we can send a lambda function from process to process We use a SerialEnv for simplicity, but for larger jobs a ParallelEnv would be better suited. .. GENERATED FROM PYTHON SOURCE LINES 741-751 .. code-block:: Python parallel_env = SerialEnv( 3, EnvCreator(lambda: GymEnv("Pendulum-v1")), ) create_env_fn = [parallel_env, parallel_env] actor_module = nn.Linear(3, 1) actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"]) .. GENERATED FROM PYTHON SOURCE LINES 752-753 Sync data collector .. GENERATED FROM PYTHON SOURCE LINES 753-765 .. code-block:: Python devices = ["cpu", "cpu"] collector = MultiSyncDataCollector( create_env_fn=create_env_fn, # either a list of functions or a ParallelEnv policy=actor, total_frames=240, max_frames_per_traj=-1, # envs are terminating, we don't need to stop them early frames_per_batch=60, # we want 60 frames at a time (we have 3 envs per sub-collector) device=devices, ) .. GENERATED FROM PYTHON SOURCE LINES 766-775 .. code-block:: Python for i, d in enumerate(collector): if i == 0: print(d) # trajectories are split automatically in [6 workers x 10 steps] collector.update_policy_weights_() # make sure that our policies have the latest weights if working on multiple devices print(i) collector.shutdown() del collector .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ action: Tensor(shape=torch.Size([2, 3, 10, 1]), device=cpu, dtype=torch.float32, is_shared=False), collector: TensorDict( fields={ traj_ids: Tensor(shape=torch.Size([2, 3, 10]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([2, 3, 10]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([2, 3, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([2, 3, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([2, 3, 10, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([2, 3, 10, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([2, 3, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([2, 3, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([2, 3, 10]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([2, 3, 10, 3]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([2, 3, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([2, 3, 10, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([2, 3, 10]), device=cpu, is_shared=False) 3 .. GENERATED FROM PYTHON SOURCE LINES 776-797 .. code-block:: Python # async data collector: keeps working while you update your model collector = MultiaSyncDataCollector( create_env_fn=create_env_fn, # either a list of functions or a ParallelEnv policy=actor, total_frames=240, max_frames_per_traj=-1, # envs are terminating, we don't need to stop them early frames_per_batch=60, # we want 60 frames at a time (we have 3 envs per sub-collector) device=devices, ) for i, d in enumerate(collector): if i == 0: print(d) # trajectories are split automatically in [6 workers x 10 steps] collector.update_policy_weights_() # make sure that our policies have the latest weights if working on multiple devices print(i) collector.shutdown() del collector del create_env_fn del parallel_env .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ action: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.float32, is_shared=False), collector: TensorDict( fields={ traj_ids: Tensor(shape=torch.Size([3, 20]), device=cpu, dtype=torch.int64, is_shared=False)}, batch_size=torch.Size([3, 20]), device=cpu, is_shared=False), done: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([3, 20, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([3, 20]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([3, 20, 3]), device=cpu, dtype=torch.float32, is_shared=False), terminated: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False), truncated: Tensor(shape=torch.Size([3, 20, 1]), device=cpu, dtype=torch.bool, is_shared=False)}, batch_size=torch.Size([3, 20]), device=cpu, is_shared=False) 3 .. GENERATED FROM PYTHON SOURCE LINES 798-800 Objectives ^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 800-822 .. code-block:: Python # TorchRL delivers meta-RL compatible loss functions # Disclaimer: This APi may change in the future from torchrl.objectives import DDPGLoss actor_module = nn.Linear(3, 1) actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"]) class ConcatModule(nn.Linear): def forward(self, obs, action): return super().forward(torch.cat([obs, action], -1)) value_module = ConcatModule(4, 1) value = TensorDictModule( value_module, in_keys=["observation", "action"], out_keys=["state_action_value"] ) loss_fn = DDPGLoss(actor, value) loss_fn.make_value_estimator(loss_fn.default_value_estimator, gamma=0.99) .. GENERATED FROM PYTHON SOURCE LINES 823-839 .. code-block:: Python tensordict = TensorDict( { "observation": torch.randn(10, 3), "next": { "observation": torch.randn(10, 3), "reward": torch.randn(10, 1), "done": torch.zeros(10, 1, dtype=torch.bool), }, "action": torch.randn(10, 1), }, batch_size=[10], device="cpu", ) loss_td = loss_fn(tensordict) .. GENERATED FROM PYTHON SOURCE LINES 840-843 .. code-block:: Python print(loss_td) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), pred_value: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False), pred_value_max: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), target_value: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False), target_value_max: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), td_error: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 844-847 .. code-block:: Python print(tensordict) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ action: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False), next: TensorDict( fields={ done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False), observation: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False), reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([10]), device=cpu, is_shared=False), observation: Tensor(shape=torch.Size([10, 3]), device=cpu, dtype=torch.float32, is_shared=False), td_error: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([10]), device=cpu, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 848-867 State of the Library ^^^^^^^^^^^^^^^^^^^^ TorchRL is currently an **alpha-release**: there may be bugs and there is no guarantee about BC-breaking changes. We should be able to move to a beta-release by the end of the year. Our roadmap to get there comprises: - Distributed solutions - Offline RL - Greater support for meta-RL - Multi-task and hierarchical RL Contributing ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ We are actively looking for contributors and early users. If you're working in RL (or just curious), try it! Give us feedback: what will make the success of TorchRL is how well it covers researchers needs. To do that, we need their input! Since the library is nascent, it is a great time for you to shape it the way you want! .. GENERATED FROM PYTHON SOURCE LINES 869-873 Installing the Library ^^^^^^^^^^^^^^^^^^^^^^ The library is on PyPI: *pip install torchrl* .. rst-class:: sphx-glr-timing **Total running time of the script:** (3 minutes 43.926 seconds) **Estimated memory usage:** 324 MB .. _sphx_glr_download_tutorials_torchrl_demo.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: torchrl_demo.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: torchrl_demo.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: torchrl_demo.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_