Distributed Inference#
Torch-TensorRT supports two distributed inference patterns:
Data parallel — the same TRT-compiled model runs independently on multiple GPUs, each processing a different shard of the batch.
Tensor parallel — the model is sharded across GPUs using
torch.distributedand DTensors; Torch-TensorRT compiles each per-GPU shard.
Both patterns produce standard TRT engines; Torch-TensorRT handles only the compilation step — data movement and process coordination remain the responsibility of the distributed framework you choose.
Data Parallel Inference#
The simplest path is to replicate the compiled model on each GPU process using Accelerate or PyTorch DDP. Each process compiles its own TRT engine independently.
# Run with: torchrun --nproc_per_node=<N> script.py
import torch
import torch_tensorrt
from accelerate import PartialState
distributed_state = PartialState()
device = distributed_state.device # GPU assigned to this rank
model = MyModel().eval().to(device)
inputs = [torch.randn(1, 3, 224, 224).to(device)]
with distributed_state.split_between_processes(inputs) as local_inputs:
trt_model = torch_tensorrt.compile(
model,
ir="dynamo",
arg_inputs=local_inputs,
min_block_size=1,
)
output = trt_model(*local_inputs)
See the data_parallel_gpt2.py and data_parallel_stable_diffusion.py examples for complete Accelerate-based workflows with GPT-2 and Stable Diffusion.
Tensor Parallel Inference#
Tensor parallelism shards model weights across GPUs. Each GPU holds a slice of every weight tensor and participates in collective operations (all-reduce, all-gather, reduce-scatter) to execute the full model forward pass.
Torch-TensorRT supports two compilation workflows for tensor-parallel models:
torch.compile (JIT) — uses
torch._dynamoto trace the model at runtime. Works with DTensor-parallelized models directly.torch.export (AOT) — ahead-of-time export, TRT compilation, and save/load. Requires manual weight slicing (DTensor is not yet supported by
torch.export).
Workflow 1: torch.compile (JIT)#
The simplest path for tensor-parallel inference. Shard the model with
parallelize_module (DTensor), compile with torch.compile, and wrap
inference in distributed_context for safe NCCL lifecycle management:
import os
import torch
import torch.distributed as dist
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)
import torch_tensorrt
dist.init_process_group(backend="nccl")
device = torch.device(f"cuda:{dist.get_rank()}")
tp_mesh = dist.device_mesh.init_device_mesh("cuda", (dist.get_world_size(),))
model = MyModel().eval().half().to(device)
parallelize_module(
model,
tp_mesh,
{
"attn.q_proj": ColwiseParallel(),
"attn.o_proj": RowwiseParallel(),
"mlp.gate_proj": ColwiseParallel(),
"mlp.down_proj": RowwiseParallel(),
},
)
trt_model = torch.compile(
model,
backend="torch_tensorrt",
dynamic=True,
options={
"use_distributed_mode_trace": True,
"use_python_runtime": False,
"min_block_size": 1,
},
)
# Warmup — triggers engine build
_ = trt_model(input_ids, position_ids=position_ids)
dist.barrier()
# Use distributed_context to manage the NCCL lifecycle.
# On __exit__ it releases NCCL communicators from TRT engines,
# making dist.destroy_process_group() safe to call afterward.
with torch_tensorrt.distributed.distributed_context(dist.group.WORLD, trt_model) as model:
output = model(input_ids, position_ids=position_ids)
dist.destroy_process_group()
os._exit(0) # bypass Python GC destructor races
Key points:
Use
distributed_contextfor safe teardown — TRT engines hold a raw pointer to the NCCL communicator.distributed_context(group, module)tracks all multi-device engines and callsrelease_nccl_comm()on__exit__, detaching the communicator sodist.destroy_process_group()doesn’t cause a use-after-free. Always follow the block withdist.destroy_process_group()andos._exit(0).Automatic distributed tracing — when
dist.init_process_group()has been called andworld_size > 1, Torch-TensorRT detects the active distributed context and automatically switches thetorch.compilebackend tracer from the defaulttorch._dynamopath toaot_autograd. You do not need to setuse_distributed_mode_trace=Trueexplicitly.dynamic=Trueenables dynamic sequence lengths — TRT builds a single engine that handles varying input shapes without recompiling.NCCL all-reduce ops from
RowwiseParallelare fused and converted to native TRTDistCollectivelayers (TRT 10.16+) or TRT-LLM plugin layers (TRT < 10.16).The warmup forward pass should use
torch._dynamo.mark_dynamic()to match the generate loop and avoid a recompile.Non-default TP subgroup — if the TRT engine should use a subgroup communicator (e.g. tensor-parallel inside a data-parallel job), pass the subgroup instead of
dist.group.WORLD:tp_group = dist.new_group(ranks=[0, 1]) with torch_tensorrt.distributed.distributed_context(tp_group, trt_model) as model: output = model(inp)
For a complete LLM example, see tensor_parallel_llama_llm.py and the multinode variant tensor_parallel_llama_multinode.py.
Workflow 2: torch.export (AOT) with Save / Load#
For deployment scenarios where you want to compile once and load many times (e.g. serving), use the export → compile → save → load workflow.
Limitation: torch.export does not currently support DTensor-parallelized
models (sharding propagation fails on symbolic reshapes). The workaround is to manually
slice weights per-rank and insert explicit _c10d_functional.all_reduce ops for
row-parallel layers.
Step 1: Export and Save#
import torch
import torch.distributed as dist
import torch_tensorrt
# Manual row-parallel wrapper (replaces DTensor RowwiseParallel)
class RowParallelLinear(torch.nn.Module):
def __init__(self, linear, group_name):
super().__init__()
self.linear = linear
self.group_name = group_name
def forward(self, x):
out = self.linear(x)
out = torch.ops._c10d_functional.all_reduce(out, "sum", self.group_name)
out = torch.ops._c10d_functional.wait_tensor(out)
return out
# Slice weights for this rank and wrap row-parallel layers
group_name = dist.distributed_c10d._get_default_group().group_name
# ... slice column-parallel weights on dim 0, row-parallel on dim 1 ...
model.o_proj = RowParallelLinear(model.o_proj, group_name)
# Export (no DTensor → export succeeds)
ep = torch.export.export(model, args=(input_ids,), strict=False)
# Compile with TRT
trt_model = torch_tensorrt.dynamo.compile(
ep,
inputs=[input_ids],
use_explicit_typing=True,
use_fp32_acc=True,
use_python_runtime=False,
min_block_size=1,
use_distributed_mode_trace=True,
)
# Save per-rank engine
torch_tensorrt.save(trt_model, f"/engines/model_rank{rank}.ep",
inputs=[input_ids], retrace=False)
Step 2: Load and Run#
Default world group (most common — all ranks share one TP group):
import os
import torch
import torch.distributed as dist
import torch_tensorrt
from torch_tensorrt.distributed import setup_nccl_for_torch_tensorrt
from torch_tensorrt.distributed._nccl_utils import initialize_nccl_comm
dist.init_process_group(backend="nccl")
setup_nccl_for_torch_tensorrt()
initialize_nccl_comm() # eagerly create NCCL communicator for TRT
rank = dist.get_rank()
# Load the per-rank engine
loaded = torch_tensorrt.load(f"/engines/model_rank{rank}.ep")
trt_model = loaded.module()
with torch_tensorrt.distributed.distributed_context(dist.group.WORLD, trt_model) as model:
output = model(input_ids)
dist.destroy_process_group()
os._exit(0)
Non-default TP subgroup (e.g. tensor-parallel inside a data-parallel job):
Use distributed_context(group, module) to pin the group on all TRT engines
in the loaded module and get the configured model back as the context value:
import os
import torch
import torch.distributed as dist
import torch_tensorrt
from torch_tensorrt.distributed import setup_nccl_for_torch_tensorrt
from torch_tensorrt.distributed._nccl_utils import initialize_nccl_comm
dist.init_process_group(backend="nccl")
setup_nccl_for_torch_tensorrt()
initialize_nccl_comm()
tp_group = dist.new_group(ranks=[0, 1]) # tensor-parallel ranks
rank = dist.get_rank()
loaded = torch_tensorrt.load(f"/engines/model_rank{rank}.ep")
raw_model = loaded.module()
with torch_tensorrt.distributed.distributed_context(tp_group, raw_model) as trt_model:
output = trt_model(input_ids)
dist.destroy_process_group()
os._exit(0)
What happens under the hood:
_c10d_functional.all_reduce+wait_tensorare fused intotensorrt_fused_nccl_all_reduce_opby thefuse_distributed_opslowering pass.The fused op is converted to a native TRT
DistCollectivelayer inside the engine.At load time, the C++ runtime auto-resolves the NCCL process group name from the c10d registry.
initialize_nccl_comm()eagerly creates PyTorch’s lazy NCCL communicator is initialized before TRT tries to bind to it.The engine is serialized with
requires_native_multidevice=True, which tells the C++ runtime to bind the NCCL communicator on first execution.
For a complete example, see tensor_parallel_llama_export.py.
Multinode Inference#
For tensor parallelism across multiple nodes (one GPU per node), use
torchtrtrun — a torchrun-compatible launcher included in Torch-TensorRT
that automatically sets up NCCL before spawning worker processes.
# Node 0 (rank 0):
torchtrtrun --nproc_per_node=1 --nnodes=2 --node_rank=0 \
--rdzv_endpoint=<node0-ip>:29500 \
tensor_parallel_llama_multinode.py
# Node 1 (rank 1):
torchtrtrun --nproc_per_node=1 --nnodes=2 --node_rank=1 \
--rdzv_endpoint=<node0-ip>:29500 \
tensor_parallel_llama_multinode.py
Single-node multi-GPU also works:
torchtrtrun --nproc_per_node=2 tensor_parallel_llama_llm.py
Note
TRT 10.x NCCL library workaround — In TRT 10.x,
IExecutionContext::setCommunicator calls dlopen("libnccl.so")
at runtime via TRT’s internal libLoader. This happens inside the C++
runtime before any Python code executes, so setting LD_LIBRARY_PATH
or LD_PRELOAD inside the script is too late. This is fixed in TRT 11.0.
torchtrtrun works around this by:
Finding the
nvidia-ncclpip package (nvidia.nccl).Creating a
libnccl.so → libnccl.so.2symlink if missing (pip only shipslibnccl.so.2).Prepending the NCCL lib directory to
LD_LIBRARY_PATH.Setting
LD_PRELOADtolibnccl.so.2so TRT’sdlopenfinds the library already resident in the process.Spawning worker processes with
RANK,LOCAL_RANK,WORLD_SIZE,MASTER_ADDR, andMASTER_PORTset.
Manual setup (without torchtrtrun ):
If you prefer to launch processes yourself, replicate the NCCL setup manually:
# Step 1 — create the libnccl.so symlink
python -c "
from torch_tensorrt.distributed import setup_nccl_for_torch_tensorrt
setup_nccl_for_torch_tensorrt()
"
# Step 2 — find the NCCL lib directory
NCCL_LIB=$(python -c "
from torch_tensorrt.distributed._nccl_utils import get_nccl_library_path
print(get_nccl_library_path())
")
# Step 3 — launch with LD_PRELOAD and LD_LIBRARY_PATH set before the process starts
# Node 0:
LD_PRELOAD="$NCCL_LIB/libnccl.so.2" LD_LIBRARY_PATH="$NCCL_LIB:$LD_LIBRARY_PATH" \
RANK=0 WORLD_SIZE=2 MASTER_ADDR=<node0-ip> MASTER_PORT=29500 \
python tensor_parallel_llama_multinode.py
# Node 1:
LD_PRELOAD="$NCCL_LIB/libnccl.so.2" LD_LIBRARY_PATH="$NCCL_LIB:$LD_LIBRARY_PATH" \
RANK=1 WORLD_SIZE=2 MASTER_ADDR=<node0-ip> MASTER_PORT=29500 \
python tensor_parallel_llama_multinode.py
Important considerations for multinode:
Set a long NCCL timeout (e.g. 2 hours) via
dist.init_process_group(timeout=datetime.timedelta(hours=2))to prevent watchdog timeouts during TRT engine building.Add a
dist.barrier()after the warmup forward pass so all ranks finish engine building before starting inference.If using NGC or a system-installed NCCL (
libnccl.soalready onLD_LIBRARY_PATH), thetorchtrtrunsetup step is skipped automatically.
NCCL Collective Ops in TRT Engines#
Torch-TensorRT compiles NCCL collective ops (all-reduce, all-gather, reduce-scatter) directly into the TRT engine binary. There are two backend paths, selected automatically at import time based on what is available in your TRT build.
Selection priority (highest to lowest):
Native TRT DistCollective (TRT 10.16+ — preferred)
TRT-LLM plugin (TRT < 10.16 fallback)
PyTorch fallback (ops remain outside the TRT subgraph)
Check which path is active in your environment:
from torch_tensorrt._features import ENABLED_FEATURES
print(ENABLED_FEATURES.native_trt_collectives) # True → native path
print(ENABLED_FEATURES.trtllm_for_nccl) # True → TRT-LLM path
Path 1: Native TRT DistCollective (TRT 10.16+)#
The preferred path. No external libraries needed — NCCL collectives are first-class TRT layers compiled into the engine. Requires:
TensorRT ≥ 10.16
Torch-TensorRT built with
ENABLE_TRT_NCCL_COLLECTIVES=ON(default for CUDA builds)
How it works end-to-end:
Graph lowering — the
fuse_distributed_opspass rewrites each_c10d_functional.<collective> + wait_tensorpair into a single fused custom op:tensorrt_fused_nccl_all_reduce_optensorrt_fused_nccl_all_gather_optensorrt_fused_nccl_reduce_scatter_op
TRT compilation — the
_TRTInterpretersetsPreviewFeature.MULTIDEVICE_RUNTIME_10_16on the builder config and then each fused op converter callsINetworkDefinition.add_dist_collective()to insert a nativeDistCollectivelayer (trt.CollectiveOperation.ALL_REDUCE,ALL_GATHER, orREDUCE_SCATTER) into the TRT network.Serialization — the engine is serialized with the
requires_native_multidevice=Trueflag in the Torch-TRT metadata, signalling to the C++ runtime that NCCL communicator binding is required at load time.Runtime — NCCL communicator binding — on first execution, the C++ runtime calls
TRTEngine::bind_nccl_comm():It resolves the process group via the group name stored on the engine (or probes the c10d registry for the first group with an NCCL backend).
It fetches the
ncclComm_tpointer from PyTorch’sProcessGroupNCCL.It calls
IExecutionContext::setCommunicator()to pass the communicator to TRT.
This is why
initialize_nccl_comm()must be called before the first inference on a loaded engine: PyTorch creates the NCCL communicator lazily (on the first collective), so without the eager-init call,bind_nccl_comm()would find a null pointer.
Note
The communicator binding happens inside the C++ runtime after the process
launches. LD_LIBRARY_PATH and LD_PRELOAD must therefore be set before
the process starts (TRT’s internal libLoader calls dlopen("libnccl.so") at
startup). torchtrtrun handles this automatically.
Path 2: TRT-LLM Plugin (TRT < 10.16)#
Used automatically when native TRT collectives are not available. Requires:
The TRT-LLM plugin library (
libnvinfer_plugin_tensorrt_llm.so)Either
TRTLLM_PLUGINS_PATH=/path/to/libset in the environment, orUSE_TRTLLM_PLUGINS=1(downloads the TRT-LLM distribution automatically)
The same fuse_distributed_ops lowering pass runs, but the converter calls
TRT-LLM’s plugin API instead of add_dist_collective(). The runtime behaviour and
requires_native_multidevice flag are identical.
Path 3: PyTorch Fallback#
If neither backend is available, the _c10d_functional ops are not registered as TRT
converters and remain outside the TRT subgraph. They execute in PyTorch on every call.
This produces correct results but loses the performance benefit of in-engine collectives.
Warning
Compile with use_distributed_mode_trace=True regardless of which backend is
active. Without it, the FX tracer may not see the collective ops at all.
Confirming the active backend#
import torch_tensorrt
from torch_tensorrt._features import ENABLED_FEATURES
if ENABLED_FEATURES.native_trt_collectives:
print("Native TRT DistCollective (TRT 10.16+)")
elif ENABLED_FEATURES.trtllm_for_nccl:
print("TRT-LLM plugin")
else:
print("PyTorch fallback — collectives will NOT be inside the TRT engine")
Process Group Management and Teardown#
torch_tensorrt.distributed.distributed_context(group, module=None)#
The recommended way to manage the NCCL lifecycle for distributed TRT engines. This context manager:
On enter — sets the active process group and pins it on all TRT engines in module (both submodule and inlined engine storage patterns).
During the block — inference uses the specified NCCL communicator.
On exit — calls
release_nccl_comm()on every tracked multi-device engine, detaching the NCCL communicator from TRT’s execution context. This makesdist.destroy_process_group()safe to call afterward.
module may be a single compiled module, a list of modules, or omitted:
Single module — yields the module as the context value:
trt_model = torch.compile(model, backend="torch_tensorrt", ...)
_ = trt_model(inp) # warmup — triggers engine build
dist.barrier()
with torch_tensorrt.distributed.distributed_context(dist.group.WORLD, trt_model) as model:
output = model(inp)
dist.destroy_process_group()
os._exit(0)
Multiple modules — pass a list; yields a list in the same order. Useful when you have separately compiled programs (e.g. an encoder and a decoder) that both need NCCL teardown:
with torch_tensorrt.distributed.distributed_context(
tp_group, [encoder_trt, decoder_trt]
) as (enc, dec):
output = dec(enc(inp))
dist.destroy_process_group()
os._exit(0)
Non-default TP subgroup (e.g. tensor-parallel inside a data-parallel job):
tp_group = dist.new_group(ranks=[0, 1])
with torch_tensorrt.distributed.distributed_context(tp_group, trt_model) as model:
output = model(inp)
Without module — use at compile time when the model is created inside the block:
with torch_tensorrt.distributed.distributed_context(tp_group):
trt_model = torch.compile(model, backend="torch_tensorrt", ...)
output = trt_model(inp)
torch_tensorrt.distributed.set_distributed_mode(group, module)#
Permanently pins group on all TRT engines in module without entering a
context manager. Use this when the group should remain set across multiple
calls outside any with block:
torch_tensorrt.distributed.set_distributed_mode(tp_group, model)
output1 = model(inp1) # group already pinned
output2 = model(inp2)
Note
set_distributed_mode pins the group but does not handle teardown.
Prefer distributed_context(group, module) when you need safe cleanup.
Both APIs handle three engine storage patterns:
Submodule engines —
TorchTensorRTModulechildren produced bytorch.compileortorch_tensorrt.compile().Inlined engines —
torch.classes.tensorrt.Engineobjects stored as plain attributes on anfx.GraphModuleaftertorch_tensorrt.save()/torch_tensorrt.load().torch.compile engines — engines inside dynamo’s code cache, tracked via a thread-local registry populated at engine creation time.
For PythonTorchTensorRTModule (use_python_runtime=True), the group is
read lazily from the active context on the first forward call, so
distributed_context (without the module argument) is sufficient — keep the
context manager active for the duration of inference.
Note
Always end distributed scripts with os._exit(0) after
dist.destroy_process_group().
TRT engines and CUDA contexts register C++ destructors that can race with
Python’s garbage collector during normal interpreter shutdown, causing
segfaults or hangs. os._exit(0) bypasses the GC and exits immediately
with a clean exit code, avoiding this entirely.
dist.destroy_process_group()
os._exit(0) # bypass Python GC — TRT/CUDA destructors race during shutdown
This applies to all distributed workflows: torch.compile, export/load,
and multinode. It is safe because all meaningful work has completed before
this point.
Compilation Settings for Distributed Workloads#
Setting |
Default |
Description |
|---|---|---|
|
|
Use |
|
|
|
|
|
Respect dtypes set in model/inputs (recommended). Use |
|
|
Set to |
|
|
Use FP32 accumulation for FP16 models. Improves numerical accuracy. |
Launching Distributed Scripts#
Single node, multiple GPUs — use torchrun or mpirun:
# torchrun
torchrun --nproc_per_node=2 tensor_parallel_llama_llm.py
# mpirun
mpirun -n 2 python tensor_parallel_llama_llm.py
Multiple nodes — use environment variables:
# Node 0
RANK=0 WORLD_SIZE=2 MASTER_ADDR=<ip> MASTER_PORT=29500 \
python tensor_parallel_llama_multinode.py
# Node 1
RANK=1 WORLD_SIZE=2 MASTER_ADDR=<ip> MASTER_PORT=29500 \
python tensor_parallel_llama_multinode.py
Examples#
The following complete examples are available in the Torch-TensorRT repository:
Script |
Description |
|---|---|
|
Data-parallel GPT-2 inference with Accelerate |
|
Data-parallel Stable Diffusion with Accelerate |
|
Two-layer MLP with column / row parallel sharding (torch.compile) |
|
Tensor-parallel attention with rotary positional embeddings (RoPE) |
|
Multinode NCCL test (C++ and Python runtime) |
|
Export → save → load round-trip test for distributed engines |
|
Llama TP with torch.compile (multinode, C++ runtime) |
|
Llama TP: export + save + load workflow (multinode) |
|
Qwen TP with torch.compile (multinode) |