MutableTorchTensorRTModule#
MutableTorchTensorRTModule is a drop-in wrapper for any torch.nn.Module that
compiles to TRT on the first forward call and automatically refits the TRT engine
whenever weights change — without recompilation. It is the recommended path for
integrating Torch-TensorRT into diffusion pipelines (Stable Diffusion, FLUX, etc.) and
any workflow that swaps LoRA adapters or checkpoint weights between runs.
See the complete worked example: mutable_torchtrt_module_example
Diffusers / LoRA Use Case#
The primary use case is replacing a sub-module inside a HuggingFace
DiffusionPipeline. This is a one-line change — the rest of your pipeline code
stays identical:
import torch
import torch_tensorrt
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
).to("cuda")
# The only extra line you need
pipe.unet = torch_tensorrt.MutableTorchTensorRTModule(
pipe.unet,
use_explicit_typing=True, # pipeline already loaded in float16 via torch_dtype
use_python_runtime=True,
)
The pipeline’s unet is now backed by a TRT engine. The first call to pipe(...)
triggers compilation; subsequent calls run the cached engine.
Dynamic shapes are required for diffusion pipelines because batch size and image dimensions vary. Set the ranges before the first call:
BATCH = torch.export.Dim("BATCH", min=2, max=24)
_H = torch.export.Dim("_H", min=16, max=32)
_W = torch.export.Dim("_W", min=16, max=32)
pipe.unet.set_expected_dynamic_shape_range(
args_dynamic_shape=({0: BATCH, 2: 4 * _H, 3: 4 * _W}, {}),
kwargs_dynamic_shape={
"encoder_hidden_states": {0: BATCH},
"added_cond_kwargs": {
"text_embeds": {0: BATCH},
"time_ids": {0: BATCH},
},
"return_dict": None, # None-valued kwargs are excluded
},
)
image = pipe("cinematic photo, 4k", num_inference_steps=30).images[0]
image.save("without_lora.jpg")
Loading a LoRA adapter uses the standard HuggingFace API — no Torch-TensorRT
calls required. MutableTorchTensorRTModule detects the weight change and refits
automatically on the next forward call:
pipe.load_lora_weights(
"stablediffusionapi/load_lora_embeddings",
weight_name="all-disney-princess-xl-lo.safetensors",
adapter_name="lora1",
)
pipe.set_adapters(["lora1"], adapter_weights=[1])
pipe.fuse_lora()
pipe.unload_lora_weights()
# Refit is triggered automatically here — much faster than recompilation
image = pipe("cinematic photo, princess", num_inference_steps=30).images[0]
image.save("with_lora.jpg")
Basic Workflow#
For a simpler illustration without diffusers, see the ResNet portion of mutable_torchtrt_module_example:
import torch
import torch_tensorrt
import torchvision.models as models
model = models.resnet18(pretrained=True).eval().cuda()
mutable_module = torch_tensorrt.MutableTorchTensorRTModule(
model,
use_explicit_typing=True, # enabled_precisions deprecated
)
inputs = [torch.rand(1, 3, 224, 224).cuda()]
with torch.no_grad():
mutable_module(*inputs) # compiles on first call
model2 = models.resnet18(pretrained=False).eval().cuda()
mutable_module.load_state_dict(model2.state_dict()) # marks for refit
output = mutable_module(*inputs) # refits, then runs
Dynamic Shapes#
The dynamic-shape hint format mirrors torch.export.export’s dynamic_shapes
argument. Use {} for inputs whose shapes are static, and a torch.export.Dim
for each dynamic axis. Nested dict/list structures (common in diffusion models) are
fully supported — the hint structure must mirror the input structure exactly:
dim_0 = torch.export.Dim("dim_0", min=1, max=50)
dim_1 = torch.export.Dim("dim_1", min=1, max=50)
mutable_module.set_expected_dynamic_shape_range(
args_dynamic_shape=({0: dim_0}, {1: dim_1}),
kwargs_dynamic_shape={},
)
Calling set_expected_dynamic_shape_range again clears the cached inputs and
triggers a fresh compilation on the next forward pass.
Engine Caching#
Pass engine-cache settings directly to MutableTorchTensorRTModule to avoid
recompiling on subsequent process starts:
mutable_module = torch_tensorrt.MutableTorchTensorRTModule(
model,
use_explicit_typing=True, # enabled_precisions deprecated
cache_built_engines=True,
reuse_cached_engines=True,
engine_cache_size=1 << 30, # 1 GiB
)
See Engine Caching for details.
Saving and Loading#
MutableTorchTensorRTModule uses its own save/load API (not torch_tensorrt.save)
because it carries extra state — dynamic-shape descriptors, refit state, etc.:
# Requires use_python_runtime=False (the default)
torch_tensorrt.MutableTorchTensorRTModule.save(mutable_module, "module.pkl")
mutable_module = torch_tensorrt.MutableTorchTensorRTModule.load("module.pkl")
use_python_runtime=True (used in the diffusers examples for pipeline compatibility)
does not support save/load. Switch to the default C++ runtime if serialization is
required.
How the Refit / Recompile Decision Works#
Every attribute write, load_state_dict, and sub-module access is intercepted by
an internal ChangeTriggerWrapper. On each forward call the module evaluates an
internal flag:
State |
Action |
|---|---|
|
Inputs match — run the TRT engine directly. |
|
Weights changed, structure is the same — refit the engine, then run. |
|
Structure changed (new keys, weight shapes) or input shapes changed — full recompile, then run. |
|
A |
If refit fails the module falls back to a full recompile automatically.
Comparison with dynamo.compile#
For large models like FLUX where a single upfront compilation is preferred and LoRA
swapping is not needed, torch_tensorrt.dynamo.compile with
immutable_weights=False is the right tool. See torch_export_flux_dev for
the complete example.
|
|
|
|---|---|---|
Weight change detection |
Automatic |
Manual |
LoRA with diffusers |
Drop-in ( |
Requires manual refit after LoRA fusion |
Save / Load |
|
|
Dynamic shapes |
|
|
Best for |
Diffusion pipelines, LoRA swaps, iterative fine-tuning |
Large one-shot models (FLUX, LLMs), maximum control |