Tracing Models with torch_tensorrt.dynamo.trace#
torch_tensorrt.dynamo.trace() is a thin wrapper around
torch.export.export that automatically injects Torch-TensorRT’s operator
decompositions at export time, producing an ExportedProgram that is better
suited for TRT compilation than a vanilla export.
Why Use trace Instead of torch.export.export Directly?#
Both paths produce a torch.export.ExportedProgram. The difference is in the
decompositions applied:
``torch.export.export`` — applies a default set of decompositions, some of which produce composite ops (e.g.
aten.linear,aten.embedding) that Torch-TensorRT must decompose again in its own lowering pass.``torch_tensorrt.dynamo.trace`` — applies Torch-TensorRT’s curated ATen decomposition set upfront, producing a graph closer to the Core ATen opset that the converter library expects. This can reduce lowering time and increase TRT coverage for certain model architectures.
For most models the difference is minor. Use trace when:
You are already using
torch_tensorrtat the export step and want a single API.Your model has ops that Torch-TensorRT’s lowering pass handles better when decomposed at export time.
You want the tracing step to validate Torch-TensorRT compatibility before compilation.
Basic Usage#
import torch
import torch_tensorrt
model = MyModel().eval().cuda()
inputs = [torch.randn(1, 3, 224, 224).cuda()]
exp_program = torch_tensorrt.dynamo.trace(model, arg_inputs=inputs)
trt_gm = torch_tensorrt.dynamo.compile(exp_program, arg_inputs=inputs)
Dynamic Shapes#
Pass torch_tensorrt.Input objects to specify dynamic dimensions. trace
extracts torch.export.Dim constraints from the min/opt/max shapes and
passes them to torch.export.export automatically:
from torch_tensorrt import Input
dyn_inputs = [
Input(
min_shape=(1, 3, 224, 224),
opt_shape=(8, 3, 224, 224),
max_shape=(16, 3, 224, 224),
dtype=torch.float32,
name="x",
)
]
exp_program = torch_tensorrt.dynamo.trace(model, arg_inputs=dyn_inputs)
trt_gm = torch_tensorrt.dynamo.compile(exp_program, arg_inputs=dyn_inputs)
Under the hood, trace calls get_dynamic_shapes_args / get_dynamic_shapes_kwargs
to build the dynamic_shapes dict required by torch.export.export, using
the Input’s min/max range to construct a torch.export.Dim for each dynamic
axis. Axes where min == max are treated as static.
Keyword Inputs#
# Model with a kwarg: def forward(self, x, *, mask=None)
exp_program = torch_tensorrt.dynamo.trace(
model,
arg_inputs=(torch.randn(1, 512).cuda(),),
kwarg_inputs={"mask": torch.ones(1, 512, dtype=torch.bool).cuda()},
)
Both arg_inputs and kwarg_inputs accept torch.Tensor (static) or
torch_tensorrt.Input (dynamic) values.
Non-Strict Tracing#
trace uses strict=False by default when calling torch.export.export.
Non-strict mode allows data-dependent control flow and Python-side tensor operations
that strict export would reject. If you need strict export semantics, pass
strict=True:
exp_program = torch_tensorrt.dynamo.trace(model, arg_inputs=inputs, strict=True)
trace vs compile Entry Points#
|
|
|
|---|---|---|
ExportedProgram reuse |
Yes — export once, compile multiple times with different settings |
No — re-exports on every compile call |
Decompositions at export |
Torch-TRT curated set |
Applied during compile’s lowering pass |
Inspect pre-compilation graph |
Yes — inspect |
No — graph only visible inside compile |
|
N/A |
Supported via |