Dynamic Shapes#
Created On: May 19, 2023 | Last Updated On: Sep 22, 2025
This section explains how to work with dynamic shapes in PyTorch, including how to debug and fix common errors, implement support for dynamic shapes in operators, and understand the underlying mechanisms.
Dynamic shapes allow PyTorch models to handle inputs with varying dimensions without recompilation. This enables more flexible models that can process different batch sizes, sequence lengths, or image dimensions in a single compiled artifact. Dynamic shapes work by symbolically tracing tensor dimensions rather than using concrete values, creating a computation graph that adapts to different input shapes at runtime. By default, PyTorch assumes all input shapes to be static.
Typically, deep learning compilers only support static shapes, requiring recompilation for input shape changes. While this approach covers many use cases, there are situations where this is insufficient:
Variable Dimensions - Batch sizes or sequence lengths vary, such as in adaptive batching.
Data-Dependent Outputs - Models produce outputs based on input data, like variable bounding boxes in detection models.
Sparse Representations - Processing depends on data-varying sparse structures, such as in sparse tensors, jagged tensors, and graph neural networks.
Dynamic shapes do not support dynamic rank programs, programs which input tensors change in dimensionality, as this is uncommon and unnecessarily complex.
What does it mean for a size/integer to be dynamic?#
Dynamic shapes allow avoiding recompilations by making certain dimensions or integers
dynamic. For example, if a function f(x)
is compiled with a static size, it will need
recompilation for different sizes:
Note
For simplicity, this example uses @torch.compile(dynamic=True)
. Note, that
this option is not recommended due to it being error prone.
For a recommended way of enabling dynamic shapes, see Enabling Dynamic Behavior.
import torch
@torch.compile(dynamic=False)
def f(x):
return x* x.size()[0]
f(torch.rand(10))
f(torch.rand(20))
f(torch.rand(30))
f(torch.rand(40))
TRACED GRAPH
===== __compiled_fn_1_c06521f2_4028_4ccf_873b_e3051963b340 =====
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[10][1]cpu"):
l_x_ = L_x_
# File: /tmp/ipykernel_689/281359623.py:4 in f, code: return x* x.size()[0]
mul: "f32[10][1]cpu" = l_x_ * 10; l_x_ = None
return (mul,)
TRACED GRAPH
===== __compiled_fn_3_4edc764c_89ae_4fbc_b57d_c7b4841ec1bb =====
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[20][1]cpu"):
l_x_ = L_x_
# File: /tmp/ipykernel_689/281359623.py:4 in f, code: return x* x.size()[0]
mul: "f32[20][1]cpu" = l_x_ * 20; l_x_ = None
return (mul,)
TRACED GRAPH
===== __compiled_fn_5_008248ce_24c3_429e_a33b_794e0539b0a8 =====
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[30][1]cpu"):
l_x_ = L_x_
# File: /tmp/ipykernel_689/281359623.py:4 in f, code: return x* x.size()[0]
mul: "f32[30][1]cpu" = l_x_ * 30; l_x_ = None
return (mul,)
TRACED GRAPH
===== __compiled_fn_7_d4129989_be44_46bf_8592_f21e104ec88e =====
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[40][1]cpu"):
l_x_ = L_x_
# File: /tmp/ipykernel_689/281359623.py:4 in f, code: return x* x.size()[0]
mul: "f32[40][1]cpu" = l_x_ * 40; l_x_ = None
return (mul,)
tensor([35.4583, 10.8351, 10.3078, 11.7326, 7.5450, 27.5295, 24.7763, 17.8126,
20.1385, 32.1584, 32.5022, 35.5030, 10.2220, 12.0239, 1.6540, 35.6477,
21.7530, 24.9889, 37.6757, 13.7359, 25.1206, 4.5433, 3.3088, 18.9730,
22.2610, 2.7201, 29.4601, 3.0944, 13.2264, 11.5001, 32.7509, 34.1086,
9.3247, 24.2391, 10.9543, 35.6147, 3.0197, 39.4121, 36.8708, 39.1429])
In the produced output, you can see that four graphs were generated. See the corresponding tlparse output
By making the size dynamic, the function can handle various sizes without recompilation:
import torch
@torch.compile(dynamic=True)
def f(x):
return x* x.size()[0]
f(torch.rand(10))
f(torch.rand(20))
f(torch.rand(30))
f(torch.rand(40))
TRACED GRAPH
===== __compiled_fn_9_bbe5b4ed_aa6c_4f49_99da_5aeb285f720a =====
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77][1]cpu"):
l_x_ = L_x_
# File: /tmp/ipykernel_689/1046103881.py:4 in f, code: return x* x.size()[0]
mul: "f32[s77][1]cpu" = l_x_ * s77; l_x_ = s77 = None
return (mul,)
tensor([ 2.5502, 16.7698, 11.3313, 18.3094, 12.6983, 14.2855, 26.8563, 26.4146,
1.7461, 28.5653, 11.6814, 35.6030, 35.6810, 39.4119, 34.1359, 11.3376,
20.3073, 12.0899, 33.4490, 6.4256, 39.8614, 9.6054, 2.9614, 14.6022,
3.3006, 15.5555, 3.7537, 0.9897, 39.1331, 0.2470, 9.6593, 17.9025,
6.1057, 5.7430, 21.1026, 7.9069, 20.6801, 25.5482, 11.9946, 29.1124])
With dynamic shapes enabled, only one graph is created. See the corresponding tlparse output.
While compilation time differences are minimal for this small example, more complex use cases would show significant performance improvements.
What is a specialization?#
Specialization refers to optimizing a computational graph for specific input shapes by examining shape conditions during control flow. If a branch is taken based on a shape condition, the graph is tailored for that condition. If a new input doesn’t meet this condition, the system will recompile the graph.
Specialization allows you to create optimized computational graphs for specific input shapes, which can significantly improve execution speed.
import torch
@torch.compile(dynamic=True)
def f(x):
if x.size()[0] == 10:
return x * 10
if x.size()[0] <= 30:
return x*200
return x*x.size()[0]
f(torch.rand(10))
f(torch.rand(20))
f(torch.rand(30))
f(torch.rand(40))
f(torch.rand(50))
TRACED GRAPH
===== __compiled_fn_11_36bbe289_b3db_4537_aaef_b7e274fe17cd =====
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[10][1]cpu"):
l_x_ = L_x_
# File: /tmp/ipykernel_689/953537014.py:5 in f, code: return x * 10
mul: "f32[10][1]cpu" = l_x_ * 10; l_x_ = None
return (mul,)
TRACED GRAPH
===== __compiled_fn_13_c7ea3972_1019_4e07_9c66_0df07eeb74f8 =====
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77][1]cpu"):
l_x_ = L_x_
# File: /tmp/ipykernel_689/953537014.py:8 in f, code: return x*200
mul: "f32[s77][1]cpu" = l_x_ * 200; l_x_ = None
return (mul,)
TRACED GRAPH
===== __compiled_fn_15_939f2a5c_8f83_4ce4_946d_d870bbd5989c =====
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77][1]cpu"):
l_x_ = L_x_
# File: /tmp/ipykernel_689/953537014.py:10 in f, code: return x*x.size()[0]
mul: "f32[s77][1]cpu" = l_x_ * s77; l_x_ = s77 = None
return (mul,)
tensor([24.8375, 29.1014, 15.5320, 5.1851, 39.6424, 35.1865, 40.0169, 21.9117,
17.3045, 20.5685, 43.7065, 24.2304, 37.4335, 1.0661, 2.3260, 33.3871,
2.4180, 39.5253, 32.0404, 17.3826, 47.8727, 35.5337, 14.6177, 0.5518,
47.4266, 37.4349, 15.4111, 26.3020, 44.1022, 0.6630, 21.1678, 0.2953,
21.0656, 46.8877, 47.5224, 15.3420, 46.3998, 20.4242, 21.1397, 23.2472,
7.7806, 5.9444, 47.4956, 44.6659, 45.4053, 12.7762, 23.2685, 8.2041,
1.0405, 4.9260])
In the code above, we specialize that the graph requires an input size of 10, in which
case it will return x * 10
. If the input size is less than 30, it will return x * 200
.
In the output, you can see that this creates three graphs.
See the corresponding tlparse output
This is how graphs created for the above function:

