• Docs >
  • Compiling Models with Dynamic Input Shapes
Shortcuts

Compiling Models with Dynamic Input Shapes

Dynamic shapes are essential when your model needs to handle varying batch sizes or sequence lengths at inference time without recompilation.

The example uses a Vision Transformer-style model with expand and reshape operations, which are common patterns that benefit from dynamic shape handling.

Imports and Model Definition

import logging

import torch
import torch.nn as nn
import torch_tensorrt

logging.basicConfig(level=logging.DEBUG)

torch.manual_seed(0)
# Define a model with expand and reshape operations
# This is a simplified Vision Transformer pattern with:
# - A learnable class token that needs to expand to match batch size
# - A QKV projection followed by reshaping for multi-head attention
class ExpandReshapeModel(nn.Module):
    def __init__(self, embed_dim: int):
        super().__init__()
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.embed_dim = embed_dim
        self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3)

    def forward(self, x: torch.Tensor):
        batch_size = x.shape[0]
        cls_token = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        x = self.qkv_proj(x)
        reshaped_qkv = x.reshape(batch_size, x.size(1), 3, 12, -1)
        return reshaped_qkv


model = ExpandReshapeModel(embed_dim=768).cuda().eval()
x = torch.randn(4, 196, 768).cuda()

Approach 1: JIT Compilation with torch.compile

The first approach uses PyTorch’s torch.compile with the TensorRT backend. This is a Just-In-Time (JIT) compilation method where the model is compiled during the first inference call.

Key points:

  • Use torch._dynamo.mark_dynamic() to specify which dimensions are dynamic

  • The index parameter indicates which dimension (0 = batch dimension)

  • Provide min and max bounds for the dynamic dimension

  • The model will work for any batch size within the specified range

x1 = x.clone()
torch._dynamo.mark_dynamic(x1, index=0, min=2, max=32)
trt_module = torch.compile(model, backend="tensorrt")
out1 = trt_module(x1)

Approach 2: AOT Compilation with torch_tensorrt.compile

The second approach uses Ahead-Of-Time (AOT) compilation with torch_tensorrt.compile. This compiles the model upfront before inference.

Key points:

  • Use torch_tensorrt.Input() to specify dynamic shape ranges

  • Provide min_shape, opt_shape, and max_shape for each input

  • The opt_shape is used for optimization and should represent typical input sizes

  • Set ir=”dynamo” to use the Dynamo frontend

x2 = x.clone()
example_input = torch_tensorrt.Input(
    min_shape=[1, 196, 768],
    opt_shape=[4, 196, 768],
    max_shape=[32, 196, 768],
    dtype=torch.float32,
)
trt_module = torch_tensorrt.compile(model, ir="dynamo", inputs=example_input)
out2 = trt_module(x2)

Approach 3: AOT with torch.export + Dynamo Compile

The third approach uses PyTorch 2.0’s torch.export API combined with Torch-TensorRT’s Dynamo compiler. This provides the most explicit control over dynamic shapes.

Key points:

  • Use torch.export.Dim() to define symbolic dimensions with constraints

  • Create a dynamic_shapes dictionary mapping inputs to their dynamic dimensions

  • Export the model to an ExportedProgram with these constraints

  • Compile the exported program with torch_tensorrt.dynamo.compile

x3 = x.clone()
bs = torch.export.Dim("bs", min=1, max=32)
dynamic_shapes = {"x": {0: bs}}
exp_program = torch.export.export(model, (x3,), dynamic_shapes=dynamic_shapes)
trt_module = torch_tensorrt.dynamo.compile(exp_program, (x3,))
out3 = trt_module(x3)

Verify All Approaches Produce Identical Results

All three approaches should produce the same numerical results. This verification ensures that dynamic shape handling works correctly across different compilation methods.

assert torch.allclose(out1, out2)
assert torch.allclose(out1, out3)
assert torch.allclose(out2, out3)

print("All three approaches produced identical results!")

Total running time of the script: ( 0 minutes 0.000 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources