.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "tutorials/export.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_tutorials_export.py: Exporting tensordict modules ============================ **Author**: `Vincent Moens `_ Prerequisites ~~~~~~~~~~~~~ Reading the :ref:`TensorDictModule ` tutorial is preferable to fully benefit from this tutorial. Once a module has been written using ``tensordict.nn``, it is often useful to isolate the computational graph and export that graph. The goal of this may be to execute the model on hardware (e.g., robots, drones, edge devices) or eliminate the dependency on tensordict altogether. PyTorch provides multiple methods for exporting modules, including ``onnx`` and ``torch.export``, both of which are compatible with ``tensordict``. In this short tutorial, we will see how one can use ``torch.export`` to isolate the computational graph of a model. ``torch.onnx`` support follows the same logic. Key learnings ~~~~~~~~~~~~~ - Executing a ``tensordict.nn`` module without :class:`~tensordict.TensorDict` inputs; - Selecting the output(s) of a model; - Exporting such model using `torch.export`; - Saving the model to a file; - Isolating the pytorch model; .. GENERATED FROM PYTHON SOURCE LINES 35-45 .. code-block:: Python import time import torch from tensordict.nn import ( NormalParamExtractor, TensorDictModule as Mod, TensorDictSequential as Seq, ) from torch import nn .. GENERATED FROM PYTHON SOURCE LINES 46-58 Designing the model ------------------- Let us build a simple neural network using ``tensordict.nn``. The network will consist of: - A linear layer mapping input to a hidden representation; - A ReLU activation; - A final linear layer producing the output. We will also include a :class:`tensordict.nn.NormalParamExtractor` to demonstrate how to extract multiple outputs from a single tensor. .. GENERATED FROM PYTHON SOURCE LINES 58-67 .. code-block:: Python model = Seq( # 1. A small network for embedding Mod(nn.Linear(3, 4), in_keys=["x"], out_keys=["hidden"]), Mod(nn.ReLU(), in_keys=["hidden"], out_keys=["hidden"]), Mod(nn.Linear(4, 4), in_keys=["hidden"], out_keys=["latent"]), # 2. Extracting params (splits into loc and scale) Mod(NormalParamExtractor(), in_keys=["latent"], out_keys=["loc", "scale"]), ) .. GENERATED FROM PYTHON SOURCE LINES 68-70 Let us run this model and see what the output looks like: .. GENERATED FROM PYTHON SOURCE LINES 70-74 .. code-block:: Python x = torch.randn(1, 3) print(model(x=x)) .. rst-class:: sphx-glr-script-out .. code-block:: none (tensor([[0.3655, 0.2505, 0.1298, 0.0000]], grad_fn=), tensor([[ 0.1071, -0.0793, -0.2612, 0.0327]], grad_fn=), tensor([[ 0.1071, -0.0793]], grad_fn=), tensor([[0.8440, 1.0207]], grad_fn=)) .. GENERATED FROM PYTHON SOURCE LINES 75-90 As expected, running the model with a tensor input returns as many tensors as the module's output keys! For large models, this can be quite annoying and wasteful. Later, we will see how we can limit the number of outputs of the model to deal with this issue. Using ``torch.export`` with a ``TensorDictModule`` -------------------------------------------------- Now that we have successfully built our model, we would like to extract its computational graph in a single object that is independent of ``tensordict``. ``torch.export`` is a PyTorch module dedicated to isolating the graph of a module and represent it in a standardized way. Its main entry point is :func:`~torch.export.export` which returns an ``ExportedProgram`` object. In turn, this object has several attributes of interest that we will explore below: a ``graph_module``, which represents the FX graph captured by ``export``, a ``graph_signature`` with inputs, outputs, etc., of the graph, and finally a ``module()`` that returns a callable that can be used in-place of the original module. Although our module accepts both args and kwargs, we will focus on its usage with kwargs as this is clearer. .. GENERATED FROM PYTHON SOURCE LINES 90-95 .. code-block:: Python from torch.export import export model_export = export(model, args=(), kwargs={"x": x}) .. GENERATED FROM PYTHON SOURCE LINES 96-98 Let us look at the module: .. GENERATED FROM PYTHON SOURCE LINES 98-100 .. code-block:: Python print("module:", model_export.module()) .. rst-class:: sphx-glr-script-out .. code-block:: none module: GraphModule( (module): Module( (0): Module( (module): Module() ) (2): Module( (module): Module() ) ) (_guards_fn): GuardsFn() ) def forward(self, x): x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec) module_0_module_weight = getattr(self.module, "0").module.weight module_0_module_bias = getattr(self.module, "0").module.bias module_2_module_weight = getattr(self.module, "2").module.weight module_2_module_bias = getattr(self.module, "2").module.bias _guards_fn = self._guards_fn(x); _guards_fn = None linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None relu = torch.ops.aten.relu.default(linear); linear = None linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); module_2_module_weight = module_2_module_bias = None chunk = torch.ops.aten.chunk.default(linear_1, 2, -1) getitem = chunk[0] getitem_1 = chunk[1]; chunk = None add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None softplus = torch.ops.aten.softplus.default(add); add = None add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None return pytree.tree_unflatten((relu, linear_1, getitem, clamp_min), self._out_spec) # To see more debug info, please use `graph_module.print_readable()` .. GENERATED FROM PYTHON SOURCE LINES 101-103 This module can be run exactly like our original module (with a lower overhead): .. GENERATED FROM PYTHON SOURCE LINES 103-114 .. code-block:: Python t0 = time.time() model(x=x) print(f"Time for TDModule: {(time.time() - t0) * 1e6: 4.2f} micro-seconds") exported = model_export.module() # Exported version t0 = time.time() exported(x=x) print(f"Time for exported module: {(time.time() - t0) * 1e6: 4.2f} micro-seconds") .. rst-class:: sphx-glr-script-out .. code-block:: none Time for TDModule: 504.26 micro-seconds Time for exported module: 357.15 micro-seconds .. GENERATED FROM PYTHON SOURCE LINES 115-116 and the FX graph: .. GENERATED FROM PYTHON SOURCE LINES 116-118 .. code-block:: Python print("fx graph:", model_export.graph_module.print_readable()) .. rst-class:: sphx-glr-script-out .. code-block:: none class GraphModule(torch.nn.Module): def forward(self, p_module_0_module_weight: "f32[4, 3]", p_module_0_module_bias: "f32[4]", p_module_2_module_weight: "f32[4, 4]", p_module_2_module_bias: "f32[4]", x: "f32[1, 3]"): # File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias) linear: "f32[1, 4]" = torch.ops.aten.linear.default(x, p_module_0_module_weight, p_module_0_module_bias); x = p_module_0_module_weight = p_module_0_module_bias = None # File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:143 in forward, code: return F.relu(input, inplace=self.inplace) relu: "f32[1, 4]" = torch.ops.aten.relu.default(linear); linear = None # File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias) linear_1: "f32[1, 4]" = torch.ops.aten.linear.default(relu, p_module_2_module_weight, p_module_2_module_bias); p_module_2_module_weight = p_module_2_module_bias = None # File: /pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/distributions/continuous.py:85 in forward, code: loc, scale = tensor.chunk(2, -1) chunk = torch.ops.aten.chunk.default(linear_1, 2, -1) getitem: "f32[1, 2]" = chunk[0] getitem_1: "f32[1, 2]" = chunk[1]; chunk = None # File: /pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py:70 in forward, code: return torch.nn.functional.softplus(x + self.bias) + self.min_val add: "f32[1, 2]" = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None softplus: "f32[1, 2]" = torch.ops.aten.softplus.default(add); add = None add_1: "f32[1, 2]" = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None # File: /pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/distributions/continuous.py:86 in forward, code: scale = self.scale_mapping(scale).clamp_min(self.scale_lb) clamp_min: "f32[1, 2]" = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None return (relu, linear_1, getitem, clamp_min) fx graph: class GraphModule(torch.nn.Module): def forward(self, p_module_0_module_weight: "f32[4, 3]", p_module_0_module_bias: "f32[4]", p_module_2_module_weight: "f32[4, 4]", p_module_2_module_bias: "f32[4]", x: "f32[1, 3]"): # File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias) linear: "f32[1, 4]" = torch.ops.aten.linear.default(x, p_module_0_module_weight, p_module_0_module_bias); x = p_module_0_module_weight = p_module_0_module_bias = None # File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:143 in forward, code: return F.relu(input, inplace=self.inplace) relu: "f32[1, 4]" = torch.ops.aten.relu.default(linear); linear = None # File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias) linear_1: "f32[1, 4]" = torch.ops.aten.linear.default(relu, p_module_2_module_weight, p_module_2_module_bias); p_module_2_module_weight = p_module_2_module_bias = None # File: /pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/distributions/continuous.py:85 in forward, code: loc, scale = tensor.chunk(2, -1) chunk = torch.ops.aten.chunk.default(linear_1, 2, -1) getitem: "f32[1, 2]" = chunk[0] getitem_1: "f32[1, 2]" = chunk[1]; chunk = None # File: /pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py:70 in forward, code: return torch.nn.functional.softplus(x + self.bias) + self.min_val add: "f32[1, 2]" = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None softplus: "f32[1, 2]" = torch.ops.aten.softplus.default(add); add = None add_1: "f32[1, 2]" = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None # File: /pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/distributions/continuous.py:86 in forward, code: scale = self.scale_mapping(scale).clamp_min(self.scale_lb) clamp_min: "f32[1, 2]" = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None return (relu, linear_1, getitem, clamp_min) .. GENERATED FROM PYTHON SOURCE LINES 119-127 Working with nested keys ~~~~~~~~~~~~~~~~~~~~~~~~ Nested keys are a core feature of the tensordict library, and being able to export modules that read and write nested entries is therefore an important feature to support. Because keyword arguments must be regular strings, it is not possible for :class:`~tensordict.nn.dispatch` to work directly with them. Instead, ``dispatch`` will unpack nested keys joined with a regular underscore (`"_"`), as the following example shows. .. GENERATED FROM PYTHON SOURCE LINES 127-137 .. code-block:: Python model_nested = Seq( Mod(lambda x: x + 1, in_keys=[("some", "key")], out_keys=["hidden"]), Mod(lambda x: x - 1, in_keys=["hidden"], out_keys=[("some", "output")]), ).select_out_keys(("some", "output")) model_nested_export = export(model_nested, args=(), kwargs={"some_key": x}) print("exported module with nested input:", model_nested_export.module()) .. rst-class:: sphx-glr-script-out .. code-block:: none exported module with nested input: GraphModule( (_guards_fn): GuardsFn() ) def forward(self, some_key): some_key, = fx_pytree.tree_flatten_spec(([], {'some_key':some_key}), self._in_spec) _guards_fn = self._guards_fn(some_key); _guards_fn = None add = torch.ops.aten.add.Tensor(some_key, 1); some_key = None sub = torch.ops.aten.sub.Tensor(add, 1); add = None return pytree.tree_unflatten((sub,), self._out_spec) # To see more debug info, please use `graph_module.print_readable()` .. GENERATED FROM PYTHON SOURCE LINES 138-174 Note that the callable returned by `module()` is a pure python callable that can be in turn compiled using :func:`~torch.compile`. Saving the exported module ~~~~~~~~~~~~~~~~~~~~~~~~~~ ``torch.export`` has its own serialization protocol, :func:`~torch.export.save` and :func:`~torch.export.load`. Conventionally, the `".pt2"` extension is to be used: >>> torch.export.save(model_export, "model.pt2") Selecting the outputs --------------------- Recall that the ``tensordict.nn`` is to keep every intermediate value in the output, unless the user specifically asks for only a specific value. During training, this can be very useful: one can easily log intermediate values of the graph, or use them for other purposes (e.g., reconstruct a distribution based on its saved parameters, rather than saving the :class:`~torch.distributions.Distribution` object itself). One could also argue that, during training, the impact on memory of registering intermediate values is negligible since they are part of the computational graph used by ``torch.autograd`` to compute the parameter gradients. During inference, though, we most likely are only interested in specific outputs of the model. Because we want to extract the model for usages that are independent of the ``tensordict`` library, it makes sense to isolate the only output we desire. To do this, we have several options: 1. Build the :meth:`~tensordict.nn.TensorDictSequential` with the ``selected_out_keys`` keyword argument, which will induce the selection of the desired entries during calls to the module; 2. Using the :meth:`~tensordict.nn.TensorDictModule.select_out_keys` method, which will modify the ``out_keys`` attribute in-place (this can be reverted through :meth:`~tensordict.nn.TensorDictModule.reset_out_keys`). 3. Wrap the existing instance in a :meth:`~tensordict.nn.TensorDictSequential` that will filter out the unwanted keys: >>> module_filtered = Seq(module, selected_out_keys=["loc"]) Let us test the model after selecting its output keys. When an `x` input is provided, we expect our model to output a single tensor corresponding to the `"loc"` output: .. GENERATED FROM PYTHON SOURCE LINES 174-178 .. code-block:: Python model.select_out_keys("loc") print(model(x=x)) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[ 0.1071, -0.0793]], grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 179-181 We see that the output is now a single tensor. We can create a new exported graph from this. Its computational graph should be simplified: .. GENERATED FROM PYTHON SOURCE LINES 181-185 .. code-block:: Python model_export = export(model, args=(), kwargs={"x": x}) print("module:", model_export.module()) .. rst-class:: sphx-glr-script-out .. code-block:: none module: GraphModule( (module): Module( (0): Module( (module): Module() ) (2): Module( (module): Module() ) ) (_guards_fn): GuardsFn() ) def forward(self, x): x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec) module_0_module_weight = getattr(self.module, "0").module.weight module_0_module_bias = getattr(self.module, "0").module.bias module_2_module_weight = getattr(self.module, "2").module.weight module_2_module_bias = getattr(self.module, "2").module.bias _guards_fn = self._guards_fn(x); _guards_fn = None linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None relu = torch.ops.aten.relu.default(linear); linear = None linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); relu = module_2_module_weight = module_2_module_bias = None chunk = torch.ops.aten.chunk.default(linear_1, 2, -1); linear_1 = None getitem = chunk[0] getitem_1 = chunk[1]; chunk = None add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None softplus = torch.ops.aten.softplus.default(add); add = None add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = clamp_min = None return pytree.tree_unflatten((getitem,), self._out_spec) # To see more debug info, please use `graph_module.print_readable()` .. GENERATED FROM PYTHON SOURCE LINES 186-198 This is all you need to know to use ``torch.export``. Please refer to the `official documentation `_ for more info. Next steps and further reading ------------------------------ - Check the ``torch.export`` tutorial, available `here `__; - ONNX support: check the `ONNX tutorials `_ to learn more about this feature. Exporting to ONNX is very similar to `torch.export` explained here. - For deployment of PyTorch code on servers without python environment, check the `AOTInductor `_ documentation. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 0.405 seconds) .. _sphx_glr_download_tutorials_export.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: export.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: export.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: export.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_