torch.export-based ONNX Exporter#
Created On: Jun 10, 2025 | Last Updated On: Aug 22, 2025
Overview#
torch.export engine is leveraged to produce a traced graph representing only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion. The resulting traced graph (1) produces normalized operators in the functional ATen operator set (as well as any user-specified custom operators), (2) has eliminated all Python control flow and data structures (with certain exceptions), and (3) records the set of shape constraints needed to show that this normalization and control-flow elimination is sound for future inputs, before it is finally translated into an ONNX graph.
In addition, during the export process, memory usage is significantly reduced.
Dependencies#
The ONNX exporter depends on extra Python packages:
They can be installed through pip:
pip install --upgrade onnx onnxscript
onnxruntime can then be used to execute the model on a large variety of processors.
A simple example#
See below a demonstration of exporter API in action with a simple Multilayer Perceptron (MLP) as example:
import torch
import torch.nn as nn
class MLPModel(nn.Module):
def __init__(self):
super().__init__()
self.fc0 = nn.Linear(8, 8, bias=True)
self.fc1 = nn.Linear(8, 4, bias=True)
self.fc2 = nn.Linear(4, 2, bias=True)
self.fc3 = nn.Linear(2, 2, bias=True)
self.fc_combined = nn.Linear(8 + 8 + 8, 8, bias=True) # Combine all inputs
def forward(self, tensor_x: torch.Tensor, input_dict: dict, input_list: list):
"""
Forward method that requires all inputs:
- tensor_x: A direct tensor input.
- input_dict: A dictionary containing the tensor under the key 'tensor_x'.
- input_list: A list where the first element is the tensor.
"""
# Extract tensors from inputs
dict_tensor = input_dict['tensor_x']
list_tensor = input_list[0]
# Combine all inputs into a single tensor
combined_tensor = torch.cat([tensor_x, dict_tensor, list_tensor], dim=1)
# Process the combined tensor through the layers
combined_tensor = self.fc_combined(combined_tensor)
combined_tensor = torch.sigmoid(combined_tensor)
combined_tensor = self.fc0(combined_tensor)
combined_tensor = torch.sigmoid(combined_tensor)
combined_tensor = self.fc1(combined_tensor)
combined_tensor = torch.sigmoid(combined_tensor)
combined_tensor = self.fc2(combined_tensor)
combined_tensor = torch.sigmoid(combined_tensor)
output = self.fc3(combined_tensor)
return output
model = MLPModel()
# Example inputs
tensor_input = torch.rand((97, 8), dtype=torch.float32)
dict_input = {'tensor_x': torch.rand((97, 8), dtype=torch.float32)}
list_input = [torch.rand((97, 8), dtype=torch.float32)]
# The input_names and output_names are used to identify the inputs and outputs of the ONNX model
input_names = ['tensor_input', 'tensor_x', 'list_input_index_0']
output_names = ['output']
# Exporting the model with all required inputs
onnx_program = torch.onnx.export(model,(tensor_input, dict_input, list_input), dynamic_shapes=({0: "batch_size"},{"tensor_x": {0: "batch_size"}},[{0: "batch_size"}]), input_names=input_names, output_names=output_names, dynamo=True,)
# Check the exported ONNX model is dynamic
assert onnx_program.model.graph.inputs[0].shape == ("batch_size", 8)
assert onnx_program.model.graph.inputs[1].shape == ("batch_size", 8)
assert onnx_program.model.graph.inputs[2].shape == ("batch_size", 8)
As the code above shows, all you need is to provide torch.onnx.export()
with an instance of the model and its input.
The exporter will then return an instance of torch.onnx.ONNXProgram
that contains the exported ONNX graph along with extra information.
The in-memory model available through onnx_program.model_proto
is an onnx.ModelProto
object in compliance with the ONNX IR spec.
The ONNX model may then be serialized into a Protobuf file using the torch.onnx.ONNXProgram.save()
API.
onnx_program.save("mlp.onnx")
When the conversion fails#
Function torch.onnx.export()
should be called a second time with
parameter report=True
. A markdown report is generated to help the user
to resolve the issue.
Metadata#
During ONNX export, each ONNX node is annotated with metadata that helps trace its origin and context from the original PyTorch model. This metadata is useful for debugging, model inspection, and understanding the mapping between PyTorch and ONNX graphs.
The following metadata fields are added to each ONNX node:
namespace
A string representing the hierarchical namespace of the node, consisting of a stack trace of modules/methods.
Example:
__main__.SimpleAddModel/add: aten.add.Tensor
pkg.torch.onnx.class_hierarchy
A list of class names representing the hierarchy of modules leading to this node.
Example:
['__main__.SimpleAddModel', 'aten.add.Tensor']
pkg.torch.onnx.fx_node
The string representation of the original FX node, including its name, number of consumers, the targeted torch op, arguments, and keyword arguments.
Example:
%cat : [num_users=1] = call_function[target=torch.ops.aten.cat.default](args = ([%tensor_x, %input_dict_tensor_x, %input_list_0], 1), kwargs = {})
pkg.torch.onnx.name_scopes
A list of name scopes (methods) representing the path to this node in the PyTorch model.
Example:
['', 'add']
pkg.torch.onnx.stack_trace
The stack trace from the original code where this node was created, if available.
Example:
File "simpleadd.py", line 7, in forward return torch.add(x, y)
These metadata fields are stored in the metadata_props attribute of each ONNX node and can be inspected using Netron or programmatically.
The overall ONNX graph has the following metadata_props
:
pkg.torch.export.ExportedProgram.graph_signature
This property contains a string representation of the graph_signature from the original PyTorch ExportedProgram. The graph signature describes the structure of the model’s inputs and outputs and how they map to the ONNX graph. The inputs are defined as
InputSpec
objects, which include the kind of input (e.g.,InputKind.PARAMETER
for parameters,InputKind.USER_INPUT
for user-defined inputs), the argument name, the target (which can be a specific node in the model), and whether the input is persistent. The outputs are defined asOutputSpec
objects, which specify the kind of output (e.g.,OutputKind.USER_OUTPUT
) and the argument name.To read more about the graph signature, please see the torch.export for more information.
pkg.torch.export.ExportedProgram.range_constraints
This property contains a string representation of any range constraints that were present in the original PyTorch ExportedProgram. Range constraints specify valid ranges for symbolic shapes or values in the model, which can be important for models that use dynamic shapes or symbolic dimensions.
Example:
s0: VR[2, int_oo]
, which indicates that the size of the input tensor must be at least 2.To read more about range constraints, please see the torch.export for more information.
Each input value in the ONNX graph may have the following metadata property:
pkg.torch.export.graph_signature.InputSpec.kind
The kind of input, as defined by PyTorch’s InputKind enum.
Example values:
“USER_INPUT”: A user-provided input to the model.
“PARAMETER”: A model parameter (e.g., weight).
“BUFFER”: A model buffer (e.g., running mean in BatchNorm).
“CONSTANT_TENSOR”: A constant tensor argument.
“CUSTOM_OBJ”: A custom object input.
“TOKEN”: A token input.
pkg.torch.export.graph_signature.InputSpec.persistent
Indicates whether the input is persistent (i.e., should be saved as part of the model’s state).
Example values:
“True”
“False”
Each output value in the ONNX graph may have the following metadata property:
pkg.torch.export.graph_signature.OutputSpec.kind
The kind of input, as defined by PyTorch’s OutputKind enum.
Example values:
“USER_OUTPUT”: A user-visible output.
“LOSS_OUTPUT”: A loss value output.
“BUFFER_MUTATION”: Indicates a buffer was mutated.
“GRADIENT_TO_PARAMETER”: Gradient output for a parameter.
“GRADIENT_TO_USER_INPUT”: Gradient output for a user input.
“USER_INPUT_MUTATION”: Indicates a user input was mutated.
“TOKEN”: A token output.
Each initialized value, input, output has the following metadata:
pkg.torch.onnx.original_node_name
The original name of the node in the PyTorch FX graph that produced this value in the case where the value was renamed. This helps trace initializers back to their source in the original model.
Example:
fc1.weight
API Reference#
- torch.onnx.export(model, args=(), f=None, *, kwargs=None, verbose=None, input_names=None, output_names=None, opset_version=None, dynamo=True, external_data=True, dynamic_shapes=None, custom_translation_table=None, report=False, optimize=True, verify=False, profile=False, dump_exported_program=False, artifacts_dir='.', fallback=False, export_params=True, keep_initializers_as_inputs=False, dynamic_axes=None, training=<TrainingMode.EVAL: 0>, operator_export_type=<OperatorExportTypes.ONNX: 0>, do_constant_folding=True, custom_opsets=None, export_modules_as_functions=False, autograd_inlining=True)[source]#
Exports a model into ONNX format.
Setting
dynamo=True
enables the new ONNX export logic which is based ontorch.export.ExportedProgram
and a more modern set of translation logic. This is the recommended and default way to export models to ONNX.When
dynamo=True
:The exporter tries the following strategies to get an ExportedProgram for conversion to ONNX.
If the model is already an ExportedProgram, it will be used as-is.
Use
torch.export.export()
and setstrict=False
.Use
torch.export.export()
and setstrict=True
.
- Parameters
model (torch.nn.Module | torch.export.ExportedProgram | torch.jit.ScriptModule | torch.jit.ScriptFunction) – The model to be exported.
args (tuple[Any, ...]) – Example positional inputs. Any non-Tensor arguments will be hard-coded into the exported model; any Tensor arguments will become inputs of the exported model, in the order they occur in the tuple.
f (str | os.PathLike | None) – Path to the output ONNX model file. E.g. “model.onnx”. This argument is kept for backward compatibility. It is recommended to leave unspecified (None) and use the returned
torch.onnx.ONNXProgram
to serialize the model to a file instead.kwargs (dict[str, Any] | None) – Optional example keyword inputs.
verbose (bool | None) – Whether to enable verbose logging.
input_names (Sequence[str] | None) – names to assign to the input nodes of the graph, in order.
output_names (Sequence[str] | None) – names to assign to the output nodes of the graph, in order.
opset_version (int | None) – The version of the default (ai.onnx) opset to target. You should set
opset_version
according to the supported opset versions of the runtime backend or compiler you want to run the exported model with. Leave as default (None
) to use the recommended version, or refer to the ONNX operators documentation for more information.dynamo (bool) – Whether to export the model with
torch.export
ExportedProgram instead of TorchScript.external_data (bool) – Whether to save the model weights as an external data file. This is required for models with large weights that exceed the ONNX file size limit (2GB). When False, the weights are saved in the ONNX file with the model architecture.
dynamic_shapes (dict[str, Any] | tuple[Any, ...] | list[Any] | None) – A dictionary or a tuple of dynamic shapes for the model inputs. Refer to
torch.export.export()
for more details. This is only used (and preferred) when dynamo is True. Note that dynamic_shapes is designed to be used when the model is exported with dynamo=True, while dynamic_axes is used when dynamo=False.custom_translation_table (dict[Callable, Callable | Sequence[Callable]] | None) – A dictionary of custom decompositions for operators in the model. The dictionary should have the callable target in the fx Node as the key (e.g.
torch.ops.aten.stft.default
), and the value should be a function that builds that graph using ONNX Script. This option is only valid when dynamo is True.report (bool) – Whether to generate a markdown report for the export process. This option is only valid when dynamo is True.
optimize (bool) – Whether to optimize the exported model. This option is only valid when dynamo is True. Default is True.
verify (bool) – Whether to verify the exported model using ONNX Runtime. This option is only valid when dynamo is True.
profile (bool) – Whether to profile the export process. This option is only valid when dynamo is True.
dump_exported_program (bool) – Whether to dump the
torch.export.ExportedProgram
to a file. This is useful for debugging the exporter. This option is only valid when dynamo is True.artifacts_dir (str | os.PathLike) – The directory to save the debugging artifacts like the report and the serialized exported program. This option is only valid when dynamo is True.
fallback (bool) – Whether to fallback to the TorchScript exporter if the dynamo exporter fails. This option is only valid when dynamo is True. When fallback is enabled, It is recommended to set dynamic_axes even when dynamic_shapes is provided.
export_params (bool) –
When ``f`` is specified: If false, parameters (weights) will not be exported.
You can also leave it unspecified and use the returned
torch.onnx.ONNXProgram
to control how initializers are treated when serializing the model.keep_initializers_as_inputs (bool) –
When ``f`` is specified: If True, all the initializers (typically corresponding to model weights) in the exported graph will also be added as inputs to the graph. If False, then initializers are not added as inputs to the graph, and only the user inputs are added as inputs.
Set this to True if you intend to supply model weights at runtime. Set it to False if the weights are static to allow for better optimizations (e.g. constant folding) by backends/runtimes.
You can also leave it unspecified and use the returned
torch.onnx.ONNXProgram
to control how initializers are treated when serializing the model.dynamic_axes (Mapping[str, Mapping[int, str]] | Mapping[str, Sequence[int]] | None) –
Prefer specifying
dynamic_shapes
whendynamo=True
and whenfallback
is not enabled.By default the exported model will have the shapes of all input and output tensors set to exactly match those given in
args
. To specify axes of tensors as dynamic (i.e. known only at run-time), setdynamic_axes
to a dict with schema:- KEY (str): an input or output name. Each name must also be provided in
input_names
or output_names
.
- KEY (str): an input or output name. Each name must also be provided in
- VALUE (dict or list): If a dict, keys are axis indices and values are axis names. If a
list, each element is an axis index.
For example:
class SumModule(torch.nn.Module): def forward(self, x): return torch.sum(x, dim=1) torch.onnx.export( SumModule(), (torch.ones(2, 2),), "onnx.pb", input_names=["x"], output_names=["sum"], )
Produces:
input { name: "x" ... shape { dim { dim_value: 2 # axis 0 } dim { dim_value: 2 # axis 1 ... output { name: "sum" ... shape { dim { dim_value: 2 # axis 0 ...
While:
torch.onnx.export( SumModule(), (torch.ones(2, 2),), "onnx.pb", input_names=["x"], output_names=["sum"], dynamic_axes={ # dict value: manually named axes "x": {0: "my_custom_axis_name"}, # list value: automatic names "sum": [0], }, )
Produces:
input { name: "x" ... shape { dim { dim_param: "my_custom_axis_name" # axis 0 } dim { dim_value: 2 # axis 1 ... output { name: "sum" ... shape { dim { dim_param: "sum_dynamic_axes_1" # axis 0 ...
training (_C_onnx.TrainingMode) – Deprecated option. Instead, set the training mode of the model before exporting.
operator_export_type (_C_onnx.OperatorExportTypes) – Deprecated option. Only ONNX is supported.
do_constant_folding (bool) – Deprecated option.
custom_opsets (Mapping[str, int] | None) – Deprecated option.
export_modules_as_functions (bool | Collection[type[torch.nn.Module]]) – Deprecated option.
autograd_inlining (bool) – Deprecated option.
- Returns
torch.onnx.ONNXProgram
if dynamo is True, otherwise None.- Return type
ONNXProgram | None
Changed in version 2.6: training is now deprecated. Instead, set the training mode of the model before exporting. operator_export_type is now deprecated. Only ONNX is supported. do_constant_folding is now deprecated. It is always enabled. export_modules_as_functions is now deprecated. autograd_inlining is now deprecated.
Changed in version 2.7: optimize is now True by default.
Changed in version 2.9: dynamo is now True by default.
- class torch.onnx.ONNXProgram(model, exported_program)#
A class to represent an ONNX program that is callable with torch tensors.
- Variables
model – The ONNX model as an ONNX IR model object.
exported_program – The exported program that produced the ONNX model.
- apply_weights(state_dict)[source]#
Apply the weights from the specified state dict to the ONNX model.
Use this method to replace FakeTensors or other weights.
- Parameters
state_dict (dict[str, torch.Tensor]) – The state dict containing the weights to apply to the ONNX model.
- compute_values(value_names, args=(), kwargs=None)[source]#
Compute the values of the specified names in the ONNX model.
This method is used to compute the values of the specified names in the ONNX model. The values are returned as a dictionary mapping names to tensors.
- initialize_inference_session(initializer=<function _ort_session_initializer>)[source]#
Initialize the ONNX Runtime inference session.
- property model_proto: ModelProto#
Return the ONNX
ModelProto
object.
- optimize()[source]#
Optimize the ONNX model.
This method optimizes the ONNX model by performing constant folding and eliminating redundancies in the graph. The optimization is done in-place.
- release()[source]#
Release the inference session.
You may call this method to release the resources used by the inference session.
- save(destination, *, include_initializers=True, keep_initializers_as_inputs=False, external_data=None)[source]#
Save the ONNX model to the specified destination.
When
external_data
isTrue
or the model is larger than 2GB, the weights are saved as external data in a separate file.Initializer (model weights) serialization behaviors:
include_initializers=True
,keep_initializers_as_inputs=False
(default): The initializers are included in the saved model.include_initializers=True
,keep_initializers_as_inputs=True
: The initializers are included in the saved model and kept as model inputs. Choose this option if you want the ability to override the model weights during inference.include_initializers=False
,keep_initializers_as_inputs=False
: The initializers are not included in the saved model and are not listed as model inputs. Choose this option if you want to attach the initializers to the ONNX model in a separate, post-processing, step.include_initializers=False
,keep_initializers_as_inputs=True
: The initializers are not included in the saved model but are listed as model inputs. Choose this option if you want to supply the initializers during inference and want to minimize the size of the saved model.
- Parameters
destination (str | os.PathLike) – The path to save the ONNX model to.
include_initializers (bool) – Whether to include the initializers in the saved model.
keep_initializers_as_inputs (bool) – Whether to keep the initializers as inputs in the saved model. If True, the initializers are added as inputs to the model which means they can be overwritten. by providing the initializers as model inputs.
external_data (Optional[bool]) – Whether to save the weights as external data in a separate file.
- Raises
TypeError – If
external_data
isTrue
anddestination
is not a file path.
- torch.onnx.is_in_onnx_export()[source]#
Returns whether it is in the middle of ONNX export.
- Return type
- class torch.onnx.OnnxExporterError#
Errors raised by the ONNX exporter. This is the base class for all exporter errors.