Enabling Dynamic Behavior#
There are the following ways to make things dynamic:
User Annotations (preferred)
torch.compile (dynamic=true) (Not recommended) (for testing only)
Advanced Options to Control Dynamic Behavior (for advanced use cases)
Read below about each of this options.
Automatic dynamic#
Automatic dynamic is the default behavior where torch.compile()
performs
the initial compilation assuming static shapes are used, while tracking the
input sizes from that first compilation. When a recompile is triggered, it
uses this information to identify which dimensions have changed and marks
those as dynamic for the second compilation.
User Annotations#
Several APIs allow users to explicitly mark specific inputs by name or code as dynamic. This is useful for avoiding initial compilations that would eventually become dynamic with the previous tools. It is also used to mark elements that do not automatically get marked as dynamic, such as neural network module parameters, and so on. User annotations are the preferred way to enable dynamic shapes.
mark_dynamic(tensor, dim, min=min, max=max)
#
The torch._dynamo.mark_dynamic()
function marks a tensor dimension as dynamic and will fail if it
gets specialized. It does not work for integers. Use this function only if you know
all graphs in the frame using this input converge to a single dynamic graph.
Otherwise, you may encounter a misleading constraint violation error.
In such cases, consider using torch._dynamo.maybe_mark_dynamic()
. Currently,
torch._dynamo.mark_dynamic()
does not have precedence over force_parameter_static_shapes = True
or force_nn_module_property_static_shapes = True
.
If you know in advance that a particular dimension will be dynamic, you
can avoid the initial recompilation by using torch._dynamo.mark_dynamic(tensor, dim)()
.
Additionally, if you already know the minimum and maximum possible
values for this dimension, you can specify them with
torch._dynamo.mark_dynamic(tensor, dim, min=min, max=max)()
.
Here is a quick example:
import torch
@torch.compile(dynamic=True)
def f(x):
return x * x.size()[0]
x = torch.randn(10)
torch._dynamo.mark_dynamic(x, 0)
# first invocation we give it is a tensor marked as dynamic
f(x)
# rest of these invocations will use dynamically compiled code
f(torch.randn(20))
f(torch.randn(30))
f(torch.randn(40))
TRACED GRAPH
===== __compiled_fn_17_6ff81c2d_5b7c_4799_b885_d74af8b5a10c =====
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77][1]cpu"):
l_x_ = L_x_
# File: /tmp/ipykernel_689/3760279969.py:5 in f, code: return x * x.size()[0]
mul: "f32[s77][1]cpu" = l_x_ * s77; l_x_ = s77 = None
return (mul,)
tensor([ 0.3497, 28.3275, 16.2943, -87.0090, 45.5919, 38.0078, -47.3709,
67.1037, -61.7345, -35.6715, -67.5932, 48.3146, -3.0090, -40.8550,
10.2011, 15.8710, 36.5330, 24.4031, 52.1428, 27.1535, 3.3862,
18.8536, 17.7592, 39.5597, 72.5149, -26.2967, 66.1949, -18.0855,
3.9902, 8.1664, -72.0031, 1.4079, 57.0317, 84.5466, -70.5541,
14.3071, 20.9651, 2.5737, -70.8189, -16.3221])
maybe_mark_dynamic(tensor, dim)
#
The torch._dynamo.maybe_mark_dynamic()
function shares all properties
with torch._dynamo.mark_dynamic()
but does not fail if the size gets specialized. Use it for inputs shared by
multiple graphs or if the number of graphs does not converge to one for a specific
frame. For instance, in the example above, use torch._dynamo.maybe_mark_dynamic()
because graphs
with sizes 0 and 1 will specialize. However, you can use torch._dynamo.mark_dynamic()
to ensure
you never specialize.
mark_unbacked(tensor, dim)
#
The torch._dynamo.mark_unbacked()
function marks a tensor dimension as unbacked. It is unlikely
to be the tool you need, but it could be useful if the specialization occurs inside
a condition guard_size_oblivious(x)
, and if using it removes the specialization.
Ensure it fixes the specialization and does not introduce a data-dependent error
that converts to a graph break at or before the specialization location
you are trying to avoid. It might be better to use the next option.
Dynamic Allow List (DYNAMIC_SOURCES
)#
Use the evnironmental variable TORCH_COMPILE_DYNAMIC_SOURCES
to pass a configuration
list of source names to be marked as dynamic. For example:
TORCH_COMPILE_DYNAMIC_SOURCES=L[‘x’],L[‘y’]
It’s easiest to find these dynamic source names using the PGO artifact in tlparse
.
You can copy and paste the dynamic source names from the PGO artifact. This method works
for integers and tensor sizes and has the highest precedence over all other flags
that force static shapes. It will not throw an error if what is marked dynamic
gets specialized or if the provided input does not exist.
Here is an example:
import torch
@torch.compile()
def f(x):
return x * x.size()[0]
with torch.compiler.config.patch(dynamic_sources="L['x']"):
f(torch.rand(10))
f(torch.rand(20))
f(torch.rand(30))
f(torch.rand(40))
TRACED GRAPH
===== __compiled_fn_19_661f590e_dda3_467e_b689_40839f64b2ec =====
/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, s77: "Sym(s77)", L_x_: "f32[s77][1]cpu"):
l_x_ = L_x_
# File: /tmp/ipykernel_689/2867773694.py:5 in f, code: return x * x.size()[0]
mul: "f32[s77][1]cpu" = l_x_ * s77; l_x_ = s77 = None
return (mul,)
tensor([38.9773, 15.7884, 3.5391, 24.3714, 3.4340, 21.3504, 4.0010, 14.3519,
23.5550, 1.5677, 28.2817, 22.6811, 6.6651, 4.8821, 0.8067, 36.5888,
16.2779, 11.0907, 31.3330, 32.9416, 7.4523, 26.8163, 12.6504, 26.5602,
30.5523, 18.8186, 23.0794, 21.4367, 33.1079, 6.2789, 32.0718, 23.5863,
33.8042, 23.3560, 39.3990, 2.5001, 20.1213, 9.1935, 39.4230, 39.1510])
torch.compiler.set_stance ("eager_then_compile")
#
At times, identifying the appropriate inputs to mark as dynamic can
be challenging. If you are willing to accept a performance cost for
the first batch, another convenient option is to use the
eager_then_compile
stances, which automatically determine dynamic
inputs for you. For more information, see torch.compiler.set_stance()
and Dynamic Compilation Control with torch.compiler.set_stance.
torch.compile (dynamic=true)
(Not recommended)#
This setting forces all sizes and integers to be dynamic, increasing the chance of encountering dynamic shape bugs. Setting this option is not recommended due to it being error prone. It would make every input size dynamic which may result it performance regressions and ultimately increase compilation time.
PyTorch also provides advanced control options for dynamic shapes, see: Advanced Options to Control Dynamic Behavior.
Where Do I Go From Here?#
If you encounter a framework code bug or an issue with specialization,
file an issue so it can be reviewed and potentially improved. If the issue
is within your user code, consider whether you are willing to rewrite your
code to avoid it. Determine if it affects correctness or if it’s a redundant
check. If the issue involves a Triton custom kernel with a constexpr
argument, evaluate whether you can rewrite it to address the problem.