Rate this Page

torch.export#

Created On: Jun 12, 2025 | Last Updated On: Jul 17, 2025

Overview#

torch.export.export() takes a torch.nn.Module and produces a traced graph representing only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, which can subsequently be executed with different outputs or serialized.

import torch
from torch.export import export, ExportedProgram

class Mod(torch.nn.Module):
    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        a = torch.sin(x)
        b = torch.cos(y)
        return a + b

example_args = (torch.randn(10, 10), torch.randn(10, 10))

exported_program: ExportedProgram = export(Mod(), args=example_args)
print(exported_program)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, x: "f32[10, 10]", y: "f32[10, 10]"):
             # File: /tmp/ipykernel_552/2550508656.py:6 in forward, code: a = torch.sin(x)
            sin: "f32[10, 10]" = torch.ops.aten.sin.default(x);  x = None
            
             # File: /tmp/ipykernel_552/2550508656.py:7 in forward, code: b = torch.cos(y)
            cos: "f32[10, 10]" = torch.ops.aten.cos.default(y);  y = None
            
             # File: /tmp/ipykernel_552/2550508656.py:8 in forward, code: return a + b
            add: "f32[10, 10]" = torch.ops.aten.add.Tensor(sin, cos);  sin = cos = None
            return (add,)
            
Graph signature: 
    # inputs
    x: USER_INPUT
    y: USER_INPUT
    
    # outputs
    add: USER_OUTPUT
    
Range constraints: {}

torch.export produces a clean intermediate representation (IR) with the following invariants. More specifications about the IR can be found here.

  • Soundness: It is guaranteed to be a sound representation of the original program, and maintains the same calling conventions of the original program.

  • Normalized: There are no Python semantics within the graph. Submodules from the original programs are inlined to form one fully flattened computational graph.

  • Graph properties: The graph is purely functional, meaning it does not contain operations with side effects such as mutations or aliasing. It does not mutate any intermediate values, parameters, or buffers.

  • Metadata: The graph contains metadata captured during tracing, such as a stacktrace from user’s code.

Under the hood, torch.export leverages the following latest technologies:

  • TorchDynamo (torch._dynamo) is an internal API that uses a CPython feature called the Frame Evaluation API to safely trace PyTorch graphs. This provides a massively improved graph capturing experience, with much fewer rewrites needed in order to fully trace the PyTorch code.

  • AOT Autograd provides a functionalized PyTorch graph and ensures the graph is decomposed/lowered to the ATen operator set.

  • Torch FX (torch.fx) is the underlying representation of the graph, allowing flexible Python-based transformations.

Existing frameworks#

torch.compile() also utilizes the same PT2 stack as torch.export, but is slightly different:

  • JIT vs. AOT: torch.compile() is a JIT compiler whereas which is not intended to be used to produce compiled artifacts outside of deployment.

  • Partial vs. Full Graph Capture: When torch.compile() runs into an untraceable part of a model, it will “graph break” and fall back to running the program in the eager Python runtime. In comparison, torch.export aims to get a full graph representation of a PyTorch model, so it will error out when something untraceable is reached. Since torch.export produces a full graph disjoint from any Python features or runtime, this graph can then be saved, loaded, and run in different environments and languages.

  • Usability tradeoff: Since torch.compile() is able to fallback to the Python runtime whenever it reaches something untraceable, it is a lot more flexible. torch.export will instead require users to provide more information or rewrite their code to make it traceable.

Compared to torch.fx.symbolic_trace(), torch.export traces using TorchDynamo which operates at the Python bytecode level, giving it the ability to trace arbitrary Python constructs not limited by what Python operator overloading supports. Additionally, torch.export keeps fine-grained track of tensor metadata, so that conditionals on things like tensor shapes do not fail tracing. In general, torch.export is expected to work on more user programs, and produce lower-level graphs (at the torch.ops.aten operator level). Note that users can still use torch.fx.symbolic_trace() as a preprocessing step before torch.export.

Compared to torch.jit.script(), torch.export does not capture Python control flow or data structures, unless using explicit control flow operators, but it supports more Python language features due to its comprehensive coverage over Python bytecodes. The resulting graphs are simpler and only have straight line control flow, except for explicit control flow operators.

Compared to torch.jit.trace(), torch.export is sound: it can trace code that performs integer computation on sizes and records all of the side-conditions necessary to ensure that a particular trace is valid for other inputs.

Exporting a PyTorch Model#

The main entrypoint is through torch.export.export(), which takes a torch.nn.Module and sample inputs, and captures the computation graph into an torch.export.ExportedProgram. An example:

