torch.export#
Created On: Jun 12, 2025 | Last Updated On: Jul 17, 2025
Overview#
torch.export.export()
takes a torch.nn.Module
and produces a traced graph
representing only the Tensor computation of the function in an Ahead-of-Time
(AOT) fashion, which can subsequently be executed with different outputs or
serialized.
import torch
from torch.export import export, ExportedProgram
class Mod(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
a = torch.sin(x)
b = torch.cos(y)
return a + b
example_args = (torch.randn(10, 10), torch.randn(10, 10))
exported_program: ExportedProgram = export(Mod(), args=example_args)
print(exported_program)
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"):
# File: /tmp/ipykernel_552/2550508656.py:6 in forward, code: a = torch.sin(x)
sin: "f32[10, 10]" = torch.ops.aten.sin.default(x); x = None
# File: /tmp/ipykernel_552/2550508656.py:7 in forward, code: b = torch.cos(y)
cos: "f32[10, 10]" = torch.ops.aten.cos.default(y); y = None
# File: /tmp/ipykernel_552/2550508656.py:8 in forward, code: return a + b
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(sin, cos); sin = cos = None
return (add,)
Graph signature:
# inputs
x: USER_INPUT
y: USER_INPUT
# outputs
add: USER_OUTPUT
Range constraints: {}
torch.export
produces a clean intermediate representation (IR) with the
following invariants. More specifications about the IR can be found
here.
Soundness: It is guaranteed to be a sound representation of the original program, and maintains the same calling conventions of the original program.
Normalized: There are no Python semantics within the graph. Submodules from the original programs are inlined to form one fully flattened computational graph.
Graph properties: The graph is purely functional, meaning it does not contain operations with side effects such as mutations or aliasing. It does not mutate any intermediate values, parameters, or buffers.
Metadata: The graph contains metadata captured during tracing, such as a stacktrace from user’s code.
Under the hood, torch.export
leverages the following latest technologies:
TorchDynamo (torch._dynamo) is an internal API that uses a CPython feature called the Frame Evaluation API to safely trace PyTorch graphs. This provides a massively improved graph capturing experience, with much fewer rewrites needed in order to fully trace the PyTorch code.
AOT Autograd provides a functionalized PyTorch graph and ensures the graph is decomposed/lowered to the ATen operator set.
Torch FX (torch.fx) is the underlying representation of the graph, allowing flexible Python-based transformations.
Existing frameworks#
torch.compile()
also utilizes the same PT2 stack as torch.export
, but
is slightly different:
JIT vs. AOT:
torch.compile()
is a JIT compiler whereas which is not intended to be used to produce compiled artifacts outside of deployment.Partial vs. Full Graph Capture: When
torch.compile()
runs into an untraceable part of a model, it will “graph break” and fall back to running the program in the eager Python runtime. In comparison,torch.export
aims to get a full graph representation of a PyTorch model, so it will error out when something untraceable is reached. Sincetorch.export
produces a full graph disjoint from any Python features or runtime, this graph can then be saved, loaded, and run in different environments and languages.Usability tradeoff: Since
torch.compile()
is able to fallback to the Python runtime whenever it reaches something untraceable, it is a lot more flexible.torch.export
will instead require users to provide more information or rewrite their code to make it traceable.
Compared to torch.fx.symbolic_trace()
, torch.export
traces using
TorchDynamo which operates at the Python bytecode level, giving it the ability
to trace arbitrary Python constructs not limited by what Python operator
overloading supports. Additionally, torch.export
keeps fine-grained track of
tensor metadata, so that conditionals on things like tensor shapes do not
fail tracing. In general, torch.export
is expected to work on more user
programs, and produce lower-level graphs (at the torch.ops.aten
operator
level). Note that users can still use torch.fx.symbolic_trace()
as a
preprocessing step before torch.export
.
Compared to torch.jit.script()
, torch.export
does not capture Python
control flow or data structures, unless using explicit control flow operators,
but it supports more Python language features due to its comprehensive coverage
over Python bytecodes. The resulting graphs are simpler and only have straight
line control flow, except for explicit control flow operators.
Compared to torch.jit.trace()
, torch.export
is sound:
it can trace code that performs integer computation on sizes and records
all of the side-conditions necessary to ensure that a particular
trace is valid for other inputs.
Exporting a PyTorch Model#
The main entrypoint is through torch.export.export()
, which takes a
torch.nn.Module
and sample inputs, and
captures the computation graph into an torch.export.ExportedProgram
. An
example:
import torch
from torch.export import export, ExportedProgram
# Simple module for demonstration
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(
in_channels=3, out_channels=16, kernel_size=3, padding=1
)
self.relu = torch.nn.ReLU()
self.maxpool = torch.nn.MaxPool2d(kernel_size=3)
def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
a = self.conv(x)
a.add_(constant)
return self.maxpool(self.relu(a))
example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}
exported_program: ExportedProgram = export(
M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)
# To run the exported program, we can use the `module()` method
print(exported_program.module()(torch.randn(1, 3, 256, 256), constant=torch.ones(1, 16, 256, 256)))
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"):
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1]); x = p_conv_weight = p_conv_bias = None
# File: /tmp/ipykernel_552/2848084713.py:16 in forward, code: a.add_(constant)
add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant); conv2d = constant = None
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_); add_ = None
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/pooling.py:226 in forward, code: return F.max_pool2d(
max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3]); relu = None
return (max_pool2d,)
Graph signature:
# inputs
p_conv_weight: PARAMETER target='conv.weight'
p_conv_bias: PARAMETER target='conv.bias'
x: USER_INPUT
constant: USER_INPUT
# outputs
max_pool2d: USER_OUTPUT
Range constraints: {}
tensor([[[[2.0887, 2.1907, 2.1275, ..., 1.3684, 1.4928, 2.0229],
[1.7835, 1.6115, 2.0009, ..., 2.2490, 2.6569, 1.2965],
[2.1505, 2.0436, 1.7390, ..., 1.4361, 1.8363, 2.6100],
...,
[1.9489, 2.3051, 2.4952, ..., 2.5161, 1.8805, 1.5157],
[2.3767, 1.7951, 1.6146, ..., 2.6584, 1.6433, 1.8267],
[1.8593, 1.9590, 2.3284, ..., 1.5938, 1.9882, 1.4999]],
[[1.8653, 1.6060, 2.0030, ..., 1.5183, 2.1048, 2.3949],
[2.0873, 1.9922, 2.2452, ..., 1.4654, 2.4572, 2.4037],
[3.0251, 1.6842, 1.9151, ..., 1.8950, 2.7677, 2.6601],
...,
[1.5861, 2.0446, 2.0074, ..., 2.3878, 2.5900, 1.8452],
[1.6562, 1.7270, 2.1694, ..., 2.6109, 2.5323, 2.0812],
[1.9849, 1.7291, 1.7152, ..., 3.2563, 2.3749, 2.4448]],
[[1.4913, 1.5377, 1.4863, ..., 1.4031, 1.2342, 1.6980],
[1.4926, 2.0321, 1.5689, ..., 1.3990, 1.7990, 1.5339],
[1.7734, 1.6504, 1.6811, ..., 1.0781, 1.5441, 1.6832],
...,
[1.6687, 1.6608, 1.5724, ..., 1.7221, 1.3858, 1.6259],
[1.7251, 1.5453, 1.6316, ..., 1.4759, 1.5368, 1.6774],
[1.6482, 1.3583, 1.3255, ..., 1.5452, 1.7060, 1.6059]],
...,
[[1.5087, 2.0547, 2.3977, ..., 1.7076, 1.9175, 1.4277],
[1.3376, 1.9704, 1.8310, ..., 1.3870, 1.5846, 2.0109],
[1.4451, 1.7761, 1.1628, ..., 1.5879, 1.7065, 1.4132],
...,
[1.3410, 0.9954, 1.5686, ..., 1.8266, 2.1916, 1.7419],
[1.9889, 0.8389, 1.6769, ..., 1.8209, 1.8571, 1.6561],
[1.3595, 1.4986, 1.5883, ..., 2.0472, 2.2438, 1.8291]],
[[1.8293, 1.3967, 1.8216, ..., 1.7859, 1.3739, 1.4712],
[1.6905, 2.2342, 2.0084, ..., 1.5636, 1.6380, 1.7226],
[2.1179, 1.6951, 2.2278, ..., 1.9913, 1.6082, 1.8772],
...,
[1.7332, 1.6669, 1.6953, ..., 2.4655, 1.3682, 1.8558],
[2.0079, 2.0344, 1.8899, ..., 2.3987, 1.6841, 1.7380],
[1.7526, 1.6374, 1.9604, ..., 1.5397, 2.0050, 1.2693]],
[[2.1842, 1.3602, 1.5108, ..., 1.3031, 1.8189, 1.5950],
[1.7679, 1.7063, 2.0322, ..., 1.5734, 1.6830, 1.3105],
[1.5963, 1.8083, 1.8361, ..., 1.7019, 2.2678, 1.7782],
...,
[1.0404, 1.8092, 1.1909, ..., 1.6335, 1.9506, 1.6260],
[1.4433, 1.2893, 1.4986, ..., 2.1747, 1.9072, 1.3845],
[1.5870, 2.0962, 1.9732, ..., 1.3077, 2.1104, 2.4622]]]],
grad_fn=<MaxPool2DWithIndicesBackward0>)
Inspecting the ExportedProgram
, we can note the following:
The
torch.fx.Graph
contains the computation graph of the original program, along with records of the original code for easy debugging.The graph contains only
torch.ops.aten
operators found here and custom operators.The parameters (weight and bias to conv) are lifted as inputs to the graph, resulting in no
get_attr
nodes in the graph, which previously existed in the result oftorch.fx.symbolic_trace()
.The
torch.export.ExportGraphSignature
models the input and output signature, along with specifying which inputs are parameters.The resulting shape and dtype of tensors produced by each node in the graph is noted. For example, the
conv2d
node will result in a tensor of dtypetorch.float32
and shape (1, 16, 256, 256).
Expressing Dynamism#
By default torch.export
will trace the program assuming all input shapes are
static, and specializing the exported program to those dimensions. One
consequence of this is that at runtime, the program won’t work on inputs with
different shapes, even if they’re valid in eager mode.
An example:
import torch
import traceback as tb
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.branch1 = torch.nn.Sequential(
torch.nn.Linear(64, 32), torch.nn.ReLU()
)
self.branch2 = torch.nn.Sequential(
torch.nn.Linear(128, 64), torch.nn.ReLU()
)
self.buffer = torch.ones(32)
def forward(self, x1, x2):
out1 = self.branch1(x1)
out2 = self.branch2(x2)
return (out1 + self.buffer, out2)
example_args = (torch.randn(32, 64), torch.randn(32, 128))
ep = torch.export.export(M(), example_args)
print(ep)
example_args2 = (torch.randn(64, 64), torch.randn(64, 128))
try:
ep.module()(*example_args2) # fails
except Exception:
tb.print_exc()
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[32, 64]", x2: "f32[32, 128]"):
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[32, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias); x1 = p_branch1_0_weight = p_branch1_0_bias = None
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
relu: "f32[32, 32]" = torch.ops.aten.relu.default(linear); linear = None
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[32, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias); x2 = p_branch2_0_weight = p_branch2_0_bias = None
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
relu_1: "f32[32, 64]" = torch.ops.aten.relu.default(linear_1); linear_1 = None
# File: /tmp/ipykernel_552/1522925308.py:19 in forward, code: return (out1 + self.buffer, out2)
add: "f32[32, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer); relu = c_buffer = None
return (add, relu_1)
Graph signature:
# inputs
p_branch1_0_weight: PARAMETER target='branch1.0.weight'
p_branch1_0_bias: PARAMETER target='branch1.0.bias'
p_branch2_0_weight: PARAMETER target='branch2.0.weight'
p_branch2_0_bias: PARAMETER target='branch2.0.bias'
c_buffer: CONSTANT_TENSOR target='buffer'
x1: USER_INPUT
x2: USER_INPUT
# outputs
add: USER_OUTPUT
relu_1: USER_OUTPUT
Range constraints: {}
Traceback (most recent call last):
File "/tmp/ipykernel_552/1522925308.py", line 28, in <module>
ep.module()(*example_args2) # fails
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/fx/graph_module.py", line 850, in call_wrapped
return self._wrapped_call(self, *args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/fx/graph_module.py", line 426, in __call__
raise e
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/fx/graph_module.py", line 413, in __call__
return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc]
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
return inner()
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1806, in inner
args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc]
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 1005, in _fn
return fn(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/export/_unlift.py", line 83, in _check_input_constraints_pre_hook
_check_input_constraints_for_graph(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_export/utils.py", line 438, in _check_input_constraints_for_graph
_check_symint(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_export/utils.py", line 402, in _check_symint
raise RuntimeError(
RuntimeError: Expected input at *args[0].shape[0] to be equal to 32, but got 64. If you meant for this dimension to be dynamic, please re-export and specify dynamic_shapes (e.g. with Dim.DYNAMIC)
However, some dimensions, such as a batch dimension, can be dynamic and vary
from run to run. Such dimensions must be specified by using the
torch.export.Dim()
API to create them and by passing them into
torch.export.export()
through the dynamic_shapes
argument.
import torch
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.branch1 = torch.nn.Sequential(
torch.nn.Linear(64, 32), torch.nn.ReLU()
)
self.branch2 = torch.nn.Sequential(
torch.nn.Linear(128, 64), torch.nn.ReLU()
)
self.buffer = torch.ones(32)
def forward(self, x1, x2):
out1 = self.branch1(x1)
out2 = self.branch2(x2)
return (out1 + self.buffer, out2)
example_args = (torch.randn(32, 64), torch.randn(32, 128))
# Create a dynamic batch size
batch = torch.export.Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}
ep = torch.export.export(
M(), args=example_args, dynamic_shapes=dynamic_shapes
)
print(ep)
example_args2 = (torch.randn(64, 64), torch.randn(64, 128))
ep.module()(*example_args2) # success
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s24, 64]", x2: "f32[s24, 128]"):
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear: "f32[s24, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias); x1 = p_branch1_0_weight = p_branch1_0_bias = None
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
relu: "f32[s24, 32]" = torch.ops.aten.relu.default(linear); linear = None
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
linear_1: "f32[s24, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias); x2 = p_branch2_0_weight = p_branch2_0_bias = None
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
relu_1: "f32[s24, 64]" = torch.ops.aten.relu.default(linear_1); linear_1 = None
# File: /tmp/ipykernel_552/3456136871.py:18 in forward, code: return (out1 + self.buffer, out2)
add: "f32[s24, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer); relu = c_buffer = None
return (add, relu_1)
Graph signature:
# inputs
p_branch1_0_weight: PARAMETER target='branch1.0.weight'
p_branch1_0_bias: PARAMETER target='branch1.0.bias'
p_branch2_0_weight: PARAMETER target='branch2.0.weight'
p_branch2_0_bias: PARAMETER target='branch2.0.bias'
c_buffer: CONSTANT_TENSOR target='buffer'
x1: USER_INPUT
x2: USER_INPUT
# outputs
add: USER_OUTPUT
relu_1: USER_OUTPUT
Range constraints: {s24: VR[0, int_oo]}
(tensor([[1.0000, 1.7449, 1.0000, ..., 1.0000, 1.0899, 1.0000],
[1.6280, 1.0000, 1.6332, ..., 1.1401, 1.2773, 1.6427],
[1.5786, 1.0000, 1.0000, ..., 1.7563, 1.0000, 1.0000],
...,
[1.0000, 1.0000, 1.4527, ..., 1.0000, 1.0000, 1.0000],
[1.7919, 1.5184, 1.0000, ..., 1.0000, 1.0000, 1.5831],
[1.5096, 1.8946, 1.6435, ..., 1.0000, 1.0000, 1.0000]],
grad_fn=<AddBackward0>),
tensor([[0.9474, 0.0000, 0.0000, ..., 0.7154, 0.0000, 0.1311],
[0.0000, 1.5020, 0.6982, ..., 0.0000, 0.0000, 0.4365],
[0.5233, 0.0000, 0.0407, ..., 0.8772, 0.0000, 0.0000],
...,
[0.0000, 0.0000, 0.0000, ..., 0.0955, 0.0000, 0.5072],
[1.3185, 0.0000, 1.1436, ..., 0.0000, 0.6630, 0.2309],
[0.1237, 0.2012, 0.0000, ..., 0.2207, 0.0000, 0.2710]],
grad_fn=<ReluBackward0>))
Some additional things to note:
Through the
torch.export.Dim()
API and thedynamic_shapes
argument, we specified the first dimension of each input to be dynamic. Looking at the inputsx1
andx2
, they have a symbolic shape of(s0, 64)
and(s0, 128)
, instead of the(32, 64)
and(32, 128)
shaped tensors that we passed in as example inputs.s0
is a symbol representing that this dimension can be a range of values.exported_program.range_constraints
describes the ranges of each symbol appearing in the graph. In this case, we see thats0
has the range [0, int_oo]. For technical reasons that are difficult to explain here, they are assumed to be not 0 or 1. This is not a bug, and does not necessarily mean that the exported program will not work for dimensions 0 or 1. See The 0/1 Specialization Problem for an in-depth discussion of this topic.
In the example, we used Dim("batch")
to create a dynamic dimension. This is
the most explicit way to specify dynamism. We can also use Dim.DYNAMIC
and
Dim.AUTO
to specify dynamism. We will go over both methods in the next section.
Named Dims#
For every dimension specified with Dim("name")
, we will allocate a symbolic
shape. Specifying a Dim
with the same name will result in the same symbol
to be generated. This allows users to specify what symbols are allocated for
each input dimension.
batch = Dim("batch")
dynamic_shapes = {"x1": {0: dim}, "x2": {0: batch}}
For each Dim
, we can specify minimum and maximum values. We also allow
specifying relations between Dim
s in univariate linear expressions: A * dim + B
.
This allows users to specify more complex constraints like integer divisibility
for dynamic dimensions. These features allow for users to place explicit
restrictions on the dynamic behavior of the ExportedProgram
produced.
dx = Dim("dx", min=4, max=256)
dh = Dim("dh", max=512)
dynamic_shapes = {
"x": (dx, None),
"y": (2 * dx, dh),
}
However, ConstraintViolationErrors
will be raised if the while tracing, we emit guards
that conflict with the relations or static/dynamic specifications given. For
example, in the above specification, the following is asserted:
x.shape[0]
is to have range[4, 256]
, and related toy.shape[0]
byy.shape[0] == 2 * x.shape[0]
.x.shape[1]
is static.y.shape[1]
has range[0, 512]
, and is unrelated to any other dimension.
If any of these assertions are found to be incorrect while tracing (ex.
x.shape[0]
is static, or y.shape[1]
has a smaller range, or
y.shape[0] != 2 * x.shape[0]
), then a ConstraintViolationError
will be
raised, and the user will need to change their dynamic_shapes
specification.
Dim Hints#
Instead of explicitly specifying dynamism using Dim("name")
, we can let
torch.export
infer the ranges and relationships of the dynamic values using
Dim.DYNAMIC
. This is also a more convenient way to specify dynamism when you
don’t know specifically how dynamic your dynamic values are.
dynamic_shapes = {
"x": (Dim.DYNAMIC, None),
"y": (Dim.DYNAMIC, Dim.DYNAMIC),
}
We can also specify min/max values for Dim.DYNAMIC
, which will serve as hints
to export. But if while tracing export found the range to be different, it will
automatically update the range without raising an error. We also cannot specify
relationships between dynamic values. Instead, this will be inferred by export,
and exposed to users through an inspection of assertions within the graph. In
this method of specifying dynamism, ConstraintViolationErrors
will only be
raised if the specified value is inferred to be static.
An even more convenient way to specify dynamism is to use Dim.AUTO
, which will
behave like Dim.DYNAMIC
, but will not raise an error if the dimension is
inferred to be static. This is useful for when you have no idea what the dynamic
values are, and want to export the program with a “best effort” dynamic approach.
ShapesCollection#
When specifying which inputs are dynamic via dynamic_shapes
, we must specify
the dynamism of every input. For example, given the following inputs:
args = {"x": tensor_x, "others": [tensor_y, tensor_z]}
we would need to specify the dynamism of tensor_x
, tensor_y
, and tensor_z
along with the dynamic shapes:
# With named-Dims
dim = torch.export.Dim(...)
dynamic_shapes = {"x": {0: dim, 1: dim + 1}, "others": [{0: dim * 2}, None]}
torch.export(..., args, dynamic_shapes=dynamic_shapes)
However, this is particularly complicated as we need to specify the
dynamic_shapes
specification in the same nested input structure as the input
arguments. Instead, an easier way to specify dynamic shapes is with the helper
utility torch.export.ShapesCollection
, where instead of specifying the
dynamism of every single input, we can just assign directly which input
dimensions are dynamic.
import torch
class M(torch.nn.Module):
def forward(self, inp):
x = inp["x"] * 1
y = inp["others"][0] * 2
z = inp["others"][1] * 3
return x, y, z
tensor_x = torch.randn(3, 4, 8)
tensor_y = torch.randn(6)
tensor_z = torch.randn(6)
args = {"x": tensor_x, "others": [tensor_y, tensor_z]}
dim = torch.export.Dim("dim")
sc = torch.export.ShapesCollection()
sc[tensor_x] = (dim, dim + 1, 8)
sc[tensor_y] = {0: dim * 2}
print(sc.dynamic_shapes(M(), (args,)))
ep = torch.export.export(M(), (args,), dynamic_shapes=sc)
print(ep)
{'inp': {'x': (Dim('dim', min=0), dim + 1, 8), 'others': [{0: 2*dim}, None]}}
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, inp_x: "f32[s96, s96 + 1, 8]", inp_others_0: "f32[2*s96]", inp_others_1: "f32[6]"):
# File: /tmp/ipykernel_552/1070110726.py:5 in forward, code: x = inp["x"] * 1
mul: "f32[s96, s96 + 1, 8]" = torch.ops.aten.mul.Tensor(inp_x, 1); inp_x = None
# File: /tmp/ipykernel_552/1070110726.py:6 in forward, code: y = inp["others"][0] * 2
mul_1: "f32[2*s96]" = torch.ops.aten.mul.Tensor(inp_others_0, 2); inp_others_0 = None
# File: /tmp/ipykernel_552/1070110726.py:7 in forward, code: z = inp["others"][1] * 3
mul_2: "f32[6]" = torch.ops.aten.mul.Tensor(inp_others_1, 3); inp_others_1 = None
return (mul, mul_1, mul_2)
Graph signature:
# inputs
inp_x: USER_INPUT
inp_others_0: USER_INPUT
inp_others_1: USER_INPUT
# outputs
mul: USER_OUTPUT
mul_1: USER_OUTPUT
mul_2: USER_OUTPUT
Range constraints: {s96: VR[0, int_oo], s96 + 1: VR[1, int_oo], 2*s96: VR[0, int_oo]}
AdditionalInputs#
In the case where you don’t know how dynamic your inputs are, but you have an
ample set of testing or profiling data that can provide a fair sense of
representative inputs for a model, you can use
torch.export.AdditionalInputs
in place of dynamic_shapes
. You can
specify all the possible inputs used to trace the program, and
AdditionalInputs
will infer which inputs are dynamic based on which input
shapes are changing.
Example:
import dataclasses
import torch
import torch.utils._pytree as pytree
@dataclasses.dataclass
class D:
b: bool
i: int
f: float
t: torch.Tensor
pytree.register_dataclass(D)
class M(torch.nn.Module):
def forward(self, d: D):
return d.i + d.f + d.t
input1 = (D(True, 3, 3.0, torch.ones(3)),)
input2 = (D(True, 4, 3.0, torch.ones(4)),)
ai = torch.export.AdditionalInputs()
ai.add(input1)
ai.add(input2)
print(ai.dynamic_shapes(M(), input1))
ep = torch.export.export(M(), input1, dynamic_shapes=ai)
print(ep)
{'d': [None, _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), None, (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True),)]}
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, d_b, d_i: "Sym(s70)", d_f, d_t: "f32[s54]"):
# File: /tmp/ipykernel_552/829931439.py:16 in forward, code: return d.i + d.f + d.t
sym_float: "Sym(ToFloat(s70))" = torch.sym_float(d_i); d_i = None
add: "Sym(ToFloat(s70) + 3.0)" = sym_float + 3.0; sym_float = None
add_1: "f32[s54]" = torch.ops.aten.add.Tensor(d_t, add); d_t = add = None
return (add_1,)
Graph signature:
# inputs
d_b: USER_INPUT
d_i: USER_INPUT
d_f: USER_INPUT
d_t: USER_INPUT
# outputs
add_1: USER_OUTPUT
Range constraints: {s70: VR[0, int_oo], s54: VR[2, int_oo]}
Serialization#
To save the ExportedProgram
, users can use the torch.export.save()
and
torch.export.load()
APIs. The resulting file is a zipfile with a specific
structure. The details of the structure are defined in the
PT2 Archive Spec.
An example:
import torch
class MyModule(torch.nn.Module):
def forward(self, x):
return x + 10
exported_program = torch.export.export(MyModule(), (torch.randn(5),))
torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')
Export IR, Decompositions#
The graph produced by torch.export
returns a graph containing only
ATen operators, which are the basic unit of
computation in PyTorch. As there are over
3000 ATen operators, export provides a way to narrow down the operator set used
in the graph based on certain characteristics, creating different IRs.
By default, export produces the most generic IR which contains all ATen
operators, including both functional and non-functional operators. A functional
operator is one that does not contain any mutations or aliasing of the inputs.
You can find a list of all ATen operators
here
and you can inspect if an operator is functional by checking
op._schema.is_mutable
.
This generic IR can be used to train in eager PyTorch Autograd.
import torch
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return (x,)
ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
print(ep_for_training.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1) # type: ignore[has-type]
add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = add_ = None
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True); conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
return (batch_norm,)
However, if you want to use the IR for inference, or decrease the amount of
operators being used, you can lower the graph through the
ExportedProgram.run_decompositions()
API. This method decomposes the
ATen operators into the ones specified in the decomposition table, and
functionalizes the graph.
By specifying an empty set, we’re only performing functionalization, and does not do any additional decompositions. This results in an IR which contains ~2000 operators (instead of the 3000 operators above), and is ideal for inference cases.
import torch
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return (x,)
ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
with torch.no_grad():
ep_for_inference = ep_for_training.run_decompositions(decomp_table={})
print(ep_for_inference.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias); x = p_conv_weight = p_conv_bias = None
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1) # type: ignore[has-type]
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = None
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05); conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
return (getitem_3, getitem_4, add, getitem)
As we can see, the previously in-place operator,
torch.ops.aten.add_.default
has now been replaced with
torch.ops.aten.add.default
, a functional operator.
We can also further lower this exported program to an operator set which only
contains the
Core ATen Operator Set <https://pytorch.org/docs/main/torch.compiler_ir.html#core-aten-ir>
__,
which is a collection of only ~180 operators. This IR is optimal for backends
who do not want to reimplement all ATen operators.
import torch
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return (x,)
ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
with torch.no_grad():
core_aten_ir = ep_for_training.run_decompositions(decomp_table=None)
print(core_aten_ir.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); x = p_conv_weight = p_conv_bias = None
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1) # type: ignore[has-type]
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = None
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05); convolution = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
return (getitem_3, getitem_4, add, getitem)
We now see that torch.ops.aten.conv2d.default
has been decomposed
into torch.ops.aten.convolution.default
. This is because convolution
is a more “core” operator, as operations like conv1d
and conv2d
can be
implemented using the same op.
We can also specify our own decomposition behaviors:
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(1, 3, 1, 1)
self.bn = torch.nn.BatchNorm2d(3)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return (x,)
ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
my_decomp_table = torch.export.default_decompositions()
def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):
return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)
my_decomp_table[torch.ops.aten.conv2d.default] = my_awesome_custom_conv2d_function
my_ep = ep_for_training.run_decompositions(my_decomp_table)
print(my_ep.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); x = p_conv_weight = p_conv_bias = None
mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2); convolution = None
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1) # type: ignore[has-type]
add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1); b_bn_num_batches_tracked = None
# File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
_native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05); mul = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4]; _native_batch_norm_legit_functional = None
return (getitem_3, getitem_4, add, getitem)
Notice that instead of torch.ops.aten.conv2d.default
being decomposed
into torch.ops.aten.convolution.default
, it is now decomposed into
torch.ops.aten.convolution.default
and torch.ops.aten.mul.Tensor
,
which matches our custom decomposition rule.
Limitations of torch.export#
As torch.export
is a one-shot process for capturing a computation graph from
a PyTorch program, it might ultimately run into untraceable parts of programs as
it is nearly impossible to support tracing all PyTorch and Python features. In
the case of torch.compile
, an unsupported operation will cause a “graph
break” and the unsupported operation will be run with default Python evaluation.
In contrast, torch.export
will require users to provide additional
information or rewrite parts of their code to make it traceable.
Draft-export is a great resource for listing out graphs breaks that will be encountered when tracing the program, along with additional debug information to solve those errors.
ExportDB is also great resource for learning about the kinds of programs that are supported and unsupported, along with ways to rewrite programs to make them traceable.
TorchDynamo unsupported#
When using torch.export
with strict=True
, this will use TorchDynamo to
evaluate the program at the Python bytecode level to trace the program into a
graph. Compared to previous tracing frameworks, there will be significantly
fewer rewrites required to make a program traceable, but there will still be
some Python features that are unsupported. An option to get past dealing with
this graph breaks is by using
non-strict export through changing the strict
flag
to strict=False
.
Data/Shape-Dependent Control Flow#
Graph breaks can also be encountered on data-dependent control flow (if x.shape[0] > 2
) when shapes are not being specialized, as a tracing compiler cannot
possibly deal with without generating code for a combinatorially exploding
number of paths. In such cases, users will need to rewrite their code using
special control flow operators. Currently, we support torch.cond
to express if-else like control flow (more coming soon!).
You can also refer to this tutorial for more ways of addressing data-dependent errors.
Missing Fake/Meta Kernels for Operators#
When tracing, a FakeTensor kernel (aka meta kernel) is required for all operators. This is used to reason about the input/output shapes for this operator.
Please see this tutorial for more details.
In the unfortunate case where your model uses an ATen operator that is does not have a FakeTensor kernel implementation yet, please file an issue.