Autocast and Precision Management#

Note

This page documents the design for rule-based autocast in Torch-TensorRT. Original design discussion: RFC #3869.

Background#

TensorRT historically supported weak typing — the builder was allowed to select the lowest-precision kernel for each layer (e.g. downcast fp32 inputs to fp16 automatically). This behavior was deprecated in TensorRT 10.x. Torch-TensorRT replaces it with a PyTorch-native approach: a lowering pass that inserts explicit cast nodes before layers that should run in lower precision, following rules similar to NVIDIA ModelOpt Autocast.

Three Precision Modes#

Mode

use_explicit_typing

enable_autocast

User-defined per-layer precision

True

False

TRT weak typing (deprecated)

False

False

Autocast (rule-based)

True

True

The autocast mode is the focus of this page.

User API#

import torch_tensorrt

trt_mod = torch_tensorrt.compile(
    exported_program.module(),
    arg_inputs=inputs,
    min_block_size=1,
    use_explicit_typing=True,
    enable_autocast=True,
    low_precision_type=torch.float16,   # target low-precision dtype
    nodes_to_exclude={"^conv2d$"},       # regex patterns → keep fp32
    targets_to_exclude={},
    data_max=512,                        # threshold for data-sensitive ops
    max_depth_of_reduction=None,
)

Autocast-aware PyTorch code (torch.autocast context managers) is respected: ops inside a torch.autocast(fp32) context remain fp32 even when the global low-precision type is fp16.

Internal Implementation#

Stage 1 — Rule-Based Node Classifier#

Each node in the FX graph is assigned a precision (high or low) according to a predefined ruleset. A node stays in fp32 (high precision) if any of the following rules fire:

  • It is inside a torch.autocast context (wrap_with_autocast node).

  • Its name matches a nodes_to_exclude regex.

  • Its target matches targets_to_exclude.

  • It performs a reduction over a large domain (max_depth_of_reduction heuristic).

  • The maximum data value flowing through it exceeds data_max (guards against overflow in fp16).

  • It is an operator.getitem node (output unpacking).

All other nodes are assigned the low_precision_type.

For the example CNN in the RFC, the classifier output is:

Low Precision:  relu, max_pool2d, conv2d_1, relu_1, max_pool2d_1, flatten
High Precision: conv2d  (matches nodes_to_exclude="^conv2d$")

Stage 2 — Graph Modification (Pre-Lowering Pass)#

The pass is implemented in py/torch_tensorrt/dynamo/lowering/passes/rule_based_autocast.py. It inserts aten._to_copy (cast) nodes before each low-precision op to convert its inputs to the target dtype. High-precision nodes are left unmodified.

Example transformed graph fragment:

%convolution   : fp32  (conv2d excluded from autocast)
%_to_copy      : cast convolution → fp16
%relu          : fp16
%max_pool2d    : fp16
%_to_copy_1    : cast back → fp32  (before fc1 which is in autocast(fp32) context)
%linear        : fp32
%add           : fp32

The resulting graph has explicit dtype annotations on every edge, satisfying TensorRT’s strong-typing requirement without relying on deprecated weak-typing.

Interaction with torch.autocast#

When a user wraps ops in torch.autocast:

with torch.autocast(x.device.type, enabled=True, dtype=torch.float32):
    x = self.fc1(x)
    x = torch.add(x, x)

the exported graph contains a torch.ops.higher_order.wrap_with_autocast node. The classifier detects this and marks all ops inside the context as high-precision (fp32), regardless of the global low_precision_type.