{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n\n# Compiling Models with Dynamic Input Shapes\n\nDynamic shapes are essential when your model\nneeds to handle varying batch sizes or sequence lengths at inference time without recompilation.\n\nThe example uses a Vision Transformer-style model with expand and reshape operations,\nwhich are common patterns that benefit from dynamic shape handling.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports and Model Definition\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import logging\n\nimport torch\nimport torch.nn as nn\nimport torch_tensorrt\n\nlogging.basicConfig(level=logging.DEBUG)\n\ntorch.manual_seed(0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "# Define a model with expand and reshape operations\n# This is a simplified Vision Transformer pattern with:\n# - A learnable class token that needs to expand to match batch size\n# - A QKV projection followed by reshaping for multi-head attention\nclass ExpandReshapeModel(nn.Module):\n def __init__(self, embed_dim: int):\n super().__init__()\n self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))\n self.embed_dim = embed_dim\n self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3)\n\n def forward(self, x: torch.Tensor):\n batch_size = x.shape[0]\n cls_token = self.cls_token.expand(batch_size, -1, -1)\n x = torch.cat([cls_token, x], dim=1)\n x = self.qkv_proj(x)\n reshaped_qkv = x.reshape(batch_size, x.size(1), 3, 12, -1)\n return reshaped_qkv\n\n\nmodel = ExpandReshapeModel(embed_dim=768).cuda().eval()\nx = torch.randn(4, 196, 768).cuda()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Approach 1: JIT Compilation with `torch.compile`\n\nThe first approach uses PyTorch's `torch.compile` with the TensorRT backend.\nThis is a Just-In-Time (JIT) compilation method where the model is compiled\nduring the first inference call.\n\nKey points:\n\n- Use `torch._dynamo.mark_dynamic()` to specify which dimensions are dynamic\n- The `index` parameter indicates which dimension (0 = batch dimension)\n- Provide `min` and `max` bounds for the dynamic dimension\n- The model will work for any batch size within the specified range\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "x1 = x.clone()\ntorch._dynamo.mark_dynamic(x1, index=0, min=2, max=32)\ntrt_module = torch.compile(model, backend=\"tensorrt\")\nout1 = trt_module(x1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Approach 2: AOT Compilation with `torch_tensorrt.compile`\n\nThe second approach uses Ahead-Of-Time (AOT) compilation with `torch_tensorrt.compile`.\nThis compiles the model upfront before inference.\n\nKey points:\n\n- Use `torch_tensorrt.Input()` to specify dynamic shape ranges\n- Provide `min_shape`, `opt_shape`, and `max_shape` for each input\n- The `opt_shape` is used for optimization and should represent typical input sizes\n- Set `ir=\"dynamo\"` to use the Dynamo frontend\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "x2 = x.clone()\nexample_input = torch_tensorrt.Input(\n min_shape=[1, 196, 768],\n opt_shape=[4, 196, 768],\n max_shape=[32, 196, 768],\n dtype=torch.float32,\n)\ntrt_module = torch_tensorrt.compile(model, ir=\"dynamo\", inputs=example_input)\nout2 = trt_module(x2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Approach 3: AOT with `torch.export` + Dynamo Compile\n\nThe third approach uses PyTorch 2.0's `torch.export` API combined with\nTorch-TensorRT's Dynamo compiler. This provides the most explicit control\nover dynamic shapes.\n\nKey points:\n\n- Use `torch.export.Dim()` to define symbolic dimensions with constraints\n- Create a `dynamic_shapes` dictionary mapping inputs to their dynamic dimensions\n- Export the model to an `ExportedProgram` with these constraints\n- Compile the exported program with `torch_tensorrt.dynamo.compile`\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "x3 = x.clone()\nbs = torch.export.Dim(\"bs\", min=1, max=32)\ndynamic_shapes = {\"x\": {0: bs}}\nexp_program = torch.export.export(model, (x3,), dynamic_shapes=dynamic_shapes)\ntrt_module = torch_tensorrt.dynamo.compile(exp_program, (x3,))\nout3 = trt_module(x3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Verify All Approaches Produce Identical Results\n\nAll three approaches should produce the same numerical results.\nThis verification ensures that dynamic shape handling works correctly\nacross different compilation methods.\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "assert torch.allclose(out1, out2)\nassert torch.allclose(out1, out3)\nassert torch.allclose(out2, out3)\n\nprint(\"All three approaches produced identical results!\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.15" } }, "nbformat": 4, "nbformat_minor": 0 }