Two-node test: torch.export → TRT AOT compile → save → load → inference.
Tests the full serialization round-trip for tensor-parallel models with native TRT NCCL collectives (C++ runtime).
Reads RANK, WORLD_SIZE, MASTER_ADDR, MASTER_PORT from the environment.
Usage#
NCCL_LIB=$(python -c “from torch_tensorrt.distributed._nccl_utils import get_nccl_library_path; print(get_nccl_library_path())”)
Rank 0:#
LD_LIBRARY_PATH=”\(NCCL_LIB:\)LD_LIBRARY_PATH”
RANK=0 WORLD_SIZE=2 MASTER_ADDR= MASTER_PORT=29500
uv run python examples/distributed_inference/test_multinode_export_save_load.py
Rank 1:#
LD_LIBRARY_PATH=”\(NCCL_LIB:\)LD_LIBRARY_PATH”
RANK=1 WORLD_SIZE=2 MASTER_ADDR= MASTER_PORT=29500
uv run python examples/distributed_inference/test_multinode_export_save_load.py
[ ]:
import datetime
import faulthandler
import os
import sys
import tempfile
from pathlib import Path
faulthandler.enable()
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.utils._pytree
import torch_tensorrt
from torch.distributed._tensor import Shard
from torch.distributed._tensor.device_mesh import init_device_mesh
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)
from torch_tensorrt.distributed import setup_nccl_for_torch_tensorrt
torch.utils._pytree.register_constant(
torch.distributed.tensor._dtensor_spec.DTensorSpec
)
setup_nccl_for_torch_tensorrt()
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
torch.cuda.set_device(0) # one GPU per node
dist.init_process_group(
"nccl",
rank=rank,
world_size=world_size,
timeout=datetime.timedelta(hours=2),
)
torch.manual_seed(0)
device_mesh = init_device_mesh("cuda", (world_size,))
print(f"[Rank {rank}/{world_size}] distributed init OK", flush=True)
# ---------------------------------------------------------------------------
# Model: simple TP MLP (same as test_multinode_nccl.py)
# ---------------------------------------------------------------------------
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
self.in_proj = nn.Linear(10, 3200)
self.relu = nn.ReLU()
self.out_proj = nn.Linear(3200, 1600)
self.in_proj2 = nn.Linear(1600, 500)
self.out_proj2 = nn.Linear(500, 100)
def forward(self, x):
x = self.out_proj(self.relu(self.in_proj(x)))
x = self.relu(x)
x = self.out_proj2(self.relu(self.in_proj2(x)))
return x
# ---------------------------------------------------------------------------
# Exportable model: manual sharding + explicit all-reduce
# ---------------------------------------------------------------------------
class _RowParallelLinear(nn.Module):
"""Linear + NCCL all-reduce (replaces DTensor RowwiseParallel for export)."""
def __init__(self, linear, group_name):
super().__init__()
self.linear = linear
self.group_name = group_name
def forward(self, x):
out = self.linear(x)
out = torch.ops._c10d_functional.all_reduce(out, "sum", self.group_name)
out = torch.ops._c10d_functional.wait_tensor(out)
return out
def build_exportable_model(rank, world_size):
"""Build model with manually sliced weights + explicit all-reduce."""
group_name = dist.distributed_c10d._get_default_group().group_name
model = ToyModel().to("cuda")
# Column-parallel: slice output dim (dim 0)
for proj in [model.in_proj, model.in_proj2]:
w = proj.weight.data
chunk = w.shape[0] // world_size
proj.weight = nn.Parameter(w[rank * chunk : (rank + 1) * chunk].contiguous())
if proj.bias is not None:
b = proj.bias.data
proj.bias = nn.Parameter(b[rank * chunk : (rank + 1) * chunk].contiguous())
# Row-parallel: slice input dim (dim 1) + wrap with all-reduce
for attr in ["out_proj", "out_proj2"]:
proj = getattr(model, attr)
w = proj.weight.data
chunk = w.shape[1] // world_size
proj.weight = nn.Parameter(w[:, rank * chunk : (rank + 1) * chunk].contiguous())
setattr(model, attr, _RowParallelLinear(proj, group_name))
return model
def rank_path(save_dir, rank, world_size):
return str(Path(save_dir) / f"tp_rank{rank}_of_{world_size}.pt2")
# ---------------------------------------------------------------------------
# Build DTensor baseline for comparison
# ---------------------------------------------------------------------------
tp_model = ToyModel().to("cuda")
tp_model = parallelize_module(
tp_model,
device_mesh,
{
"in_proj": ColwiseParallel(input_layouts=Shard(0)),
"out_proj": RowwiseParallel(output_layouts=Shard(0)),
"in_proj2": ColwiseParallel(input_layouts=Shard(0)),
"out_proj2": RowwiseParallel(output_layouts=Shard(0)),
},
)
inp = torch.rand(20, 10, device="cuda")
python_result = tp_model(inp)
print(f"[Rank {rank}] PyTorch TP baseline OK, shape={python_result.shape}", flush=True)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
PASSED = []
FAILED = []
save_dir = tempfile.mkdtemp(prefix="trt_export_test_")
def test_export_compile_save():
"""Test 1: torch.export → TRT AOT compile → save per-rank."""
torch.manual_seed(0)
model = build_exportable_model(rank, world_size)
# Get the manually-sharded model's own PyTorch output as reference
with torch.no_grad():
ref_output = model(inp)
# Export with static shapes (dynamic shapes + DTensor not supported)
ep = torch.export.export(model, args=(inp,), strict=False)
trt_model = torch_tensorrt.dynamo.compile(
ep,
inputs=[inp],
use_fp32_acc=True,
device=torch.device("cuda:0"),
disable_tf32=True,
use_python_runtime=False,
min_block_size=1,
use_distributed_mode_trace=True,
assume_dynamic_shape_support=True,
)
# Verify TRT output matches the same model's PyTorch output
output = trt_model(inp)
std = float((ref_output - output).std())
assert std < 0.01, f"Export compile output mismatch: std={std}"
print(f"[Rank {rank}] Export compile OK (std={std:.6f})", flush=True)
# Save per-rank engine
path = rank_path(save_dir, rank, world_size)
torch_tensorrt.save(trt_model, path, inputs=[inp], retrace=False)
dist.barrier()
assert os.path.isfile(path), f"Engine file not found: {path}"
size_mb = os.path.getsize(path) / 1e6
print(f"[Rank {rank}] Saved engine to {path} ({size_mb:.1f} MB)", flush=True)
return trt_model
def test_load_and_infer():
"""Test 2: Load per-rank engine → inference (no model weights needed)."""
# Eagerly initialize PyTorch's NCCL communicator so TRT's
# bind_nccl_comm() can extract the ncclComm_t on first engine execution.
from torch_tensorrt.distributed._nccl_utils import initialize_nccl_comm
initialize_nccl_comm()
path = rank_path(save_dir, rank, world_size)
loaded = torch_tensorrt.load(path)
trt_model = loaded.module()
output = trt_model(inp)
# Compare against a manually-sharded PyTorch model with the same seed
torch.manual_seed(0)
ref_model = build_exportable_model(rank, world_size)
with torch.no_grad():
ref_output = ref_model(inp)
std = float((ref_output - output).std())
assert std < 0.01, f"Loaded engine output mismatch: std={std}"
print(f"[Rank {rank}] Load + infer OK (std={std:.6f})", flush=True)
def test_loaded_matches_compiled(compiled_output):
"""Test 3: Loaded engine produces same output as freshly compiled."""
path = rank_path(save_dir, rank, world_size)
loaded = torch_tensorrt.load(path)
trt_model = loaded.module()
loaded_output = trt_model(inp)
compiled_output_val = compiled_output(inp)
diff = float((compiled_output_val - loaded_output).abs().max())
assert diff < 1e-3, f"Compiled vs loaded mismatch: max_diff={diff}"
print(f"[Rank {rank}] Compiled vs loaded match (max_diff={diff:.6f})", flush=True)
compiled_model = None
# Run tests
for name, fn in [
("export_compile_save", lambda: test_export_compile_save()),
("load_and_infer", lambda: test_load_and_infer()),
]:
dist.barrier()
try:
result = fn()
PASSED.append(name)
if name == "export_compile_save":
compiled_model = result
except Exception as e:
print(f"[Rank {rank}] FAIL {name}: {e}", flush=True)
import traceback
traceback.print_exc()
FAILED.append(name)
# Test 3 only if test 1 passed
if compiled_model is not None:
dist.barrier()
try:
test_loaded_matches_compiled(compiled_model)
PASSED.append("loaded_matches_compiled")
except Exception as e:
print(f"[Rank {rank}] FAIL loaded_matches_compiled: {e}", flush=True)
FAILED.append("loaded_matches_compiled")
# Delete TRT engines before destroying the process group — the engines hold
# a reference to the NCCL communicator and will segfault if NCCL is torn
# down first.
del compiled_model
torch.cuda.empty_cache()
dist.destroy_process_group()
print(f"[Rank {rank}] Results — passed: {PASSED} failed: {FAILED}", flush=True)
os._exit(0 if not FAILED else 1)