import torch
from torch.export import export, ExportedProgram

# Simple module for demonstration
class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(
            in_channels=3, out_channels=16, kernel_size=3, padding=1
        )
        self.relu = torch.nn.ReLU()
        self.maxpool = torch.nn.MaxPool2d(kernel_size=3)

    def forward(self, x: torch.Tensor, *, constant=None) -> torch.Tensor:
        a = self.conv(x)
        a.add_(constant)
        return self.maxpool(self.relu(a))

example_args = (torch.randn(1, 3, 256, 256),)
example_kwargs = {"constant": torch.ones(1, 16, 256, 256)}

exported_program: ExportedProgram = export(
    M(), args=example_args, kwargs=example_kwargs
)
print(exported_program)

# To run the exported program, we can use the `module()` method
print(exported_program.module()(torch.randn(1, 3, 256, 256), constant=torch.ones(1, 16, 256, 256)))
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_conv_weight: "f32[16, 3, 3, 3]", p_conv_bias: "f32[16]", x: "f32[1, 3, 256, 256]", constant: "f32[1, 16, 256, 256]"):
             # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
            conv2d: "f32[1, 16, 256, 256]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias, [1, 1], [1, 1]);  x = p_conv_weight = p_conv_bias = None
            
             # File: /tmp/ipykernel_552/2848084713.py:16 in forward, code: a.add_(constant)
            add_: "f32[1, 16, 256, 256]" = torch.ops.aten.add_.Tensor(conv2d, constant);  conv2d = constant = None
            
             # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
            relu: "f32[1, 16, 256, 256]" = torch.ops.aten.relu.default(add_);  add_ = None
            
             # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/pooling.py:226 in forward, code: return F.max_pool2d(
            max_pool2d: "f32[1, 16, 85, 85]" = torch.ops.aten.max_pool2d.default(relu, [3, 3], [3, 3]);  relu = None
            return (max_pool2d,)
            
Graph signature: 
    # inputs
    p_conv_weight: PARAMETER target='conv.weight'
    p_conv_bias: PARAMETER target='conv.bias'
    x: USER_INPUT
    constant: USER_INPUT
    
    # outputs
    max_pool2d: USER_OUTPUT
    
Range constraints: {}

tensor([[[[2.0887, 2.1907, 2.1275,  ..., 1.3684, 1.4928, 2.0229],
          [1.7835, 1.6115, 2.0009,  ..., 2.2490, 2.6569, 1.2965],
          [2.1505, 2.0436, 1.7390,  ..., 1.4361, 1.8363, 2.6100],
          ...,
          [1.9489, 2.3051, 2.4952,  ..., 2.5161, 1.8805, 1.5157],
          [2.3767, 1.7951, 1.6146,  ..., 2.6584, 1.6433, 1.8267],
          [1.8593, 1.9590, 2.3284,  ..., 1.5938, 1.9882, 1.4999]],

         [[1.8653, 1.6060, 2.0030,  ..., 1.5183, 2.1048, 2.3949],
          [2.0873, 1.9922, 2.2452,  ..., 1.4654, 2.4572, 2.4037],
          [3.0251, 1.6842, 1.9151,  ..., 1.8950, 2.7677, 2.6601],
          ...,
          [1.5861, 2.0446, 2.0074,  ..., 2.3878, 2.5900, 1.8452],
          [1.6562, 1.7270, 2.1694,  ..., 2.6109, 2.5323, 2.0812],
          [1.9849, 1.7291, 1.7152,  ..., 3.2563, 2.3749, 2.4448]],

         [[1.4913, 1.5377, 1.4863,  ..., 1.4031, 1.2342, 1.6980],
          [1.4926, 2.0321, 1.5689,  ..., 1.3990, 1.7990, 1.5339],
          [1.7734, 1.6504, 1.6811,  ..., 1.0781, 1.5441, 1.6832],
          ...,
          [1.6687, 1.6608, 1.5724,  ..., 1.7221, 1.3858, 1.6259],
          [1.7251, 1.5453, 1.6316,  ..., 1.4759, 1.5368, 1.6774],
          [1.6482, 1.3583, 1.3255,  ..., 1.5452, 1.7060, 1.6059]],

         ...,

         [[1.5087, 2.0547, 2.3977,  ..., 1.7076, 1.9175, 1.4277],
          [1.3376, 1.9704, 1.8310,  ..., 1.3870, 1.5846, 2.0109],
          [1.4451, 1.7761, 1.1628,  ..., 1.5879, 1.7065, 1.4132],
          ...,
          [1.3410, 0.9954, 1.5686,  ..., 1.8266, 2.1916, 1.7419],
          [1.9889, 0.8389, 1.6769,  ..., 1.8209, 1.8571, 1.6561],
          [1.3595, 1.4986, 1.5883,  ..., 2.0472, 2.2438, 1.8291]],

         [[1.8293, 1.3967, 1.8216,  ..., 1.7859, 1.3739, 1.4712],
          [1.6905, 2.2342, 2.0084,  ..., 1.5636, 1.6380, 1.7226],
          [2.1179, 1.6951, 2.2278,  ..., 1.9913, 1.6082, 1.8772],
          ...,
          [1.7332, 1.6669, 1.6953,  ..., 2.4655, 1.3682, 1.8558],
          [2.0079, 2.0344, 1.8899,  ..., 2.3987, 1.6841, 1.7380],
          [1.7526, 1.6374, 1.9604,  ..., 1.5397, 2.0050, 1.2693]],

         [[2.1842, 1.3602, 1.5108,  ..., 1.3031, 1.8189, 1.5950],
          [1.7679, 1.7063, 2.0322,  ..., 1.5734, 1.6830, 1.3105],
          [1.5963, 1.8083, 1.8361,  ..., 1.7019, 2.2678, 1.7782],
          ...,
          [1.0404, 1.8092, 1.1909,  ..., 1.6335, 1.9506, 1.6260],
          [1.4433, 1.2893, 1.4986,  ..., 2.1747, 1.9072, 1.3845],
          [1.5870, 2.0962, 1.9732,  ..., 1.3077, 2.1104, 2.4622]]]],
       grad_fn=<MaxPool2DWithIndicesBackward0>)

