Autocast and Precision Management#
Note
This page documents the design for rule-based autocast in Torch-TensorRT. Original design discussion: RFC #3869.
Background#
TensorRT historically supported weak typing — the builder was allowed to select the lowest-precision kernel for each layer (e.g. downcast fp32 inputs to fp16 automatically). This behavior was deprecated in TensorRT 10.x. Torch-TensorRT replaces it with a PyTorch-native approach: a lowering pass that inserts explicit cast nodes before layers that should run in lower precision, following rules similar to NVIDIA ModelOpt Autocast.
Three Precision Modes#
Mode |
|
|
|---|---|---|
User-defined per-layer precision |
|
|
TRT weak typing (deprecated) |
|
|
Autocast (rule-based) |
|
|
The autocast mode is the focus of this page.
User API#
import torch_tensorrt
trt_mod = torch_tensorrt.compile(
exported_program.module(),
arg_inputs=inputs,
min_block_size=1,
use_explicit_typing=True,
enable_autocast=True,
low_precision_type=torch.float16, # target low-precision dtype
nodes_to_exclude={"^conv2d$"}, # regex patterns → keep fp32
targets_to_exclude={},
data_max=512, # threshold for data-sensitive ops
max_depth_of_reduction=None,
)
Autocast-aware PyTorch code (torch.autocast context managers) is respected:
ops inside a torch.autocast(fp32) context remain fp32 even when the global
low-precision type is fp16.
Internal Implementation#
Stage 1 — Rule-Based Node Classifier#
Each node in the FX graph is assigned a precision (high or low) according to a predefined ruleset. A node stays in fp32 (high precision) if any of the following rules fire:
It is inside a
torch.autocastcontext (wrap_with_autocastnode).Its name matches a
nodes_to_excluderegex.Its target matches
targets_to_exclude.It performs a reduction over a large domain (
max_depth_of_reductionheuristic).The maximum data value flowing through it exceeds
data_max(guards against overflow in fp16).It is an
operator.getitemnode (output unpacking).
All other nodes are assigned the low_precision_type.
For the example CNN in the RFC, the classifier output is:
Low Precision: relu, max_pool2d, conv2d_1, relu_1, max_pool2d_1, flatten
High Precision: conv2d (matches nodes_to_exclude="^conv2d$")
Stage 2 — Graph Modification (Pre-Lowering Pass)#
The pass is implemented in
py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py.
It inserts aten._to_copy (cast) nodes before each low-precision op to convert
its inputs to the target dtype. High-precision nodes are left unmodified.
Example transformed graph fragment:
%convolution : fp32 (conv2d excluded from autocast)
%_to_copy : cast convolution → fp16
%relu : fp16
%max_pool2d : fp16
%_to_copy_1 : cast back → fp32 (before fc1 which is in autocast(fp32) context)
%linear : fp32
%add : fp32
The resulting graph has explicit dtype annotations on every edge, satisfying TensorRT’s strong-typing requirement without relying on deprecated weak-typing.
Interaction with torch.autocast#
When a user wraps ops in torch.autocast:
with torch.autocast(x.device.type, enabled=True, dtype=torch.float32):
x = self.fc1(x)
x = torch.add(x, x)
the exported graph contains a torch.ops.higher_order.wrap_with_autocast node.
The classifier detects this and marks all ops inside the context as high-precision
(fp32), regardless of the global low_precision_type.