Tensor Parallel Distributed Inference with Torch-TensorRT (torchrun)#

Same model as tensor_parallel_simple_example.py but launched with torchrun / python -m torch_tensorrt.distributed.run instead of mpirun.

Usage#

# Single-node, 2 GPUs
torchrun --nproc_per_node=2 tensor_parallel_simple_example_torchrun.py

# Two nodes, 1 GPU each — run on BOTH nodes simultaneously:
#   Node 0 (spirit):
RANK=0 WORLD_SIZE=2 MASTER_ADDR=<spirit_ip> MASTER_PORT=29500 LOCAL_RANK=0 \
    uv run python tensor_parallel_simple_example_torchrun.py

#   Node 1 (opportunity):
RANK=1 WORLD_SIZE=2 MASTER_ADDR=<spirit_ip> MASTER_PORT=29500 LOCAL_RANK=0 \
    uv run python tensor_parallel_simple_example_torchrun.py

# Or via torchtrtrun (sets up NCCL library paths automatically):
python -m torch_tensorrt.distributed.run --nproc_per_node=2 \
    tensor_parallel_simple_example_torchrun.py

Optional args: –mode jit_python | jit_cpp | export | load (default: jit_python) –save-path /tmp/tp_model.ep –precision FP16 | BF16 | FP32 (default: FP16) –debug

[ ]:
import argparse
import datetime
import logging
import os
from contextlib import nullcontext

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.utils._pytree
from torch.distributed.device_mesh import init_device_mesh
from torch_tensorrt.distributed import setup_nccl_for_torch_tensorrt

torch.utils._pytree.register_constant(
    torch.distributed.tensor._dtensor_spec.DTensorSpec
)

# One GPU per node; LOCAL_RANK defaults to 0 for plain env-var launch.
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.cuda.set_device(local_rank)
DEVICE = torch.device(f"cuda:{local_rank}")

# 2-hour timeout so TRT engine building doesn't trigger the NCCL watchdog.
dist.init_process_group(backend="nccl", timeout=datetime.timedelta(hours=2))
rank = dist.get_rank()
world_size = dist.get_world_size()

import torch_tensorrt
from torch_tensorrt.distributed import setup_nccl_for_torch_tensorrt

setup_nccl_for_torch_tensorrt()

from torch.distributed._tensor import Shard
from torch.distributed.tensor.parallel import (
    ColwiseParallel,
    RowwiseParallel,
    parallelize_module,
)

logging.basicConfig(
    level=logging.INFO,
    format=f"[Rank {rank}] %(levelname)s: %(message)s",
)
logger = logging.getLogger(__name__)
logger.info(f"dist init OK  rank={rank}/{world_size}  device={DEVICE}")


class ToyModel(nn.Module):
    """MLP based model"""

    def __init__(self):
        super().__init__()
        self.in_proj = nn.Linear(10, 3200)
        self.relu = nn.ReLU()
        self.out_proj = nn.Linear(3200, 1600)
        self.in_proj2 = nn.Linear(1600, 500)
        self.out_proj2 = nn.Linear(500, 100)

    def forward(self, x):
        x = self.out_proj(self.relu(self.in_proj(x)))
        x = self.relu(x)
        x = self.out_proj2(self.relu(self.in_proj2(x)))
        return x


def get_model(device_mesh):
    assert (
        world_size % 2 == 0
    ), f"TP examples require an even number of GPUs, got {world_size}"
    model = ToyModel().to(DEVICE)
    parallelize_module(
        module=model,
        device_mesh=device_mesh,
        parallelize_plan={
            "in_proj": ColwiseParallel(input_layouts=Shard(0)),
            "out_proj": RowwiseParallel(output_layouts=Shard(0)),
            "in_proj2": ColwiseParallel(input_layouts=Shard(0)),
            "out_proj2": RowwiseParallel(output_layouts=Shard(0)),
        },
    )
    logger.info("Model built and sharded across ranks.")
    return model


def compile_torchtrt(model, args):
    model.eval()

    use_fp32_acc = args.precision == "FP16"

    use_python_runtime = args.mode == "jit_python"

    with torch_tensorrt.logging.debug() if args.debug else nullcontext():
        trt_model = torch.compile(
            model,
            backend="torch_tensorrt",
            dynamic=False,
            options={
                "use_fp32_acc": use_fp32_acc,
                "device": DEVICE,
                "disable_tf32": True,
                "use_python_runtime": use_python_runtime,
                "debug": args.debug,
                "min_block_size": 1,
            },
        )
    return trt_model


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Tensor Parallel Simple Example (torchrun)"
    )
    parser.add_argument(
        "--mode",
        type=str,
        choices=["jit_python", "jit_cpp", "export", "load"],
        default="jit_python",
    )
    parser.add_argument("--save-path", type=str, default="/tmp/tp_model.ep")
    parser.add_argument(
        "--precision",
        default="FP16",
        choices=["FP16", "BF16", "FP32"],
    )
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()

    device_mesh = init_device_mesh("cuda", (world_size,))

    with torch.inference_mode():
        model = get_model(device_mesh)

        torch.manual_seed(0)
        inp = torch.rand(20, 10, device=DEVICE)
        python_result = model(inp)

        if args.mode == "load":
            logger.info(f"Loading from {args.save_path}")
            loaded_program = torch_tensorrt.load(args.save_path)
            output = loaded_program.module()(inp)
            assert (python_result - output).std() < 0.01, "Result mismatch"
            logger.info("Load successful!")

        elif args.mode in ("jit_python", "jit_cpp"):
            trt_model = compile_torchtrt(model, args)

            # Warmup: trigger engine build on all ranks, then barrier so no
            # rank races ahead to the next NCCL collective before others finish.
            logger.info("Warming up (triggering TRT engine build)...")
            _ = trt_model(inp)
            dist.barrier()
            logger.info("All ranks compiled. Running inference...")

            with torch_tensorrt.distributed.distributed_context(
                dist.group.WORLD, trt_model
            ) as dist_model:
                output = dist_model(inp)

            assert (python_result - output).std() < 0.01, "Result mismatch"
            logger.info("JIT compile successful!")

        elif args.mode == "export":
            with torch.inference_mode():
                exported_program = torch.export.export(model, (inp,), strict=False)
            trt_model = torch_tensorrt.dynamo.compile(
                exported_program,
                inputs=[inp],
                use_fp32_acc=True,
                device=DEVICE,
                disable_tf32=True,
                use_python_runtime=False,
                min_block_size=1,
                use_distributed_mode_trace=True,
                assume_dynamic_shape_support=True,
            )
            with torch.inference_mode():
                output = trt_model(inp)
            assert (python_result - output).std() < 0.01, "Result mismatch"
            save_path = torch_tensorrt.save(trt_model, args.save_path, inputs=[inp])
            logger.info(f"Saved to {save_path}")
            dist.barrier()

    dist.destroy_process_group()
    logger.info("Done!")
    # Bypass Python GC — TRT/CUDA destructors can segfault during
    # interpreter shutdown due to unpredictable destruction order.
    os._exit(0)