Inspecting the ExportedProgram, we can note the following:

  • The torch.fx.Graph contains the computation graph of the original program, along with records of the original code for easy debugging.

  • The graph contains only torch.ops.aten operators found here and custom operators.

  • The parameters (weight and bias to conv) are lifted as inputs to the graph, resulting in no get_attr nodes in the graph, which previously existed in the result of torch.fx.symbolic_trace().

  • The torch.export.ExportGraphSignature models the input and output signature, along with specifying which inputs are parameters.

  • The resulting shape and dtype of tensors produced by each node in the graph is noted. For example, the conv2d node will result in a tensor of dtype torch.float32 and shape (1, 16, 256, 256).

Expressing Dynamism#

By default torch.export will trace the program assuming all input shapes are static, and specializing the exported program to those dimensions. One consequence of this is that at runtime, the program won’t work on inputs with different shapes, even if they’re valid in eager mode.

An example:

import torch
import traceback as tb

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.branch1 = torch.nn.Sequential(
            torch.nn.Linear(64, 32), torch.nn.ReLU()
        )
        self.branch2 = torch.nn.Sequential(
            torch.nn.Linear(128, 64), torch.nn.ReLU()
        )
        self.buffer = torch.ones(32)

    def forward(self, x1, x2):
        out1 = self.branch1(x1)
        out2 = self.branch2(x2)
        return (out1 + self.buffer, out2)

example_args = (torch.randn(32, 64), torch.randn(32, 128))

ep = torch.export.export(M(), example_args)
print(ep)

example_args2 = (torch.randn(64, 64), torch.randn(64, 128))
try:
    ep.module()(*example_args2)  # fails
except Exception:
    tb.print_exc()
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[32, 64]", x2: "f32[32, 128]"):
             # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[32, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias);  x1 = p_branch1_0_weight = p_branch1_0_bias = None
            
             # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
            relu: "f32[32, 32]" = torch.ops.aten.relu.default(linear);  linear = None
            
             # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_1: "f32[32, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias);  x2 = p_branch2_0_weight = p_branch2_0_bias = None
            
             # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
            relu_1: "f32[32, 64]" = torch.ops.aten.relu.default(linear_1);  linear_1 = None
            
             # File: /tmp/ipykernel_552/1522925308.py:19 in forward, code: return (out1 + self.buffer, out2)
            add: "f32[32, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer);  relu = c_buffer = None
            return (add, relu_1)
            
Graph signature: 
    # inputs
    p_branch1_0_weight: PARAMETER target='branch1.0.weight'
    p_branch1_0_bias: PARAMETER target='branch1.0.bias'
    p_branch2_0_weight: PARAMETER target='branch2.0.weight'
    p_branch2_0_bias: PARAMETER target='branch2.0.bias'
    c_buffer: CONSTANT_TENSOR target='buffer'
    x1: USER_INPUT
    x2: USER_INPUT
    
    # outputs
    add: USER_OUTPUT
    relu_1: USER_OUTPUT
    
