{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n\n# Refitting Torch-TensorRT Programs with New Weights\n\nCompilation is an expensive operation as it involves many graph transformations, translations\nand optimizations applied on the model. In cases were the weights of a model might be updated\noccasionally (e.g. inserting LoRA adapters), the large cost of recompilation can make it infeasible\nto use TensorRT if the compiled program needed to be built from scratch each time. Torch-TensorRT\nprovides a PyTorch native mechanism to update the weights of a compiled TensorRT program without\nrecompiling from scratch through weight refitting.\n\nIn this tutorial, we are going to walk through\n\n 1. Compiling a PyTorch model to a TensorRT Graph Module\n 2. Save and load a graph module\n 3. Refit the graph module\n\nThis tutorial focuses mostly on the AOT workflow where it is most likely that a user might need to\nmanually refit a module. In the JIT workflow, weight changes trigger recompilation. As the engine\nhas previously been built, with an engine cache enabled, Torch-TensorRT can automatically recognize\na previously built engine, trigger refit and short cut recompilation on behalf of the user (see: `engine_caching_example`).\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Standard Workflow\n\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Imports and model definition\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import numpy as np\nimport torch\nimport torch_tensorrt as torch_trt\nimport torchvision.models as models\nfrom torch_tensorrt.dynamo import refit_module_weights\n\nnp.random.seed(0)\ntorch.manual_seed(0)\ninputs = [torch.rand((1, 3, 224, 224)).to(\"cuda\")]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Make a refittable Compilation Program\n\nThe inital step is to compile a module and save it as with a normal. Note that there is an\nadditional parameter `immutable_weights` that is set to `False`. This parameter is used to\nindicate that the engine being built should support weight refitting later. Engines built without\nthese setttings will not be able to be refit.\n\nIn this case we are going to compile a ResNet18 model with randomly initialized weights and save it.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "model = models.resnet18(pretrained=False).to(\"cuda\").eval()\nexp_program = torch.export.export(model, tuple(inputs))\nenabled_precisions = {torch.float}\nworkspace_size = 20 << 30\nmin_block_size = 0\nuse_python_runtime = False\ntorch_executed_ops = {}\ntrt_gm = torch_trt.dynamo.compile(\n exp_program,\n tuple(inputs),\n use_python_runtime=use_python_runtime,\n enabled_precisions=enabled_precisions,\n min_block_size=min_block_size,\n torch_executed_ops=torch_executed_ops,\n immutable_weights=False,\n reuse_cached_engines=False,\n) # Output is a torch.fx.GraphModule\n\n# Save the graph module as an exported program\ntorch_trt.save(trt_gm, \"./compiled.ep\", inputs=inputs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Refit the Program with Pretrained Weights\n\nRandom weights are not useful for inference. But now instead of recompiling the model, we can\nrefit the model with the pretrained weights. This is done by setting up another PyTorch module\nwith the target weights and exporting it as an ExportedProgram. Then the ``refit_module_weights``\nfunction is used to update the weights of the compiled module with the new weights.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Create and compile the updated model\nmodel2 = models.resnet18(pretrained=True).to(\"cuda\").eval()\nexp_program2 = torch.export.export(model2, tuple(inputs))\n\n\ncompiled_trt_ep = torch_trt.load(\"./compiled.ep\")\n\n# This returns a new module with updated weights\nnew_trt_gm = refit_module_weights(\n compiled_module=compiled_trt_ep,\n new_weight_module=exp_program2,\n arg_inputs=inputs,\n)\n\n# Check the output\nwith torch.no_grad():\n expected_outputs, refitted_outputs = (\n exp_program2.module()(*inputs),\n new_trt_gm(*inputs),\n )\n for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):\n assert torch.allclose(\n expected_output, refitted_output, 1e-2, 1e-2\n ), \"Refit Result is not correct. Refit failed\"\n\nprint(\"Refit successfully!\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Advanced Usage\n\nThere are a number of settings you can use to control the refit process\n\n### Weight Map Cache\n\nWeight refitting works by matching the weights of the compiled module with the new weights from\nthe user supplied ExportedProgram. Since 1:1 name matching from PyTorch to TensorRT is hard to accomplish,\nthe only gaurenteed way to match weights at *refit-time* is to pass the new ExportedProgram through the\nearly phases of the compilation process to generate near identical weight names. This can be expensive\nand is not always necessary.\n\nTo avoid this, **At initial compile**, Torch-TensorRt will attempt to cache a direct mapping from PyTorch\nweights to TensorRT weights. This cache is stored in the compiled module as metadata and can be used\nto speed up refit. If the cache is not present, the refit system will fallback to rebuilding the mapping at\nrefit-time. Use of this cache is controlled by the ``use_weight_map_cache`` parameter.\n\nSince the cache uses a heuristic based system for matching PyTorch and TensorRT weights, you may want to verify the refitting. This can be done by setting\n``verify_output`` to True and providing sample ``arg_inputs`` and ``kwarg_inputs``. When this is done, the refit\nsystem will run the refitted module and the user supplied module on the same inputs and compare the outputs.\n\n### In-Place Refit\n\n``in_place`` allows the user to refit the module in place. This is useful when the user wants to update the weights\nof the compiled module without creating a new module.\n\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.15" } }, "nbformat": 4, "nbformat_minor": 0 }