Sharing Dynamic Dimensions Across Inputs#
When a model takes multiple inputs whose dynamic axes must be equal at runtime — for example, HuggingFace-style encoders where input_ids and attention_mask are both shaped [batch, seq_len] — naively assigning an independent dynamic dimension to each input causes torch.export to raise a ConstraintViolationError. The exporter detects that the two independent symbols are forced equal by the model’s forward pass (e.g. a broadcast) and rejects the export.
torch_tensorrt.Input(shared_dims={axis: name}) solves this: axes that share the same name across inputs are exported as a single torch.export.Dim, so the equality constraint is satisfied automatically. All dynamic-shape intent lives on the Input objects — no separate dynamic_shapes argument or torch.export knowledge is required at the call site.
Imports and Model Definition#
[ ]:
import torch
import torch.nn as nn
import torch_tensorrt
from torch_tensorrt.dynamo.utils import COSINE_THRESHOLD, cosine_similarity
Define a HuggingFace-style encoder whose two inputs share the batch axis. The embed * mask broadcast forces input_ids.shape[0] == attention_mask.shape[0] at every forward call — exactly the pattern that triggers ConstraintViolationError when the batch axis is exported as two independent Dim objects.
[ ]:
class SharedDimEncoder(nn.Module):
def __init__(self, vocab: int = 1024, hidden: int = 64):
super().__init__()
self.embed = nn.Embedding(vocab, hidden)
self.proj = nn.Linear(hidden, hidden)
def forward(self, input_ids: torch.Tensor, attention_mask: torch.Tensor):
x = self.embed(input_ids) # [B, S, hidden]
mask = attention_mask.unsqueeze(-1).to(x.dtype) # [B, S, 1]
return self.proj(x * mask) # [B, S, hidden]
model = SharedDimEncoder().cuda().eval()
Sharing multiple axes#
If both batch and sequence length are dynamic and must be shared, annotate both axes on each input:
shared_dims={0: "B", 1: "S"}
The same name on the same axis across different inputs produces one shared Dim; different names on different axes produce independent Dim\s.
[ ]:
print("\nAll checks passed.")