Range constraints: {}
Traceback (most recent call last):
  File "/tmp/ipykernel_552/1522925308.py", line 28, in <module>
    ep.module()(*example_args2)  # fails
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/fx/graph_module.py", line 850, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/fx/graph_module.py", line 426, in __call__
    raise e
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/fx/graph_module.py", line 413, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1879, in _call_impl
    return inner()
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1806, in inner
    args_kwargs_result = hook(self, args, kwargs)  # type: ignore[misc]
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 1005, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/export/_unlift.py", line 83, in _check_input_constraints_pre_hook
    _check_input_constraints_for_graph(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_export/utils.py", line 438, in _check_input_constraints_for_graph
    _check_symint(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/_export/utils.py", line 402, in _check_symint
    raise RuntimeError(
RuntimeError: Expected input at *args[0].shape[0] to be equal to 32, but got 64. If you meant for this dimension to be dynamic, please re-export and specify dynamic_shapes (e.g. with Dim.DYNAMIC)

However, some dimensions, such as a batch dimension, can be dynamic and vary from run to run. Such dimensions must be specified by using the torch.export.Dim() API to create them and by passing them into torch.export.export() through the dynamic_shapes argument.

import torch

class M(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.branch1 = torch.nn.Sequential(
            torch.nn.Linear(64, 32), torch.nn.ReLU()
        )
        self.branch2 = torch.nn.Sequential(
            torch.nn.Linear(128, 64), torch.nn.ReLU()
        )
        self.buffer = torch.ones(32)

    def forward(self, x1, x2):
        out1 = self.branch1(x1)
        out2 = self.branch2(x2)
        return (out1 + self.buffer, out2)

example_args = (torch.randn(32, 64), torch.randn(32, 128))

# Create a dynamic batch size
batch = torch.export.Dim("batch")
# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x1": {0: batch}, "x2": {0: batch}}

ep = torch.export.export(
    M(), args=example_args, dynamic_shapes=dynamic_shapes
)
print(ep)

example_args2 = (torch.randn(64, 64), torch.randn(64, 128))
ep.module()(*example_args2)  # success
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_branch1_0_weight: "f32[32, 64]", p_branch1_0_bias: "f32[32]", p_branch2_0_weight: "f32[64, 128]", p_branch2_0_bias: "f32[64]", c_buffer: "f32[32]", x1: "f32[s24, 64]", x2: "f32[s24, 128]"):
             # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear: "f32[s24, 32]" = torch.ops.aten.linear.default(x1, p_branch1_0_weight, p_branch1_0_bias);  x1 = p_branch1_0_weight = p_branch1_0_bias = None
            
             # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
            relu: "f32[s24, 32]" = torch.ops.aten.relu.default(linear);  linear = None
            
             # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/linear.py:134 in forward, code: return F.linear(input, self.weight, self.bias)
            linear_1: "f32[s24, 64]" = torch.ops.aten.linear.default(x2, p_branch2_0_weight, p_branch2_0_bias);  x2 = p_branch2_0_weight = p_branch2_0_bias = None
            
             # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/activation.py:144 in forward, code: return F.relu(input, inplace=self.inplace)
            relu_1: "f32[s24, 64]" = torch.ops.aten.relu.default(linear_1);  linear_1 = None
            
             # File: /tmp/ipykernel_552/3456136871.py:18 in forward, code: return (out1 + self.buffer, out2)
            add: "f32[s24, 32]" = torch.ops.aten.add.Tensor(relu, c_buffer);  relu = c_buffer = None
            return (add, relu_1)
            
Graph signature: 
    # inputs
    p_branch1_0_weight: PARAMETER target='branch1.0.weight'
    p_branch1_0_bias: PARAMETER target='branch1.0.bias'
    p_branch2_0_weight: PARAMETER target='branch2.0.weight'
    p_branch2_0_bias: PARAMETER target='branch2.0.bias'
    c_buffer: CONSTANT_TENSOR target='buffer'
    x1: USER_INPUT
    x2: USER_INPUT
    
    # outputs
    add: USER_OUTPUT
    relu_1: USER_OUTPUT
    
Range constraints: {s24: VR[0, int_oo]}
(tensor([[1.0000, 1.7449, 1.0000,  ..., 1.0000, 1.0899, 1.0000],
         [1.6280, 1.0000, 1.6332,  ..., 1.1401, 1.2773, 1.6427],
         [1.5786, 1.0000, 1.0000,  ..., 1.7563, 1.0000, 1.0000],
         ...,
         [1.0000, 1.0000, 1.4527,  ..., 1.0000, 1.0000, 1.0000],
         [1.7919, 1.5184, 1.0000,  ..., 1.0000, 1.0000, 1.5831],
         [1.5096, 1.8946, 1.6435,  ..., 1.0000, 1.0000, 1.0000]],
        grad_fn=<AddBackward0>),
 tensor([[0.9474, 0.0000, 0.0000,  ..., 0.7154, 0.0000, 0.1311],
         [0.0000, 1.5020, 0.6982,  ..., 0.0000, 0.0000, 0.4365],
         [0.5233, 0.0000, 0.0407,  ..., 0.8772, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0955, 0.0000, 0.5072],
         [1.3185, 0.0000, 1.1436,  ..., 0.0000, 0.6630, 0.2309],
         [0.1237, 0.2012, 0.0000,  ..., 0.2207, 0.0000, 0.2710]],
        grad_fn=<ReluBackward0>))

Some additional things to note:

  • Through the torch.export.Dim() API and the dynamic_shapes argument, we specified the first dimension of each input to be dynamic. Looking at the inputs x1 and x2, they have a symbolic shape of (s0, 64) and (s0, 128), instead of the (32, 64) and (32, 128) shaped tensors that we passed in as example inputs. s0 is a symbol representing that this dimension can be a range of values.

  • exported_program.range_constraints describes the ranges of each symbol appearing in the graph. In this case, we see that s0 has the range [0, int_oo]. For technical reasons that are difficult to explain here, they are assumed to be not 0 or 1. This is not a bug, and does not necessarily mean that the exported program will not work for dimensions 0 or 1. See The 0/1 Specialization Problem for an in-depth discussion of this topic.

In the example, we used Dim("batch") to create a dynamic dimension. This is the most explicit way to specify dynamism. We can also use Dim.DYNAMIC and Dim.AUTO to specify dynamism. We will go over both methods in the next section.

Named Dims#

For every dimension specified with Dim("name"), we will allocate a symbolic shape. Specifying a Dim with the same name will result in the same symbol to be generated. This allows users to specify what symbols are allocated for each input dimension.

batch = Dim("batch")
dynamic_shapes = {"x1": {0: dim}, "x2": {0: batch}}

For each Dim, we can specify minimum and maximum values. We also allow specifying relations between Dims in univariate linear expressions: A * dim + B. This allows users to specify more complex constraints like integer divisibility for dynamic dimensions. These features allow for users to place explicit restrictions on the dynamic behavior of the ExportedProgram produced.

dx = Dim("dx", min=4, max=256)
dh = Dim("dh", max=512)
dynamic_shapes = {
    "x": (dx, None),
    "y": (2 * dx, dh),
}

However, ConstraintViolationErrors will be raised if the while tracing, we emit guards that conflict with the relations or static/dynamic specifications given. For example, in the above specification, the following is asserted:

  • x.shape[0] is to have range [4, 256], and related to y.shape[0] by y.shape[0] == 2 * x.shape[0].

  • x.shape[1] is static.

  • y.shape[1] has range [0, 512], and is unrelated to any other dimension.

If any of these assertions are found to be incorrect while tracing (ex. x.shape[0] is static, or y.shape[1] has a smaller range, or y.shape[0] != 2 * x.shape[0]), then a ConstraintViolationError will be raised, and the user will need to change their dynamic_shapes specification.

Dim Hints#

Instead of explicitly specifying dynamism using Dim("name"), we can let torch.export infer the ranges and relationships of the dynamic values using Dim.DYNAMIC. This is also a more convenient way to specify dynamism when you don’t know specifically how dynamic your dynamic values are.

dynamic_shapes = {
    "x": (Dim.DYNAMIC, None),
    "y": (Dim.DYNAMIC, Dim.DYNAMIC),
}

We can also specify min/max values for Dim.DYNAMIC, which will serve as hints to export. But if while tracing export found the range to be different, it will automatically update the range without raising an error. We also cannot specify relationships between dynamic values. Instead, this will be inferred by export, and exposed to users through an inspection of assertions within the graph. In this method of specifying dynamism, ConstraintViolationErrors will only be raised if the specified value is inferred to be static.

An even more convenient way to specify dynamism is to use Dim.AUTO, which will behave like Dim.DYNAMIC, but will not raise an error if the dimension is inferred to be static. This is useful for when you have no idea what the dynamic values are, and want to export the program with a “best effort” dynamic approach.

ShapesCollection#

When specifying which inputs are dynamic via dynamic_shapes, we must specify the dynamism of every input. For example, given the following inputs:

args = {"x": tensor_x, "others": [tensor_y, tensor_z]}

we would need to specify the dynamism of tensor_x, tensor_y, and tensor_z along with the dynamic shapes:

# With named-Dims
dim = torch.export.Dim(...)
dynamic_shapes = {"x": {0: dim, 1: dim + 1}, "others": [{0: dim * 2}, None]}

torch.export(..., args, dynamic_shapes=dynamic_shapes)

However, this is particularly complicated as we need to specify the dynamic_shapes specification in the same nested input structure as the input arguments. Instead, an easier way to specify dynamic shapes is with the helper utility torch.export.ShapesCollection, where instead of specifying the dynamism of every single input, we can just assign directly which input dimensions are dynamic.

import torch

class M(torch.nn.Module):
    def forward(self, inp):
        x = inp["x"] * 1
        y = inp["others"][0] * 2
        z = inp["others"][1] * 3
        return x, y, z

tensor_x = torch.randn(3, 4, 8)
tensor_y = torch.randn(6)
tensor_z = torch.randn(6)
args = {"x": tensor_x, "others": [tensor_y, tensor_z]}

dim = torch.export.Dim("dim")
sc = torch.export.ShapesCollection()
sc[tensor_x] = (dim, dim + 1, 8)
sc[tensor_y] = {0: dim * 2}

print(sc.dynamic_shapes(M(), (args,)))
ep = torch.export.export(M(), (args,), dynamic_shapes=sc)
print(ep)
{'inp': {'x': (Dim('dim', min=0), dim + 1, 8), 'others': [{0: 2*dim}, None]}}
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, inp_x: "f32[s96, s96 + 1, 8]", inp_others_0: "f32[2*s96]", inp_others_1: "f32[6]"):
             # File: /tmp/ipykernel_552/1070110726.py:5 in forward, code: x = inp["x"] * 1
            mul: "f32[s96, s96 + 1, 8]" = torch.ops.aten.mul.Tensor(inp_x, 1);  inp_x = None
            
             # File: /tmp/ipykernel_552/1070110726.py:6 in forward, code: y = inp["others"][0] * 2
            mul_1: "f32[2*s96]" = torch.ops.aten.mul.Tensor(inp_others_0, 2);  inp_others_0 = None
            
             # File: /tmp/ipykernel_552/1070110726.py:7 in forward, code: z = inp["others"][1] * 3
            mul_2: "f32[6]" = torch.ops.aten.mul.Tensor(inp_others_1, 3);  inp_others_1 = None
            return (mul, mul_1, mul_2)
            
Graph signature: 
    # inputs
    inp_x: USER_INPUT
    inp_others_0: USER_INPUT
    inp_others_1: USER_INPUT
    
    # outputs
    mul: USER_OUTPUT
    mul_1: USER_OUTPUT
    mul_2: USER_OUTPUT
    
Range constraints: {s96: VR[0, int_oo], s96 + 1: VR[1, int_oo], 2*s96: VR[0, int_oo]}

AdditionalInputs#

In the case where you don’t know how dynamic your inputs are, but you have an ample set of testing or profiling data that can provide a fair sense of representative inputs for a model, you can use torch.export.AdditionalInputs in place of dynamic_shapes. You can specify all the possible inputs used to trace the program, and AdditionalInputs will infer which inputs are dynamic based on which input shapes are changing.

Example:

import dataclasses
import torch
import torch.utils._pytree as pytree

@dataclasses.dataclass
class D:
    b: bool
    i: int
    f: float
    t: torch.Tensor

pytree.register_dataclass(D)

class M(torch.nn.Module):
    def forward(self, d: D):
        return d.i + d.f + d.t

input1 = (D(True, 3, 3.0, torch.ones(3)),)
input2 = (D(True, 4, 3.0, torch.ones(4)),)
ai = torch.export.AdditionalInputs()
ai.add(input1)
ai.add(input2)

print(ai.dynamic_shapes(M(), input1))
ep = torch.export.export(M(), input1, dynamic_shapes=ai)
print(ep)
{'d': [None, _DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True), None, (_DimHint(type=<_DimHintType.DYNAMIC: 3>, min=None, max=None, _factory=True),)]}
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, d_b, d_i: "Sym(s70)", d_f, d_t: "f32[s54]"):
             # File: /tmp/ipykernel_552/829931439.py:16 in forward, code: return d.i + d.f + d.t
            sym_float: "Sym(ToFloat(s70))" = torch.sym_float(d_i);  d_i = None
            add: "Sym(ToFloat(s70) + 3.0)" = sym_float + 3.0;  sym_float = None
            add_1: "f32[s54]" = torch.ops.aten.add.Tensor(d_t, add);  d_t = add = None
            return (add_1,)
            
Graph signature: 
    # inputs
    d_b: USER_INPUT
    d_i: USER_INPUT
    d_f: USER_INPUT
    d_t: USER_INPUT
    
    # outputs
    add_1: USER_OUTPUT
    
Range constraints: {s70: VR[0, int_oo], s54: VR[2, int_oo]}

Serialization#

To save the ExportedProgram, users can use the torch.export.save() and torch.export.load() APIs. The resulting file is a zipfile with a specific structure. The details of the structure are defined in the PT2 Archive Spec.

An example:

import torch

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

exported_program = torch.export.export(MyModule(), (torch.randn(5),))

torch.export.save(exported_program, 'exported_program.pt2')
saved_exported_program = torch.export.load('exported_program.pt2')

Export IR, Decompositions#

The graph produced by torch.export returns a graph containing only ATen operators, which are the basic unit of computation in PyTorch. As there are over 3000 ATen operators, export provides a way to narrow down the operator set used in the graph based on certain characteristics, creating different IRs.

By default, export produces the most generic IR which contains all ATen operators, including both functional and non-functional operators. A functional operator is one that does not contain any mutations or aliasing of the inputs. You can find a list of all ATen operators here and you can inspect if an operator is functional by checking op._schema.is_mutable.

This generic IR can be used to train in eager PyTorch Autograd.

