.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/functional.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_functional.py: Functional Programming with TensorDict ======================================= **Author**: `Vincent Moens `_ In this tutorial you will learn how to use :class:`~.TensorDict` for functional-style programming with :class:`~torch.nn.Module`, including parameter swapping, model ensembling with :func:`~torch.vmap`, and functional calls with :func:`~torch.func.functional_call`. .. GENERATED FROM PYTHON SOURCE LINES 13-18 TensorDict as a parameter container ------------------------------------ :meth:`~.TensorDict.from_module` extracts the parameters of a module into a nested :class:`~.TensorDict` whose structure mirrors the module hierarchy. .. GENERATED FROM PYTHON SOURCE LINES 18-27 .. code-block:: Python import torch import torch.nn as nn from tensordict import TensorDict module = nn.Sequential(nn.Linear(3, 4), nn.ReLU(), nn.Linear(4, 1)) params = TensorDict.from_module(module) print(params) .. rst-class:: sphx-glr-script-out .. code-block:: none TensorDict( fields={ 0: TensorDict( fields={ bias: Parameter(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False), weight: Parameter(shape=torch.Size([4, 3]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False), 2: TensorDict( fields={ bias: Parameter(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False), weight: Parameter(shape=torch.Size([1, 4]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) .. GENERATED FROM PYTHON SOURCE LINES 33-36 The resulting :class:`~.TensorDict` holds the same :class:`~torch.nn.Parameter` objects as the module. We can manipulate them as a batch -- for example, zeroing all parameters at once: .. GENERATED FROM PYTHON SOURCE LINES 36-40 .. code-block:: Python params_zero = params.detach().clone().zero_() print("All zeros:", (params_zero == 0).all()) .. rst-class:: sphx-glr-script-out .. code-block:: none All zeros: True .. GENERATED FROM PYTHON SOURCE LINES 41-47 Swapping parameters with a context manager ------------------------------------------- :meth:`~.TensorDict.to_module` temporarily replaces the parameters of a module within a context manager. The original parameters are restored on exit. .. GENERATED FROM PYTHON SOURCE LINES 47-60 .. code-block:: Python x = torch.randn(5, 3) with params_zero.to_module(module): y_zero = module(x) print("Output with zeroed params:", y_zero) assert (y_zero == 0).all() y_original = module(x) print("Output with original params:", y_original) assert not (y_original == 0).all() .. rst-class:: sphx-glr-script-out .. code-block:: none Output with zeroed params: tensor([[0.], [0.], [0.], [0.], [0.]]) Output with original params: tensor([[-0.0777], [ 0.1591], [ 0.4036], [-0.1975], [-0.0056]], grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 61-67 Model ensembling with ``torch.vmap`` ------------------------------------- Because :class:`~.TensorDict` supports batching and stacking, we can stack multiple parameter configurations and use :func:`~torch.vmap` to run the model across all of them in a single vectorized call. .. GENERATED FROM PYTHON SOURCE LINES 67-85 .. code-block:: Python params_ones = params.detach().clone().apply_(lambda t: t.fill_(1.0)) params_stack = torch.stack([params_zero, params_ones, params]) print("Stacked params batch_size:", params_stack.batch_size) def call(x, td): with td.to_module(module): return module(x) x = torch.randn(3, 5, 3) y = torch.vmap(call)(x, params_stack) print("Output shape:", y.shape) assert (y[0] == 0).all() .. rst-class:: sphx-glr-script-out .. code-block:: none Stacked params batch_size: torch.Size([3]) Output shape: torch.Size([3, 5, 1]) .. GENERATED FROM PYTHON SOURCE LINES 86-93 Functional calls with ``torch.func`` -------------------------------------- :func:`~torch.func.functional_call` works with the state-dict extracted by :meth:`~.TensorDict.from_module`. Because ``from_module`` returns a :class:`~.TensorDict` with the same structure as a state-dict, we can convert it to a regular dict and pass it directly. .. GENERATED FROM PYTHON SOURCE LINES 93-102 .. code-block:: Python from torch.func import functional_call flat_params = params.flatten_keys(".") state_dict = dict(flat_params.items()) x = torch.randn(5, 3) y = functional_call(module, state_dict, x) print("functional_call output:", y.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none functional_call output: torch.Size([5, 1]) .. GENERATED FROM PYTHON SOURCE LINES 103-108 The combination of :meth:`~.TensorDict.from_module`, :meth:`~.TensorDict.to_module`, and :func:`~torch.vmap` makes it straightforward to do things like compute per-sample gradients, run model ensembles, or implement meta-learning inner loops -- all without leaving the standard PyTorch API. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.006 seconds) .. _sphx_glr_download_tutorials_functional.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: functional.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: functional.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: functional.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_