Weight Refit#
Note
This page documents the design for weight refit in Torch-TensorRT. Original design discussions: RFC #2900, RFC #3204.
Goal#
Allow compiled TensorRT engines to have their weights updated after compilation without rebuilding the engine from scratch. Engine builds involve expensive kernel auto-tuning; refit skips that entirely and just copies new weight values into the already-built engine, which is typically 80–95% faster than a full rebuild.
Primary use cases:
LoRA / adapter hot-swapping — apply a new adapter (e.g. a LoRA for Stable Diffusion) to a pre-compiled TRT engine in seconds rather than minutes.
A/B testing — switch model weight variants without recompilation.
Cloud pre-compiled engines — distribute a weight-stripped engine; end users fill weights locally.
Parameter-efficient fine-tuning — freeze the backbone in TRT and only refit adapter layers on each training step.
High-Level Design#
The compilation pipeline is extended from:
lowering → partitioning → compilation
to:
lowering → partitioning → compilation → refit
During the initial compilation a refit map is constructed — a lookup table
mapping original PyTorch parameter names to their corresponding TensorRT layer
indices. This map is stored inside every TorchTRTModule and is used later to
efficiently copy new weights without traversing the full graph again.
Compilation Modes#
Three engine modes are supported (controlled by make_refittable and
strip_engine_weights):
Weightless + refittable (
strip_engine_weights=True,make_refittable=True) — engine stores only the computation graph; weights are supplied at runtime via refit. Cache-friendly: the engine file is much smaller, and any engine is cacheable regardless of weight values.Refittable with embedded weights (
make_refittable=True) — engine stores both the computation graph and current weights. Refit replaces the embedded weights in-place (kREFIT_IDENTICALsemantic in TensorRT).Non-refittable (legacy default) — weights are baked into the engine at build time; no post-build updates are possible.
User-Facing API#
refit_module_weights#
Refit a compiled torch.fx.GraphModule with weights from a new exported program:
import torch_tensorrt
from torch_tensorrt.dynamo import refit_module_weights
# Compile once
exp_program = torch.export.export(model, inputs)
trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs=inputs)
# Later: update weights (e.g. different LoRA applied to model)
new_model = MyModel() # same architecture, different weights
new_exp_program = torch.export.export(new_model, inputs)
refitted_gm = refit_module_weights(
compiled_module=trt_gm,
new_weight_module=new_exp_program,
inputs=inputs,
)
MutableTorchTensorRTModule#
A higher-level nn.Module wrapper that intercepts weight mutations and
dispatches to refit automatically:
from torch_tensorrt.dynamo import MutableTorchTensorRTModule
mutable = MutableTorchTensorRTModule(model, config=settings)
# Weight update (e.g. HuggingFace diffusers LoRA load_lora_weights)
pipeline.unet = mutable
pipeline.load_lora_weights("path/to/lora")
# → intercepted; refit triggered automatically, no recompilation
# If the model architecture changes (new adapter inserts layers),
# a full recompilation is triggered instead (engine cache is consulted first).
Internal Implementation#
Refit Map Construction#
During conversion the FX interpreter inspects each INetworkDefinition layer that
carries learnable weights (convolutions, deconvolutions, BatchNorm, LayerNorm,
constant layers) and records a mapping:
{ "pytorch.param.name" : trt_layer_index }
This map is serialized alongside the engine bytes and stored as part of the
torch.classes.tensorrt.Engine object.
def construct_refit_mapping(
module: torch.fx.GraphModule,
inputs: Sequence[Input],
settings: CompilationSettings = CompilationSettings(),
) -> dict[str, np.ndarray]:
"""
Run the interpreter and find the weight mapping between
the exported program's state_dict and TensorRT engine weights.
Returns: { trt_weight_name -> numpy weight array }
"""
Weight Application#
refit_module_weights re-runs the compilation settings stored in the compiled
module to re-trace the new exported program through the ATen lowering stage only
(no partitioning or engine rebuild), then:
Iterates over TRT submodules in the compiled graph.
For each, constructs a fresh
INetworkDefinitionfrom the new weights.Uses
nvinfer1::IRefitterto push the new weights into the existing engine.Returns a copy of the compiled module with the updated engines.
Non-Refittable Ops#
Some ops cannot be refitted because TensorRT embeds their outputs as constants
(e.g. aten.cumsum, aten.embedding_bag). Two options are available:
Keep engines refittable and fall back those ops to PyTorch.
Set
make_refittable=Falseand rebuild the engine when weights change.
Refit Caching Shortcut#
If the mapping between state_dict keys and TRT engine weight names is stable
across calls (same model, different weights), the map is cached so that
re-interpretation of the exported program can be skipped entirely on subsequent
refits — only the weight copy step runs.