CUDAGraphs#
Note
This page documents the design for CUDAGraphs support in Torch-TensorRT. Original design discussion: RFC #2736.
Goal#
Reduce inference latency by hiding CUDA kernel launch overhead using CUDA Graphs. A CUDA Graph records a sequence of GPU operations and replays them as a single batched launch, eliminating the per-kernel CPU overhead that can dominate for small models or short-batch workloads.
User API#
import torch_tensorrt
# torch.export + dynamo.compile path
trt_gm = torch_tensorrt.compile(
model,
arg_inputs=inputs,
use_cuda_graph=True,
)
# torch.export path — explicit CUDAGraphs module
import torch_tensorrt.runtime as runtime
with runtime.enable_cudagraphs(trt_gm) as cudagraphs_module:
output = cudagraphs_module(*inputs)
Limitations#
CUDAGraphs require static input shapes. A graph recorded for one input shape is invalid for a different shape and must be re-recorded.
Memory addresses must be stable across replay. Input and output buffers must be allocated before recording and reused on every replay (no allocation inside the graph).
Not compatible with models that have data-dependent shapes (DDS ops such as
nonzero).The
torch.compile(JIT) path supports recompilation on shape changes via PyTorch guards;torch.export(AOT) requires a static shape or explicit per-shape re-recording.
Internal Implementation#
Recording Strategy#
Recording is done post engine build as a wrapper over the execute-engine call. For each TRT subgraph the runtime:
Performs a warmup run outside the graph to initialize internal TRT state.
Begins CUDA stream capture with
cudaStreamBeginCapture.Calls
execute_async_v3(the TensorRT asynchronous execution call).Ends capture with
cudaStreamEndCaptureto obtain acudaGraph_t.Instantiates the graph with
cudaGraphInstantiate.
Subsequent inference launches the instantiated graph instead of calling
execute_async_v3 directly:
# Pseudocode reflecting the recording pattern from TRT docs
from cuda import cudart
err, stream = cudart.cudaStreamCreate()
# Warmup: required to update TRT internal state for the input shape
context.execute_async_v3(stream)
# Capture
cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureModeGlobal)
context.execute_async_v3(stream)
err, graph = cudart.cudaStreamEndCapture(stream)
err, instance = cudart.cudaGraphInstantiate(graph, 0)
# Replay
for _ in range(iterations):
cudart.cudaGraphLaunch(instance, stream)
cudart.cudaStreamSynchronize(stream)
Graph Storage#
Each runtime module (both C++ TorchTensorRTModule and Python
PythonTorchTensorRTModule) stores a cudaGraphExec_t instance. When
use_cuda_graph=True is set at compile time the runtime records one graph
per engine for the first input shape encountered.
For ir="torch_compile" (JIT path), guard-triggered recompilation provides
dynamic shape support: a new graph is recorded whenever a guard fails and the
model is retraced. This mirrors PyTorch Inductor’s mode="reduce-overhead".
_CudaGraphsTorchTensorRTModule#
_CudaGraphsTorchTensorRTModule is a dedicated Python runtime wrapper that
combines TRT execution with CUDAGraph management. It:
Pre-allocates a static output buffer sized for the compiled output shapes.
Records the full forward pass (across all TRT subgraphs and interleaved PyTorch ops) into a single
torch.cuda.CUDAGraph.On each forward call copies inputs into the pre-allocated input buffer and replays the graph.
This gives the lowest-latency execution path for static-shape models.