Note
Go to the end to download the full example code
Compiling Models with Dynamic Input Shapes¶
Dynamic shapes are essential when your model needs to handle varying batch sizes or sequence lengths at inference time without recompilation.
The example uses a Vision Transformer-style model with expand and reshape operations, which are common patterns that benefit from dynamic shape handling.
Imports and Model Definition¶
import logging
import torch
import torch.nn as nn
import torch_tensorrt
logging.basicConfig(level=logging.DEBUG)
torch.manual_seed(0)
# Define a model with expand and reshape operations
# This is a simplified Vision Transformer pattern with:
# - A learnable class token that needs to expand to match batch size
# - A QKV projection followed by reshaping for multi-head attention
class ExpandReshapeModel(nn.Module):
def __init__(self, embed_dim: int):
super().__init__()
self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
self.embed_dim = embed_dim
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3)
def forward(self, x: torch.Tensor):
batch_size = x.shape[0]
cls_token = self.cls_token.expand(batch_size, -1, -1)
x = torch.cat([cls_token, x], dim=1)
x = self.qkv_proj(x)
reshaped_qkv = x.reshape(batch_size, x.size(1), 3, 12, -1)
return reshaped_qkv
model = ExpandReshapeModel(embed_dim=768).cuda().eval()
x = torch.randn(4, 196, 768).cuda()
Approach 1: JIT Compilation with torch.compile¶
The first approach uses PyTorch’s torch.compile with the TensorRT backend. This is a Just-In-Time (JIT) compilation method where the model is compiled during the first inference call.
Key points:
Use torch._dynamo.mark_dynamic() to specify which dimensions are dynamic
The index parameter indicates which dimension (0 = batch dimension)
Provide min and max bounds for the dynamic dimension
The model will work for any batch size within the specified range
x1 = x.clone()
torch._dynamo.mark_dynamic(x1, index=0, min=2, max=32)
trt_module = torch.compile(model, backend="tensorrt")
out1 = trt_module(x1)
Approach 2: AOT Compilation with torch_tensorrt.compile¶
The second approach uses Ahead-Of-Time (AOT) compilation with torch_tensorrt.compile. This compiles the model upfront before inference.
Key points:
Use torch_tensorrt.Input() to specify dynamic shape ranges
Provide min_shape, opt_shape, and max_shape for each input
The opt_shape is used for optimization and should represent typical input sizes
Set ir=”dynamo” to use the Dynamo frontend
x2 = x.clone()
example_input = torch_tensorrt.Input(
min_shape=[1, 196, 768],
opt_shape=[4, 196, 768],
max_shape=[32, 196, 768],
dtype=torch.float32,
)
trt_module = torch_tensorrt.compile(model, ir="dynamo", inputs=example_input)
out2 = trt_module(x2)
Approach 3: AOT with torch.export + Dynamo Compile¶
The third approach uses PyTorch 2.0’s torch.export API combined with Torch-TensorRT’s Dynamo compiler. This provides the most explicit control over dynamic shapes.
Key points:
Use torch.export.Dim() to define symbolic dimensions with constraints
Create a dynamic_shapes dictionary mapping inputs to their dynamic dimensions
Export the model to an ExportedProgram with these constraints
Compile the exported program with torch_tensorrt.dynamo.compile
x3 = x.clone()
bs = torch.export.Dim("bs", min=1, max=32)
dynamic_shapes = {"x": {0: bs}}
exp_program = torch.export.export(model, (x3,), dynamic_shapes=dynamic_shapes)
trt_module = torch_tensorrt.dynamo.compile(exp_program, (x3,))
out3 = trt_module(x3)
Verify All Approaches Produce Identical Results¶
All three approaches should produce the same numerical results. This verification ensures that dynamic shape handling works correctly across different compilation methods.
assert torch.allclose(out1, out2)
assert torch.allclose(out1, out3)
assert torch.allclose(out2, out3)
print("All three approaches produced identical results!")
Total running time of the script: ( 0 minutes 0.000 seconds)