import torch

class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 3, 1, 1)
        self.bn = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return (x,)

ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
print(ep_for_training.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
    def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
         # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias);  x = p_conv_weight = p_conv_bias = None
        
         # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1)  # type: ignore[has-type]
        add_: "i64[]" = torch.ops.aten.add_.Tensor(b_bn_num_batches_tracked, 1);  b_bn_num_batches_tracked = add_ = None
        
         # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
        batch_norm: "f32[1, 3, 3, 3]" = torch.ops.aten.batch_norm.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05, True);  conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
        return (batch_norm,)
        

However, if you want to use the IR for inference, or decrease the amount of operators being used, you can lower the graph through the ExportedProgram.run_decompositions() API. This method decomposes the ATen operators into the ones specified in the decomposition table, and functionalizes the graph.

By specifying an empty set, we’re only performing functionalization, and does not do any additional decompositions. This results in an IR which contains ~2000 operators (instead of the 3000 operators above), and is ideal for inference cases.

import torch

class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 3, 1, 1)
        self.bn = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return (x,)

ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
with torch.no_grad():
    ep_for_inference = ep_for_training.run_decompositions(decomp_table={})
print(ep_for_inference.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
    def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
         # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        conv2d: "f32[1, 3, 3, 3]" = torch.ops.aten.conv2d.default(x, p_conv_weight, p_conv_bias);  x = p_conv_weight = p_conv_bias = None
        
         # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1)  # type: ignore[has-type]
        add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1);  b_bn_num_batches_tracked = None
        
         # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
        _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(conv2d, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05);  conv2d = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
        getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
        getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
        getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
        return (getitem_3, getitem_4, add, getitem)
        

As we can see, the previously in-place operator, torch.ops.aten.add_.default has now been replaced with torch.ops.aten.add.default, a functional operator.

We can also further lower this exported program to an operator set which only contains the Core ATen Operator Set <https://pytorch.org/docs/main/torch.compiler_ir.html#core-aten-ir>__, which is a collection of only ~180 operators. This IR is optimal for backends who do not want to reimplement all ATen operators.

import torch

class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 3, 1, 1)
        self.bn = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return (x,)

ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))
with torch.no_grad():
    core_aten_ir = ep_for_training.run_decompositions(decomp_table=None)
print(core_aten_ir.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
    def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
         # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1);  x = p_conv_weight = p_conv_bias = None
        
         # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1)  # type: ignore[has-type]
        add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1);  b_bn_num_batches_tracked = None
        
         # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
        _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(convolution, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05);  convolution = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
        getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
        getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
        getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
        return (getitem_3, getitem_4, add, getitem)
        

We now see that torch.ops.aten.conv2d.default has been decomposed into torch.ops.aten.convolution.default. This is because convolution is a more “core” operator, as operations like conv1d and conv2d can be implemented using the same op.

We can also specify our own decomposition behaviors:

class M(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(1, 3, 1, 1)
        self.bn = torch.nn.BatchNorm2d(3)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return (x,)

ep_for_training = torch.export.export(M(), (torch.randn(1, 1, 3, 3),))

my_decomp_table = torch.export.default_decompositions()

def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1):
    return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups)

