Refitting TensorRT Engines with Updated Weights#
Compiling a TensorRT engine is expensive — it can take seconds to minutes depending on
model size and optimization level. For workflows where the graph structure is fixed but
the weights change (fine-tuning checkpoints, LoRA adapters, EMA weight updates),
torch_tensorrt.dynamo.refit_module_weights() updates engine weights in-place
without rerunning the TRT optimizer.
When to Use Refit#
Refit is the right tool when:
You have a compiled TRT module and a new PyTorch checkpoint with the same architecture but different weights.
The weight update is frequent (e.g., loading a new LoRA adapter for each request).
Full recompilation is too slow for your workflow.
Refit is not appropriate when:
The model architecture has changed (different layers, shapes, or graph topology). In that case you must recompile.
torch_tensorrt.MutableTorchTensorRTModulehandles this case automatically.The compiled engine was built with
immutable_weights=True(the default). You must compile withimmutable_weights=Falseto make engines refittable.
Requirements#
Compile the original model with
immutable_weights=False:import torch import torch_tensorrt model = MyModel().eval().cuda() inputs = [torch.randn(1, 3, 224, 224).cuda()] exp_program = torch.export.export(model, tuple(inputs)) compiled = torch_tensorrt.dynamo.compile( exp_program, arg_inputs=inputs, immutable_weights=False, # required for refit )
Export the updated model (same architecture, new weights) as an
ExportedProgram:updated_model = MyModel().eval().cuda() # load updated_model weights from checkpoint... new_exp_program = torch.export.export(updated_model, tuple(inputs))
Call
torch_tensorrt.dynamo.refit_module_weights():from torch_tensorrt.dynamo import refit_module_weights refitted = refit_module_weights( compiled_module=compiled, new_weight_module=new_exp_program, )
refittedis a newtorch.fx.GraphModulewith the TRT engines updated to use the new weights. The originalcompiledmodule is unchanged (a deep copy is made by default).
API#
torch_tensorrt.dynamo.refit_module_weights(
compiled_module,
new_weight_module,
arg_inputs=None,
kwarg_inputs=None,
verify_output=False,
use_weight_map_cache=True,
in_place=False,
)
Parameters
compiled_module(torch.fx.GraphModule | ExportedProgram)The compiled TRT module to update. Must have been compiled with
immutable_weights=False. Can be loaded from disk viatorch_tensorrt.load().new_weight_module(ExportedProgram)Exported program containing the updated weights. Must have the same model architecture (graph topology and tensor shapes) as the original.
arg_inputs(Tuple[Any, ...], optional)Sample positional inputs. Required only when
verify_output=True.kwarg_inputs(dict[str, Any], optional)Sample keyword inputs. Required only when
verify_output=True.verify_output(bool, defaultFalse)Run a numerical check comparing the output of the refitted TRT engine against PyTorch on the provided sample inputs. Useful for catching silent refit failures during development.
use_weight_map_cache(bool, defaultTrue)When torch-tensorrt programs are compiled, the TRTIntpereter builds a map of which exported program nodes correspond to which TensorRT layers. This mapping is stored as metadata in serialized torch-tensorrt programs. This cache is not gaurenteed to be an exact match but to a new unseen exported program but when it does, it reduces refit time by ~50%.
in_place(bool, defaultFalse)If
True, modify the compiled module in-place rather than returning a copy. Not supported forExportedPrograminputs (use the returned module instead).
Returns torch.fx.GraphModule — the refitted compiled module.
Output Verification#
Use verify_output=True during development to catch numerical mismatches between the
refitted TRT engine and PyTorch:
inputs = [torch.randn(1, 3, 224, 224).cuda()]
refitted = refit_module_weights(
compiled_module=compiled,
new_weight_module=new_exp_program,
arg_inputs=tuple(inputs),
verify_output=True,
)
A warning is logged if the outputs differ beyond floating-point tolerance.
Batch Norm and Constant Folding#
BatchNorm layers are typically constant-folded into the preceding convolution during
export. refit_module_weights handles this automatically: it reconstructs the
folded weight, bias, running_mean, and running_var tensors from the
updated BatchNorm state dict and maps them to the correct fused TRT layer.
Saving and Loading Refitted Modules#
Refitted modules can be saved and loaded exactly like any other compiled module:
torch_tensorrt.save(refitted, "model_v2.ep", arg_inputs=inputs)
# later:
refitted = torch_tensorrt.load("model_v2.ep")
Refit vs. MutableTorchTensorRTModule#
Use torch_tensorrt.MutableTorchTensorRTModule when you need automatic
handling of both weight mutations and structural mutations:
Scenario |
|
|
|---|---|---|
New checkpoint, same architecture |
Yes — explicit, controlled |
Yes — automatic |
LoRA adapter changes graph topology |
No — must recompile manually |
Yes — detects structural change, recompiles automatically |
HuggingFace |
Requires custom glue code |
Drop-in |
Fine-grained control over refit timing |
Yes |
No — mutation-triggered |