Note
Go to the end to download the full example code.
Exporting tensordict modules¶
Author: Vincent Moens
Prerequisites¶
Reading the TensorDictModule tutorial is preferable to fully benefit from this tutorial.
Once a module has been written using tensordict.nn, it is often useful to isolate the computational graph and export
that graph. The goal of this may be to execute the model on hardware (e.g., robots, drones, edge devices) or eliminate
the dependency on tensordict altogether.
PyTorch provides multiple methods for exporting modules, including onnx and torch.export, both of which are
compatible with tensordict.
In this short tutorial, we will see how one can use torch.export to isolate the computational graph of a model.
torch.onnx support follows the same logic.
Key learnings¶
Executing a
tensordict.nnmodule withoutTensorDictinputs;Selecting the output(s) of a model;
Exporting such model using torch.export;
Saving the model to a file;
Isolating the pytorch model;
import time
import torch
from tensordict.nn import (
NormalParamExtractor,
TensorDictModule as Mod,
TensorDictSequential as Seq,
)
from torch import nn
Designing the model¶
Let us build a simple neural network using tensordict.nn. The network will consist of:
A linear layer mapping input to a hidden representation;
A ReLU activation;
A final linear layer producing the output.
We will also include a tensordict.nn.NormalParamExtractor to demonstrate how to extract
multiple outputs from a single tensor.
model = Seq(
# 1. A small network for embedding
Mod(nn.Linear(3, 4), in_keys=["x"], out_keys=["hidden"]),
Mod(nn.ReLU(), in_keys=["hidden"], out_keys=["hidden"]),
Mod(nn.Linear(4, 4), in_keys=["hidden"], out_keys=["latent"]),
# 2. Extracting params (splits into loc and scale)
Mod(NormalParamExtractor(), in_keys=["latent"], out_keys=["loc", "scale"]),
)
Let us run this model and see what the output looks like:
x = torch.randn(1, 3)
print(model(x=x))
(tensor([[0.3655, 0.2505, 0.1298, 0.0000]], grad_fn=<ReluBackward0>), tensor([[ 0.1071, -0.0793, -0.2612, 0.0327]], grad_fn=<AddmmBackward0>), tensor([[ 0.1071, -0.0793]], grad_fn=<SplitBackward0>), tensor([[0.8440, 1.0207]], grad_fn=<ClampMinBackward0>))
As expected, running the model with a tensor input returns as many tensors as the module’s output keys! For large models, this can be quite annoying and wasteful. Later, we will see how we can limit the number of outputs of the model to deal with this issue.
Using torch.export with a TensorDictModule¶
Now that we have successfully built our model, we would like to extract its computational graph in a single object that
is independent of tensordict. torch.export is a PyTorch module dedicated to isolating the graph of a module and
represent it in a standardized way. Its main entry point is export() which returns an ExportedProgram
object. In turn, this object has several attributes of interest that we will explore below: a graph_module,
which represents the FX graph captured by export, a graph_signature with inputs, outputs, etc., of the graph,
and finally a module() that returns a callable that can be used in-place of the original module.
Although our module accepts both args and kwargs, we will focus on its usage with kwargs as this is clearer.
from torch.export import export
model_export = export(model, args=(), kwargs={"x": x})
Let us look at the module:
print("module:", model_export.module())
module: GraphModule(
(module): Module(
(0): Module(
(module): Module()
)
(2): Module(
(module): Module()
)
)
(_guards_fn): GuardsFn()
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
module_0_module_weight = getattr(self.module, "0").module.weight
module_0_module_bias = getattr(self.module, "0").module.bias
module_2_module_weight = getattr(self.module, "2").module.weight
module_2_module_bias = getattr(self.module, "2").module.bias
_guards_fn = self._guards_fn(x); _guards_fn = None
linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); module_2_module_weight = module_2_module_bias = None
chunk = torch.ops.aten.chunk.default(linear_1, 2, -1)
getitem = chunk[0]
getitem_1 = chunk[1]; chunk = None
add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus = torch.ops.aten.softplus.default(add); add = None
add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
return pytree.tree_unflatten((relu, linear_1, getitem, clamp_min), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
This module can be run exactly like our original module (with a lower overhead):
Time for TDModule: 504.26 micro-seconds
Time for exported module: 357.15 micro-seconds
and the FX graph:
print("fx graph:", model_export.graph_module.print_readable())
class GraphModule(torch.nn.Module):
def forward(self, p_module_0_module_weight: "f32[4, 3]", p_module_0_module_bias: "f32[4]", p_module_2_module_weight: "f32[4, 4]", p_module_2_module_bias: "f32[4]", x: "f32[1, 3]"):
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[1, 4]" = torch.ops.aten.linear.default(x, p_module_0_module_weight, p_module_0_module_bias); x = p_module_0_module_weight = p_module_0_module_bias = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:143 in forward, code: return F.relu(input, inplace=self.inplace)
relu: "f32[1, 4]" = torch.ops.aten.relu.default(linear); linear = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[1, 4]" = torch.ops.aten.linear.default(relu, p_module_2_module_weight, p_module_2_module_bias); p_module_2_module_weight = p_module_2_module_bias = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/distributions/continuous.py:85 in forward, code: loc, scale = tensor.chunk(2, -1)
chunk = torch.ops.aten.chunk.default(linear_1, 2, -1)
getitem: "f32[1, 2]" = chunk[0]
getitem_1: "f32[1, 2]" = chunk[1]; chunk = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py:70 in forward, code: return torch.nn.functional.softplus(x + self.bias) + self.min_val
add: "f32[1, 2]" = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus: "f32[1, 2]" = torch.ops.aten.softplus.default(add); add = None
add_1: "f32[1, 2]" = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/distributions/continuous.py:86 in forward, code: scale = self.scale_mapping(scale).clamp_min(self.scale_lb)
clamp_min: "f32[1, 2]" = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
return (relu, linear_1, getitem, clamp_min)
fx graph: class GraphModule(torch.nn.Module):
def forward(self, p_module_0_module_weight: "f32[4, 3]", p_module_0_module_bias: "f32[4]", p_module_2_module_weight: "f32[4, 4]", p_module_2_module_bias: "f32[4]", x: "f32[1, 3]"):
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[1, 4]" = torch.ops.aten.linear.default(x, p_module_0_module_weight, p_module_0_module_bias); x = p_module_0_module_weight = p_module_0_module_bias = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/nn/modules/activation.py:143 in forward, code: return F.relu(input, inplace=self.inplace)
relu: "f32[1, 4]" = torch.ops.aten.relu.default(linear); linear = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[1, 4]" = torch.ops.aten.linear.default(relu, p_module_2_module_weight, p_module_2_module_bias); p_module_2_module_weight = p_module_2_module_bias = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/distributions/continuous.py:85 in forward, code: loc, scale = tensor.chunk(2, -1)
chunk = torch.ops.aten.chunk.default(linear_1, 2, -1)
getitem: "f32[1, 2]" = chunk[0]
getitem_1: "f32[1, 2]" = chunk[1]; chunk = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/utils.py:70 in forward, code: return torch.nn.functional.softplus(x + self.bias) + self.min_val
add: "f32[1, 2]" = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus: "f32[1, 2]" = torch.ops.aten.softplus.default(add); add = None
add_1: "f32[1, 2]" = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
# File: /pytorch/tensordict/env/lib/python3.10/site-packages/tensordict/nn/distributions/continuous.py:86 in forward, code: scale = self.scale_mapping(scale).clamp_min(self.scale_lb)
clamp_min: "f32[1, 2]" = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = None
return (relu, linear_1, getitem, clamp_min)
Working with nested keys¶
Nested keys are a core feature of the tensordict library, and being able to export modules that read and write
nested entries is therefore an important feature to support.
Because keyword arguments must be regular strings, it is not possible for dispatch to work
directly with them. Instead, dispatch will unpack nested keys joined with a regular underscore (“_”), as the
following example shows.
model_nested = Seq(
Mod(lambda x: x + 1, in_keys=[("some", "key")], out_keys=["hidden"]),
Mod(lambda x: x - 1, in_keys=["hidden"], out_keys=[("some", "output")]),
).select_out_keys(("some", "output"))
model_nested_export = export(model_nested, args=(), kwargs={"some_key": x})
print("exported module with nested input:", model_nested_export.module())
exported module with nested input: GraphModule(
(_guards_fn): GuardsFn()
)
def forward(self, some_key):
some_key, = fx_pytree.tree_flatten_spec(([], {'some_key':some_key}), self._in_spec)
_guards_fn = self._guards_fn(some_key); _guards_fn = None
add = torch.ops.aten.add.Tensor(some_key, 1); some_key = None
sub = torch.ops.aten.sub.Tensor(add, 1); add = None
return pytree.tree_unflatten((sub,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
Note that the callable returned by module() is a pure python callable that can be in turn compiled using
compile().
Saving the exported module¶
torch.export has its own serialization protocol, save() and load().
Conventionally, the “.pt2” extension is to be used:
>>> torch.export.save(model_export, "model.pt2")
Selecting the outputs¶
Recall that the tensordict.nn is to keep every intermediate value in the output, unless the user specifically asks
for only a specific value. During training, this can be very useful: one can easily log intermediate values of the
graph, or use them for other purposes (e.g., reconstruct a distribution based on its saved parameters, rather than
saving the Distribution object itself). One could also argue that, during training, the
impact on memory of registering intermediate values is negligible since they are part of the computational graph
used by torch.autograd to compute the parameter gradients.
During inference, though, we most likely are only interested in specific outputs of the model.
Because we want to extract the model for usages that are independent of the tensordict library, it makes sense to
isolate the only output we desire.
To do this, we have several options:
Build the
TensorDictSequential()with theselected_out_keyskeyword argument, which will induce the selection of the desired entries during calls to the module;Using the
select_out_keys()method, which will modify theout_keysattribute in-place (this can be reverted throughreset_out_keys()).Wrap the existing instance in a
TensorDictSequential()that will filter out the unwanted keys:>>> module_filtered = Seq(module, selected_out_keys=["loc"])
Let us test the model after selecting its output keys. When an x input is provided, we expect our model to output a single tensor corresponding to the “loc” output:
tensor([[ 0.1071, -0.0793]], grad_fn=<SplitBackward0>)
We see that the output is now a single tensor. We can create a new exported graph from this. Its computational graph should be simplified:
model_export = export(model, args=(), kwargs={"x": x})
print("module:", model_export.module())
module: GraphModule(
(module): Module(
(0): Module(
(module): Module()
)
(2): Module(
(module): Module()
)
)
(_guards_fn): GuardsFn()
)
def forward(self, x):
x, = fx_pytree.tree_flatten_spec(([], {'x':x}), self._in_spec)
module_0_module_weight = getattr(self.module, "0").module.weight
module_0_module_bias = getattr(self.module, "0").module.bias
module_2_module_weight = getattr(self.module, "2").module.weight
module_2_module_bias = getattr(self.module, "2").module.bias
_guards_fn = self._guards_fn(x); _guards_fn = None
linear = torch.ops.aten.linear.default(x, module_0_module_weight, module_0_module_bias); x = module_0_module_weight = module_0_module_bias = None
relu = torch.ops.aten.relu.default(linear); linear = None
linear_1 = torch.ops.aten.linear.default(relu, module_2_module_weight, module_2_module_bias); relu = module_2_module_weight = module_2_module_bias = None
chunk = torch.ops.aten.chunk.default(linear_1, 2, -1); linear_1 = None
getitem = chunk[0]
getitem_1 = chunk[1]; chunk = None
add = torch.ops.aten.add.Tensor(getitem_1, 0.5254586935043335); getitem_1 = None
softplus = torch.ops.aten.softplus.default(add); add = None
add_1 = torch.ops.aten.add.Tensor(softplus, 0.01); softplus = None
clamp_min = torch.ops.aten.clamp_min.default(add_1, 0.0001); add_1 = clamp_min = None
return pytree.tree_unflatten((getitem,), self._out_spec)
# To see more debug info, please use `graph_module.print_readable()`
This is all you need to know to use torch.export. Please refer to the
official documentation for more info.
Next steps and further reading¶
Check the
torch.exporttutorial, available here;ONNX support: check the ONNX tutorials to learn more about this feature. Exporting to ONNX is very similar to torch.export explained here.
For deployment of PyTorch code on servers without python environment, check the AOTInductor documentation.
Total running time of the script: (0 minutes 0.405 seconds)