Tensor Parallel Distributed Inference with Torch-TensorRT#
Below example shows how to use Torch-TensorRT backend for distributed inference with tensor parallelism.
This example demonstrates: - Setting up distributed environment for tensor parallelism - Model sharding across multiple GPUs - Compilation with Torch-TensorRT - Distributed inference execution
Usage#
# JIT mode python runtime
mpirun -n 2 python tensor_parallel_simple_example.py --mode jit_cpp
# JIT mode cpp runtime
mpirun -n 2 python tensor_parallel_simple_example.py --mode jit_python
WIP: Export and load mode
mpirun -n 2 python tensor_parallel_simple_example.py --mode export --save-path /tmp/tp_model.ep
mpirun -n 2 python tensor_parallel_simple_example.py --mode load --save-path /tmp/tp_model.ep
[ ]:
import argparse
import time
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.utils._pytree
from tensor_parallel_initialize_dist import (
cleanup_distributed_env,
initialize_distributed_env,
)
torch.utils._pytree.register_constant(
torch.distributed.tensor._dtensor_spec.DTensorSpec
)
parser = argparse.ArgumentParser(description="Tensor Parallel Simple Example")
parser.add_argument(
"--mode",
type=str,
choices=["jit_python", "jit_cpp", "export", "load"],
default="jit_python",
)
parser.add_argument("--save-path", type=str, default="/tmp/tp_model.ep")
args = parser.parse_args()
device_mesh, _world_size, _rank, logger = initialize_distributed_env(
"tensor_parallel_simple_example"
)
import torch_tensorrt
from torch_tensorrt.distributed import setup_nccl_for_torch_tensorrt
setup_nccl_for_torch_tensorrt()
from torch.distributed._tensor import Shard
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)
"""
This example takes some code from https://github.com/pytorch/examples/blob/main/distributed/tensor_parallelism/tensor_parallel_example.py
"""
class ToyModel(nn.Module):
"""MLP based model"""
def __init__(self):
super(ToyModel, self).__init__()
self.in_proj = nn.Linear(10, 3200)
self.relu = nn.ReLU()
self.out_proj = nn.Linear(3200, 1600)
self.in_proj2 = nn.Linear(1600, 500)
self.out_proj2 = nn.Linear(500, 100)
def forward(self, x):
x = self.out_proj(self.relu(self.in_proj(x)))
x = self.relu(x)
x = self.out_proj2(self.relu(self.in_proj2(x)))
return x
logger.info(f"Starting PyTorch TP example on rank {_rank}.")
assert (
_world_size % 2 == 0
), f"TP examples require even number of GPUs, but got {_world_size} gpus"
# # create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids.
tp_model = ToyModel().to("cuda")
# Custom parallelization plan for the model
tp_model = parallelize_module(
module=tp_model,
device_mesh=device_mesh,
parallelize_plan={
"in_proj": ColwiseParallel(input_layouts=Shard(0)),
"out_proj": RowwiseParallel(output_layouts=Shard(0)),
"in_proj2": ColwiseParallel(input_layouts=Shard(0)),
"out_proj2": RowwiseParallel(output_layouts=Shard(0)),
},
)
torch.manual_seed(0)
inp = torch.rand(20, 10, device="cuda")
python_result = tp_model(inp)
if args.mode == "load":
# Load per-rank model: /tmp/tp_model.ep -> /tmp/tp_model_rank0_of_2.ep
logger.info(f"Loading from {args.save_path}")
loaded_program = torch_tensorrt.load(args.save_path)
output = loaded_program.module()(inp)
dist.barrier()
assert (python_result - output).std() < 0.01, "Result mismatch"
logger.info("Load successful!")
elif args.mode == "jit_python":
trt_model = torch.compile(
tp_model,
backend="torch_tensorrt",
options={
"truncate_long_and_double": True,
"use_python_runtime": True,
"min_block_size": 1,
},
)
output = trt_model(inp)
dist.barrier()
assert (python_result - output).std() < 0.01, "Result mismatch"
logger.info("JIT compile successful!")
elif args.mode == "jit_cpp":
trt_model = torch.compile(
tp_model,
backend="torch_tensorrt",
options={
"truncate_long_and_double": True,
"use_python_runtime": False,
"min_block_size": 1,
},
)
output = trt_model(inp)
dist.barrier()
assert (python_result - output).std() < 0.01, "Result mismatch"
logger.info("JIT compile successful!")
elif args.mode == "export":
# Export: torch.export + dynamo.compile - AOT compilation, can save
exported_program = torch.export.export(tp_model, (inp,), strict=False)
trt_model = torch_tensorrt.dynamo.compile(
exported_program,
inputs=[inp],
truncate_double=True,
use_python_runtime=False,
min_block_size=1,
use_distributed_mode_trace=True,
)
output = trt_model(inp)
dist.barrier()
assert (python_result - output).std() < 0.01, "Result mismatch"
# Save per-rank: /tmp/tp_model.ep -> /tmp/tp_model_rank0_of_2.ep
save_path = torch_tensorrt.save(trt_model, args.save_path, inputs=[inp])
logger.info(f"Saved to {save_path}")
dist.barrier()
cleanup_distributed_env()
logger.info("Done!")