my_decomp_table[torch.ops.aten.conv2d.default] = my_awesome_custom_conv2d_function
my_ep = ep_for_training.run_decompositions(my_decomp_table)
print(my_ep.graph_module.print_readable(print_output=False))
class GraphModule(torch.nn.Module):
    def forward(self, p_conv_weight: "f32[3, 1, 1, 1]", p_conv_bias: "f32[3]", p_bn_weight: "f32[3]", p_bn_bias: "f32[3]", b_bn_running_mean: "f32[3]", b_bn_running_var: "f32[3]", b_bn_num_batches_tracked: "i64[]", x: "f32[1, 1, 3, 3]"):
         # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/conv.py:548 in forward, code: return self._conv_forward(input, self.weight, self.bias)
        convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(x, p_conv_weight, p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1);  x = p_conv_weight = p_conv_bias = None
        mul: "f32[1, 3, 3, 3]" = torch.ops.aten.mul.Tensor(convolution, 2);  convolution = None
        
         # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py:173 in forward, code: self.num_batches_tracked.add_(1)  # type: ignore[has-type]
        add: "i64[]" = torch.ops.aten.add.Tensor(b_bn_num_batches_tracked, 1);  b_bn_num_batches_tracked = None
        
         # File: /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py:193 in forward, code: return F.batch_norm(
        _native_batch_norm_legit_functional = torch.ops.aten._native_batch_norm_legit_functional.default(mul, p_bn_weight, p_bn_bias, b_bn_running_mean, b_bn_running_var, True, 0.1, 1e-05);  mul = p_bn_weight = p_bn_bias = b_bn_running_mean = b_bn_running_var = None
        getitem: "f32[1, 3, 3, 3]" = _native_batch_norm_legit_functional[0]
        getitem_3: "f32[3]" = _native_batch_norm_legit_functional[3]
        getitem_4: "f32[3]" = _native_batch_norm_legit_functional[4];  _native_batch_norm_legit_functional = None
        return (getitem_3, getitem_4, add, getitem)
        

Notice that instead of torch.ops.aten.conv2d.default being decomposed into torch.ops.aten.convolution.default, it is now decomposed into torch.ops.aten.convolution.default and torch.ops.aten.mul.Tensor, which matches our custom decomposition rule.

Limitations of torch.export#

As torch.export is a one-shot process for capturing a computation graph from a PyTorch program, it might ultimately run into untraceable parts of programs as it is nearly impossible to support tracing all PyTorch and Python features. In the case of torch.compile, an unsupported operation will cause a “graph break” and the unsupported operation will be run with default Python evaluation. In contrast, torch.export will require users to provide additional information or rewrite parts of their code to make it traceable.

Draft-export is a great resource for listing out graphs breaks that will be encountered when tracing the program, along with additional debug information to solve those errors.

ExportDB is also great resource for learning about the kinds of programs that are supported and unsupported, along with ways to rewrite programs to make them traceable.

TorchDynamo unsupported#

When using torch.export with strict=True, this will use TorchDynamo to evaluate the program at the Python bytecode level to trace the program into a graph. Compared to previous tracing frameworks, there will be significantly fewer rewrites required to make a program traceable, but there will still be some Python features that are unsupported. An option to get past dealing with this graph breaks is by using non-strict export through changing the strict flag to strict=False.

Data/Shape-Dependent Control Flow#

Graph breaks can also be encountered on data-dependent control flow (if x.shape[0] > 2) when shapes are not being specialized, as a tracing compiler cannot possibly deal with without generating code for a combinatorially exploding number of paths. In such cases, users will need to rewrite their code using special control flow operators. Currently, we support torch.cond to express if-else like control flow (more coming soon!).

You can also refer to this tutorial for more ways of addressing data-dependent errors.

Missing Fake/Meta Kernels for Operators#

When tracing, a FakeTensor kernel (aka meta kernel) is required for all operators. This is used to reason about the input/output shapes for this operator.

Please see this tutorial for more details.

In the unfortunate case where your model uses an ATen operator that is does not have a FakeTensor kernel implementation yet, please file an issue.