Torch-TensorRT Distributed Inference#

This interactive script is intended as a sample of distributed inference using data parallelism using Accelerate library with the Torch-TensorRT workflow on Stable Diffusion model.

Imports and Model Definition#

[ ]:
import torch
import torch_tensorrt
from accelerate import PartialState
from diffusers import DiffusionPipeline

model_id = "CompVis/stable-diffusion-v1-4"

# Instantiate Stable Diffusion Pipeline with FP16 weights
pipe = DiffusionPipeline.from_pretrained(
    model_id, revision="fp16", torch_dtype=torch.float16
)

distributed_state = PartialState()
pipe = pipe.to(distributed_state.device)

backend = "torch_tensorrt"

# Optimize the UNet portion with Torch-TensorRT
pipe.unet = torch.compile(  # %%
    # Inference
    # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    # Assume there are 2 processes (2 devices)
    pipe.unet,
    backend=backend,
    options={
        "truncate_long_and_double": True,
        "precision": torch.float16,
        "use_python_runtime": True,
    },
    dynamic=False,
)
torch_tensorrt.runtime.set_multi_device_safe_mode(True)

Inference#

[ ]:
# Assume there are 2 processes (2 devices)
with distributed_state.split_between_processes(["a dog", "a cat"]) as prompt:
    result = pipe(prompt).images[0]
    result.save(f"result_{distributed_state.process_index}.png")