Runtime API#

torch_tensorrt.runtime exposes a set of context managers and utility functions for controlling TRT engine execution behavior after compilation. These APIs let you opt in to CUDA graph capture, pre-allocated output buffers, and weight streaming without recompiling the engine.


enable_cudagraphs#

Wraps a compiled torch.fx.GraphModule so that forward calls are captured and replayed as a CUDA graph. See CUDAGraphs and the Output Allocator for a full explanation of the two modes (whole-graph vs per-subgraph).

import torch_tensorrt

trt_model = torch_tensorrt.compile(model, ir="dynamo", arg_inputs=inputs)

with torch_tensorrt.runtime.enable_cudagraphs(trt_model) as cg_model:
    # First call records the graph; subsequent calls replay it
    output = cg_model(*inputs)
    output = cg_model(*inputs)  # fast replay
# CUDA graph recording is torn down on exit; trt_model is restored

Mode selection is automatic: if the model contains PyTorch fallback subgraphs, the whole graph is captured using CudaGraphsTorchTensorRTModule (whole-graph mode); if the model is pure TRT, per-subgraph CUDA graph capture is used.

Whole-graph CUDA graph capture requires fixed input shapes. If your model uses data-dependent-shape ops, use enable_output_allocator instead (incompatible with CUDA graphs).


enable_pre_allocated_outputs#

Allocates output tensors once at the start of the context and reuses them for every subsequent forward call. This eliminates the overhead of output buffer allocation on the critical path and is useful for latency-sensitive inference loops.

with torch_tensorrt.runtime.enable_pre_allocated_outputs(trt_model) as pre_model:
    for batch in dataloader:
        output = pre_model(batch.cuda())
        # output is valid until the next call
# Output pre-allocation is released on exit

Warning

The output tensors are overwritten in-place on each call. Copy them before the next forward pass if you need to retain the values:

with torch_tensorrt.runtime.enable_pre_allocated_outputs(trt_model) as pre_model:
    result = pre_model(*inputs).clone()  # clone before next call

enable_output_allocator#

Activates TRT’s dynamic output allocator for models with data-dependent output shapes — ops like nonzero, unique, masked_select, or nms whose output size is not known at compile time.

with torch_tensorrt.runtime.enable_output_allocator(trt_model) as dds_model:
    # Works with variable-length outputs
    indices = dds_model(mask_tensor)

On entry, use_output_allocator is enabled on each TRT submodule; on exit it is disabled and the module reverts to standard static-allocation execution.

Note

enable_output_allocator is incompatible with CUDA graphs — do not nest it inside enable_cudagraphs.


weight_streaming#

See Resource Management for a complete guide. In brief:

with torch_tensorrt.runtime.weight_streaming(trt_model) as ctx:
    # Limit GPU memory used for weights to 1 GiB
    ctx.device_budget = 1 * 1024**3
    output = trt_model(*inputs)
# Budget is restored to the original value on exit

weight_streaming requires the model to be compiled with enable_weight_streaming=True and use_explicit_typing=True.


set_cudagraphs_mode / get_cudagraphs_mode#

Low-level global CUDA graph mode control (useful for testing or profiling):

from torch_tensorrt.runtime import (
    set_cudagraphs_mode,
    get_cudagraphs_mode,
    get_whole_cudagraphs_mode,
)

# Enable per-subgraph CUDA graphs globally
set_cudagraphs_mode(True)
print(get_cudagraphs_mode())       # True
print(get_whole_cudagraphs_mode()) # False

# Enable whole-graph CUDA graphs via the context manager
with torch_tensorrt.runtime.enable_cudagraphs(trt_model) as cg_model:
    print(get_whole_cudagraphs_mode())  # True (if model has fallback subgraphs)

Prefer enable_cudagraphs over manual set_cudagraphs_mode calls — the context manager handles mode restoration on exit automatically.


set_multi_device_safe_mode#

Enables a thread-safe mode for running the same compiled TRT module on multiple CUDA devices concurrently. When enabled, the runtime serializes device context switches before each engine execution:

torch_tensorrt.runtime.set_multi_device_safe_mode(True)

# Now safe to call trt_model from multiple threads on different devices
output = trt_model(*inputs)

torch_tensorrt.runtime.set_multi_device_safe_mode(False)

This incurs a small overhead per forward call. Only enable it when genuinely running across multiple devices from the same Python process.