Tensor Parallel Distributed Inference with Torch-TensorRT#

Below example shows how to use Torch-TensorRT backend for distributed inference with tensor parallelism.

This example demonstrates: - Setting up distributed environment for tensor parallelism - Model sharding across multiple GPUs - Compilation with Torch-TensorRT - Distributed inference execution

Usage#

mpirun -n 2 --allow-run-as-root python tensor_parallel_simple_example.py
[ ]:
import time

import tensorrt as trt
import torch
import torch.distributed as dist
import torch.nn as nn
from tensor_parallel_initialize_dist import (
    cleanup_distributed_env,
    get_tensor_parallel_device_mesh,
    initialize_distributed_env,
)

# Initialize distributed environment and logger BEFORE importing torch_tensorrt
# This ensures logging is configured before any import-time log messages
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
    "tensor_parallel_simple_example"
)

from torch.distributed._tensor import Shard
from torch.distributed.tensor.parallel import (
    ColwiseParallel,
    RowwiseParallel,
    parallelize_module,
)

"""
This example takes some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
"""


class ToyModel(nn.Module):
    """MLP based model"""

    def __init__(self):
        super(ToyModel, self).__init__()
        self.in_proj = nn.Linear(10, 3200)
        self.relu = nn.ReLU()
        self.out_proj = nn.Linear(3200, 1600)
        self.in_proj2 = nn.Linear(1600, 500)
        self.out_proj2 = nn.Linear(500, 100)

    def forward(self, x):
        x = self.out_proj(self.relu(self.in_proj(x)))
        x = self.relu(x)
        x = self.out_proj2(self.relu(self.in_proj2(x)))
        return x


logger.info(f"Starting PyTorch TP example on rank {_rank}.")
assert (
    _world_size % 2 == 0
), f"TP examples require even number of GPUs, but got {_world_size} gpus"


# # create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids.
tp_model = ToyModel().to("cuda")


# Custom parallelization plan for the model
tp_model = parallelize_module(
    module=tp_model,
    device_mesh=device_mesh,
    parallelize_plan={
        "in_proj": ColwiseParallel(input_layouts=Shard(0)),
        "out_proj": RowwiseParallel(output_layouts=Shard(0)),
        "in_proj2": ColwiseParallel(input_layouts=Shard(0)),
        "out_proj2": RowwiseParallel(output_layouts=Shard(0)),
    },
)
torch.manual_seed(0)
inp = torch.rand(20, 10, device="cuda")
python_result = tp_model(inp)

backend = "torch_tensorrt"
tp_model = torch.compile(
    tp_model,
    backend=backend,
    options={
        "truncate_long_and_double": True,
        "enabled_precisions": {torch.float32, torch.float16},
        "use_python_runtime": True,
        "min_block_size": 1,
        "use_distributed_mode_trace": True,
    },
    dynamic=None,
)

# For TP, input needs to be same across all TP ranks.
# Setting the random seed is to mimic the behavior of dataloader.
torch.manual_seed(0)
inp = torch.rand(20, 10, device="cuda")
start = time.time()
output = tp_model(inp)
end = time.time()
logger.info(f"Compilation time is {end - start}")
assert (python_result - output).std() < 0.01, "Result is not correct."

# This cleans up the distributed process group
cleanup_distributed_env()