.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/tensordict_module_functional.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_tensordict_module_functional.py: Functionalizing TensorDictModule ================================ In this tutorial you will learn how to use :class:`~.TensorDictModule` in conjunction with functorch to create functionlized modules. .. GENERATED FROM PYTHON SOURCE LINES 9-14 Before we take a look at the functional utilities in :mod:`tensordict.nn`, let us reintroduce one of the example modules from the :class:`~.TensorDictModule` tutorial. We'll create a simple module that has two linear layers, which share the input and return separate outputs. .. GENERATED FROM PYTHON SOURCE LINES 14-32 .. code-block:: Python import functorch import torch import torch.nn as nn from tensordict import TensorDict from tensordict.nn import TensorDictModule class MultiHeadLinear(nn.Module): def __init__(self, in_1, out_1, out_2): super().__init__() self.linear_1 = nn.Linear(in_1, out_1) self.linear_2 = nn.Linear(in_1, out_2) def forward(self, x): return self.linear_1(x), self.linear_2(x) .. GENERATED FROM PYTHON SOURCE LINES 33-35 We can now create a :class:`~.TensorDictModule` that will read the input from a key ``"a"``, and write to the keys ``"output_1"`` and ``"output_2"``. .. GENERATED FROM PYTHON SOURCE LINES 35-39 .. code-block:: Python splitlinear = TensorDictModule( MultiHeadLinear(3, 4, 10), in_keys=["a"], out_keys=["output_1", "output_2"] ) .. GENERATED FROM PYTHON SOURCE LINES 40-42 Ordinarily we would use this module by simply calling it on a :class:`~.TensorDict` with the required input keys. .. GENERATED FROM PYTHON SOURCE LINES 42-48 .. code-block:: Python tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5]) splitlinear(tensordict) print(tensordict) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False), output_1: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), output_2: Tensor(shape=torch.Size([5, 10]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 49-51 However, we can also use :func:`functorch.make_functional_with_buffers` in order to functionalise the module. .. GENERATED FROM PYTHON SOURCE LINES 51-54 .. code-block:: Python func, params, buffers = functorch.make_functional_with_buffers(splitlinear) print(func(params, buffers, tensordict)) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ a: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False), output_1: Tensor(shape=torch.Size([5, 4]), device=cpu, dtype=torch.float32, is_shared=False), output_2: Tensor(shape=torch.Size([5, 10]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([5]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 55-58 This can be used with the vmap operator. For example, we use 3 replicas of the params and buffers and execute a vectorized map over these for a single batch of data: .. GENERATED FROM PYTHON SOURCE LINES 58-63 .. code-block:: Python params_expand = [p.expand(3, *p.shape) for p in params] buffers_expand = [p.expand(3, *p.shape) for p in buffers] print(torch.vmap(func, (0, 0, None))(params_expand, buffers_expand, tensordict)) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ a: Tensor(shape=torch.Size([3, 5, 3]), device=cpu, dtype=torch.float32, is_shared=False), output_1: Tensor(shape=torch.Size([3, 5, 4]), device=cpu, dtype=torch.float32, is_shared=False), output_2: Tensor(shape=torch.Size([3, 5, 10]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([3, 5]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 64-67 We can also use the native :func:`make_functional ` function from :mod:`tensordict.nn``, which modifies the module to make it accept the parameters as regular inputs: .. GENERATED FROM PYTHON SOURCE LINES 67-79 .. code-block:: Python from tensordict.nn import make_functional tensordict = TensorDict({"a": torch.randn(5, 3)}, batch_size=[5]) num_models = 10 model = TensorDictModule(nn.Linear(3, 4), in_keys=["a"], out_keys=["output"]) params = make_functional(model) # we stack two groups of parameters to show the vmap usage: params = torch.stack([params, params.apply(lambda x: torch.zeros_like(x))], 0) result_td = torch.vmap(model, (None, 0))(tensordict, params) print("the output tensordict shape is: ", result_td.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none the output tensordict shape is: torch.Size([2, 5]) .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.006 seconds) .. _sphx_glr_download_tutorials_tensordict_module_functional.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: tensordict_module_functional.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: tensordict_module_functional.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: tensordict_module_functional.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_