Distributed Inference#
Torch-TensorRT supports two distributed inference patterns:
Data parallel — the same TRT-compiled model runs independently on multiple GPUs, each processing a different shard of the batch.
Tensor parallel — the model is sharded across GPUs using
torch.distributedand DTensors; Torch-TensorRT compiles each per-GPU shard.
Both patterns produce standard TRT engines; Torch-TensorRT handles only the compilation step — data movement and process coordination remain the responsibility of the distributed framework you choose.
Data Parallel Inference#
The simplest path is to replicate the compiled model on each GPU process using Accelerate or PyTorch DDP. Each process compiles its own TRT engine independently.
# Run with: torchrun --nproc_per_node=<N> script.py
import torch
import torch_tensorrt
from accelerate import PartialState
distributed_state = PartialState()
device = distributed_state.device # GPU assigned to this rank
model = MyModel().eval().to(device)
inputs = [torch.randn(1, 3, 224, 224).to(device)]
with distributed_state.split_between_processes(inputs) as local_inputs:
trt_model = torch_tensorrt.compile(
model,
ir="dynamo",
arg_inputs=local_inputs,
use_explicit_typing=True, # enabled_precisions deprecated; cast model/inputs to target dtype
min_block_size=1,
)
output = trt_model(*local_inputs)
See the data_parallel_gpt2.py and data_parallel_stable_diffusion.py examples for complete Accelerate-based workflows with GPT-2 and Stable Diffusion.
Tensor Parallel Inference#
Tensor parallelism shards model weights across GPUs. Each GPU holds a slice of every weight tensor and participates in collective operations (all-gather, reduce-scatter) to execute the full model forward pass.
Torch-TensorRT compiles the per-GPU shard of the model. Use
use_distributed_mode_trace=True to switch the export path to aot_autograd, which
handles DTensor inputs correctly:
import torch
import torch.distributed as dist
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)
import torch_tensorrt
# --- Initialize process group ---
dist.init_process_group(backend="nccl")
rank = dist.get_rank()
device = torch.device(f"cuda:{rank}")
tp_mesh = dist.device_mesh.init_device_mesh("cuda", (dist.get_world_size(),))
# --- Shard the model across GPUs ---
model = MyModel().eval().to(device)
parallelize_module(
model,
tp_mesh,
{
"fc1": ColwiseParallel(),
"fc2": RowwiseParallel(),
},
)
# --- Compile each GPU's shard with Torch-TensorRT ---
inputs = [torch.randn(8, 512).to(device)]
trt_model = torch.compile(
model,
backend="torch_tensorrt",
options={
"use_distributed_mode_trace": True,
"use_explicit_typing": True, # enabled_precisions deprecated
"use_python_runtime": True,
"min_block_size": 1,
},
)
output = trt_model(*inputs)
use_distributed_mode_trace=True is required whenever:
The model contains
DTensorparameters (fromparallelize_module).The model uses
torch.distributedcollective ops that appear as graph nodes.
Without it, the default export path (aot_export_joint_simple) will fail on DTensor
inputs and Torch-TensorRT will emit a warning.
NCCL Collective Ops in TRT Graphs#
For models where collective ops (all-gather, reduce-scatter) appear inside the TRT-compiled subgraph, Torch-TensorRT can fuse them using TensorRT-LLM plugins.
The fuse_distributed_ops lowering pass automatically rewrites consecutive
_c10d_functional.all_gather_into_tensor / reduce_scatter_tensor +
wait_tensor pairs into fused custom ops:
tensorrt_fused_nccl_all_gather_optensorrt_fused_nccl_reduce_scatter_op
These are then converted by custom converters into TensorRT-LLM AllGather / ReduceScatter
plugin layers (requires ENABLED_FEATURES.trtllm_for_nccl). When the TRT-LLM plugin is
unavailable, the ops fall back to PyTorch execution transparently.
See the tensor_parallel_rotary_embedding.py example for a Llama-style model with NCCL collective ops compiled end-to-end.
Compilation Settings for Distributed Workloads#
Setting |
Default |
Description |
|---|---|---|
|
|
Use |
|
|
Use the Python runtime. Often set to |
|
|
Respect dtypes set in model/inputs (recommended). Use |
Launching Distributed Scripts#
Use torchrun to launch multi-GPU scripts:
# 4-GPU tensor-parallel job
torchrun --nproc_per_node=4 tensor_parallel_example.py
# 2-node, 8-GPU data-parallel job
torchrun --nnodes=2 --nproc_per_node=4 --rdzv_backend=c10d \
--rdzv_endpoint=$MASTER_ADDR:$MASTER_PORT data_parallel_example.py
Examples#
The following complete examples are available in the Torch-TensorRT repository under
examples/distributed_inference/:
Script |
Description |
|---|---|
|
Data-parallel GPT-2 inference with Accelerate |
|
Data-parallel Stable Diffusion with Accelerate |
|
Two-layer MLP with column / row parallel sharding |
|
Llama-3 RoPE module with NCCL collective ops |