.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/multi_task.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_multi_task.py: Task-specific policy in multi-task environments ================================================ This tutorial details how multi-task policies and batched environments can be used. .. GENERATED FROM PYTHON SOURCE LINES 7-10 At the end of this tutorial, you will be capable of writing policies that can compute actions in diverse settings using a distinct set of weights. You will also be able to execute diverse environments in parallel. .. GENERATED FROM PYTHON SOURCE LINES 10-16 .. code-block:: Python from tensordict import LazyStackedTensorDict from tensordict.nn import TensorDictModule, TensorDictSequential from torch import nn .. GENERATED FROM PYTHON SOURCE LINES 38-43 .. code-block:: Python from torchrl.envs import CatTensors, Compose, DoubleToFloat, ParallelEnv, TransformedEnv from torchrl.envs.libs.dm_control import DMControlEnv from torchrl.modules import MLP .. GENERATED FROM PYTHON SOURCE LINES 44-46 We design two environments, one humanoid that must complete the stand task and another that must learn to walk. .. GENERATED FROM PYTHON SOURCE LINES 46-74 .. code-block:: Python env1 = DMControlEnv("humanoid", "stand") env1_obs_keys = list(env1.observation_spec.keys()) env1 = TransformedEnv( env1, Compose( CatTensors(env1_obs_keys, "observation_stand", del_keys=False), CatTensors(env1_obs_keys, "observation"), DoubleToFloat( in_keys=["observation_stand", "observation"], in_keys_inv=["action"], ), ), ) env2 = DMControlEnv("humanoid", "walk") env2_obs_keys = list(env2.observation_spec.keys()) env2 = TransformedEnv( env2, Compose( CatTensors(env2_obs_keys, "observation_walk", del_keys=False), CatTensors(env2_obs_keys, "observation"), DoubleToFloat( in_keys=["observation_walk", "observation"], in_keys_inv=["action"], ), ), ) .. GENERATED FROM PYTHON SOURCE LINES 75-84 .. code-block:: Python tdreset1 = env1.reset() tdreset2 = env2.reset() # With LazyStackedTensorDict, stacking is done in a lazy manner: the original tensordicts # can still be recovered by indexing the main tensordict tdreset = LazyStackedTensorDict.lazy_stack([tdreset1, tdreset2], 0) assert tdreset[0] is tdreset1 .. GENERATED FROM PYTHON SOURCE LINES 85-88 .. code-block:: Python print(tdreset[0]) .. GENERATED FROM PYTHON SOURCE LINES 89-96 Policy ^^^^^^ We will design a policy where a backbone reads the "observation" key. Then specific sub-components will read the "observation_stand" and "observation_walk" keys of the stacked tensordicts, if they are present, and pass them through the dedicated sub-network. .. GENERATED FROM PYTHON SOURCE LINES 96-99 .. code-block:: Python action_dim = env1.action_spec.shape[-1] .. GENERATED FROM PYTHON SOURCE LINES 100-118 .. code-block:: Python policy_common = TensorDictModule( nn.Linear(67, 64), in_keys=["observation"], out_keys=["hidden"] ) policy_stand = TensorDictModule( MLP(67 + 64, action_dim, depth=2), in_keys=["observation_stand", "hidden"], out_keys=["action"], ) policy_walk = TensorDictModule( MLP(67 + 64, action_dim, depth=2), in_keys=["observation_walk", "hidden"], out_keys=["action"], ) seq = TensorDictSequential( policy_common, policy_stand, policy_walk, partial_tolerant=True ) .. GENERATED FROM PYTHON SOURCE LINES 119-120 Let's check that our sequence outputs actions for a single env (stand). .. GENERATED FROM PYTHON SOURCE LINES 120-123 .. code-block:: Python seq(env1.reset()) .. GENERATED FROM PYTHON SOURCE LINES 124-125 Let's check that our sequence outputs actions for a single env (walk). .. GENERATED FROM PYTHON SOURCE LINES 125-128 .. code-block:: Python seq(env2.reset()) .. GENERATED FROM PYTHON SOURCE LINES 129-133 This also works with the stack: now the stand and walk keys have disappeared, because they're not shared by all tensordicts. But the ``TensorDictSequential`` still performed the operations. Note that the backbone was executed in a vectorized way - not in a loop - which is more efficient. .. GENERATED FROM PYTHON SOURCE LINES 133-136 .. code-block:: Python seq(tdreset) .. GENERATED FROM PYTHON SOURCE LINES 137-148 Executing diverse tasks in parallel ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ We can parallelize the operations if the common key-value pairs share the same specs (in particular their shape and dtype must match: you can't do the following if the observation shapes are different but are pointed to by the same key). If ParallelEnv receives a single env making function, it will assume that a single task has to be performed. If a list of functions is provided, then it will assume that we are in a multi-task setting. .. GENERATED FROM PYTHON SOURCE LINES 148-186 .. code-block:: Python def env1_maker(): return TransformedEnv( DMControlEnv("humanoid", "stand"), Compose( CatTensors(env1_obs_keys, "observation_stand", del_keys=False), CatTensors(env1_obs_keys, "observation"), DoubleToFloat( in_keys=["observation_stand", "observation"], in_keys_inv=["action"], ), ), ) def env2_maker(): return TransformedEnv( DMControlEnv("humanoid", "walk"), Compose( CatTensors(env2_obs_keys, "observation_walk", del_keys=False), CatTensors(env2_obs_keys, "observation"), DoubleToFloat( in_keys=["observation_walk", "observation"], in_keys_inv=["action"], ), ), ) env = ParallelEnv(2, [env1_maker, env2_maker]) assert not env._single_task tdreset = env.reset() print(tdreset) print(tdreset[0]) print(tdreset[1]) # should be different .. GENERATED FROM PYTHON SOURCE LINES 187-188 Let's pass the output through our network. .. GENERATED FROM PYTHON SOURCE LINES 188-200 .. code-block:: Python tdreset = seq(tdreset) print(tdreset) print(tdreset[0]) print(tdreset[1]) # should be different but all have an "action" key env.step(tdreset) # computes actions and execute steps in parallel print(tdreset) print(tdreset[0]) print(tdreset[1]) # next_observation has now been written .. GENERATED FROM PYTHON SOURCE LINES 201-203 Rollout ^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 203-206 .. code-block:: Python td_rollout = env.rollout(100, policy=seq, return_contiguous=False) .. GENERATED FROM PYTHON SOURCE LINES 207-210 .. code-block:: Python td_rollout[:, 0] # tensordict of the first step: only the common keys are shown .. GENERATED FROM PYTHON SOURCE LINES 211-216 .. code-block:: Python td_rollout[0] # tensordict of the first env: the stand obs is present env.close() del env .. _sphx_glr_download_tutorials_multi_task.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: multi_task.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: multi_task.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: multi_task.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_