{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n\n# Saving and Loading Models with Dynamic Shapes\n\nThis example demonstrates how to save and load Torch-TensorRT compiled models\nwith dynamic input shapes. When you compile a model with dynamic shapes,\nyou need to preserve the dynamic shape specifications when saving the model\nto ensure it can handle variable input sizes after deserialization.\n\nThe API is designed to feel similar to torch.export's handling of dynamic shapes\nfor consistency and ease of use.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports and Model Definition\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import tempfile\n\nimport torch\nimport torch.nn as nn\nimport torch_tensorrt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define a simple model that we'll compile with dynamic batch size\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class MyModel(nn.Module):\n def __init__(self):\n super().__init__()\n self.conv = nn.Conv2d(3, 16, 3, stride=1, padding=1)\n self.relu = nn.ReLU()\n self.linear = nn.Linear(16 * 224 * 224, 10)\n\n def forward(self, x):\n x = self.conv(x)\n x = self.relu(x)\n x = x.flatten(1)\n x = self.linear(x)\n return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Compile with Dynamic Shapes\nFirst, we compile the model with dynamic batch dimension\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "model = MyModel().eval().cuda()\n\n# Define example input with batch size 2\nexample_input = torch.randn(2, 3, 224, 224).cuda()\n\n# Define dynamic batch dimension using torch.export.Dim\n# This allows batch sizes from 1 to 32\ndyn_batch = torch.export.Dim(\"batch\", min=1, max=32)\n\n# Specify which dimensions are dynamic\ndynamic_shapes = {\"x\": {0: dyn_batch}}\n\n# Export the model with dynamic shapes\nexp_program = torch.export.export(\n model, (example_input,), dynamic_shapes=dynamic_shapes, strict=False\n)\n\n# Compile with Torch-TensorRT\ncompile_spec = {\n \"inputs\": [\n torch_tensorrt.Input(\n min_shape=(1, 3, 224, 224),\n opt_shape=(8, 3, 224, 224),\n max_shape=(32, 3, 224, 224),\n dtype=torch.float32,\n )\n ],\n \"enabled_precisions\": {torch.float32},\n \"min_block_size\": 1,\n}\n\ntrt_gm = torch_tensorrt.dynamo.compile(exp_program, **compile_spec)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Test Compiled Model with Different Batch Sizes\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Test with batch size 4\ninput_bs4 = torch.randn(4, 3, 224, 224).cuda()\noutput_bs4 = trt_gm(input_bs4)\n\n# Test with batch size 16\ninput_bs16 = torch.randn(16, 3, 224, 224).cuda()\noutput_bs16 = trt_gm(input_bs16)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Save the Model with Dynamic Shapes\nThe key is to pass the same dynamic_shapes specification to save()\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "with tempfile.TemporaryDirectory() as tmpdir:\n save_path = f\"{tmpdir}/dynamic_model.ep\"\n\n # Save with dynamic_shapes parameter - this is crucial for preserving dynamic behavior\n torch_tensorrt.save(\n trt_gm,\n save_path,\n output_format=\"exported_program\",\n arg_inputs=[example_input],\n dynamic_shapes=dynamic_shapes, # Same as used during export\n )\n\n # %%\n # Load and Test the Saved Model\n # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n\n # Load the saved model\n loaded_model = torch_tensorrt.load(save_path).module()\n\n # Test with the same batch sizes to verify dynamic shapes are preserved\n output_loaded_bs4 = loaded_model(input_bs4)\n\n output_loaded_bs16 = loaded_model(input_bs16)\n\n assert torch.allclose(output_bs4, output_loaded_bs4, rtol=1e-3, atol=1e-3)\n assert torch.allclose(output_bs16, output_loaded_bs16, rtol=1e-3, atol=1e-3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Example with Multiple Dynamic Dimensions\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "class MultiDimModel(nn.Module):\n def __init__(self):\n super().__init__()\n self.conv = nn.Conv2d(3, 16, 3, stride=1, padding=1)\n\n def forward(self, x):\n return self.conv(x)\n\n\nmodel2 = MultiDimModel().eval().cuda()\nexample_input2 = torch.randn(2, 3, 128, 128).cuda()\n\n# Define dynamic dimensions for batch and spatial dimensions\ndyn_batch2 = torch.export.Dim(\"batch\", min=1, max=16)\ndyn_height = torch.export.Dim(\"height\", min=64, max=512)\ndyn_width = torch.export.Dim(\"width\", min=64, max=512)\n\ndynamic_shapes2 = {\"x\": {0: dyn_batch2, 2: dyn_height, 3: dyn_width}}\n\nexp_program2 = torch.export.export(\n model2, (example_input2,), dynamic_shapes=dynamic_shapes2, strict=False\n)\n\ncompile_spec2 = {\n \"inputs\": [\n torch_tensorrt.Input(\n min_shape=(1, 3, 64, 64),\n opt_shape=(8, 3, 256, 256),\n max_shape=(16, 3, 512, 512),\n dtype=torch.float32,\n )\n ],\n \"enabled_precisions\": {torch.float32},\n}\n\ntrt_gm2 = torch_tensorrt.dynamo.compile(exp_program2, **compile_spec2)\n\nwith tempfile.TemporaryDirectory() as tmpdir:\n save_path2 = f\"{tmpdir}/multi_dim_model.ep\"\n\n torch_tensorrt.save(\n trt_gm2,\n save_path2,\n output_format=\"exported_program\",\n arg_inputs=[example_input2],\n dynamic_shapes=dynamic_shapes2,\n )\n\n loaded_model2 = torch_tensorrt.load(save_path2).module()\n\n # Test with different input shapes\n test_input = torch.randn(4, 3, 256, 256).cuda()\n output = loaded_model2(test_input)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.14" } }, "nbformat": 4, "nbformat_minor": 0 }