.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "intermediate/torch_export_tutorial.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_intermediate_torch_export_tutorial.py: torch.export Tutorial =================================================== **Author:** William Wen, Zhengxu Chen, Angela Yi, Pian Pawakapan .. GENERATED FROM PYTHON SOURCE LINES 10-29 .. warning:: ``torch.export`` and its related features are in prototype status and are subject to backwards compatibility breaking changes. This tutorial provides a snapshot of ``torch.export`` usage as of PyTorch 2.5. :func:`torch.export` is the PyTorch 2.X way to export PyTorch models into standardized model representations, intended to be run on different (i.e. Python-less) environments. The official documentation can be found `here `__. In this tutorial, you will learn how to use :func:`torch.export` to extract ``ExportedProgram``'s (i.e. single-graph representations) from PyTorch programs. We also detail some considerations/modifications that you may need to make in order to make your model compatible with ``torch.export``. **Contents** .. contents:: :local: .. GENERATED FROM PYTHON SOURCE LINES 32-61 Basic Usage ----------- ``torch.export`` extracts single-graph representations from PyTorch programs by tracing the target function, given example inputs. ``torch.export.export()`` is the main entry point for ``torch.export``. In this tutorial, ``torch.export`` and ``torch.export.export()`` are practically synonymous, though ``torch.export`` generally refers to the PyTorch 2.X export process, and ``torch.export.export()`` generally refers to the actual function call. The signature of ``torch.export.export()`` is: .. code-block:: python export( mod: torch.nn.Module, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None, *, dynamic_shapes: Optional[Dict[str, Dict[int, Dim]]] = None ) -> ExportedProgram ``torch.export.export()`` traces the tensor computation graph from calling ``mod(*args, **kwargs)`` and wraps it in an ``ExportedProgram``, which can be serialized or executed later with different inputs. To execute the ``ExportedProgram`` we can call ``.module()`` on it to return a ``torch.nn.Module`` which is callable, just like the original program. We will detail the ``dynamic_shapes`` argument later in the tutorial. .. GENERATED FROM PYTHON SOURCE LINES 61-79 .. code-block:: Python import torch from torch.export import export class MyModule(torch.nn.Module): def __init__(self): super().__init__() self.lin = torch.nn.Linear(100, 10) def forward(self, x, y): return torch.nn.functional.relu(self.lin(x + y), inplace=True) mod = MyModule() exported_mod = export(mod, (torch.randn(8, 100), torch.randn(8, 100))) print(type(exported_mod)) print(exported_mod.module()(torch.randn(8, 100), torch.randn(8, 100))) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor([[0.0000, 0.0000, 0.0753, 0.4761, 0.0000, 0.0000, 0.0000, 0.0816, 0.0692, 0.0000], [0.0000, 1.1218, 1.2616, 0.0000, 0.0000, 0.0000, 0.8918, 0.0000, 0.0000, 0.0000], [1.2681, 0.0000, 1.3830, 0.0000, 0.7240, 0.0804, 0.0000, 0.0000, 0.3206, 0.0000], [0.4810, 0.0000, 0.0000, 0.4256, 1.5272, 0.0000, 0.8898, 0.0000, 1.1272, 0.3875], [0.0000, 0.0000, 0.0000, 0.6868, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.1212, 1.3065, 0.0000, 1.0019, 0.3537, 0.0000, 0.0000, 0.0036], [0.2707, 0.0000, 0.6322, 0.0000, 0.0000, 0.1343, 0.2890, 0.0000, 0.0000, 0.0000], [0.4347, 0.0000, 0.0000, 0.0000, 0.0000, 0.2712, 0.0000, 0.0000, 0.7067, 0.1505]], grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 80-91 Let's review some attributes of ``ExportedProgram`` that are of interest. The ``graph`` attribute is an `FX graph `__ traced from the function we exported, that is, the computation graph of all PyTorch operations. The FX graph is in "ATen IR" meaning that it contains only "ATen-level" operations. The ``graph_signature`` attribute gives a more detailed description of the input and output nodes in the exported graph, describing which ones are parameters, buffers, user inputs, or user outputs. The ``range_constraints`` attributes will be covered later. .. GENERATED FROM PYTHON SOURCE LINES 91-94 .. code-block:: Python print(exported_mod) .. rst-class:: sphx-glr-script-out .. code-block:: none ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, p_lin_weight: "f32[10, 100]", p_lin_bias: "f32[10]", x: "f32[8, 100]", y: "f32[8, 100]"): # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:71 in forward, code: return torch.nn.functional.relu(self.lin(x + y), inplace=True) add: "f32[8, 100]" = torch.ops.aten.add.Tensor(x, y); x = y = None linear: "f32[8, 10]" = torch.ops.aten.linear.default(add, p_lin_weight, p_lin_bias); add = p_lin_weight = p_lin_bias = None relu_: "f32[8, 10]" = torch.ops.aten.relu_.default(linear); linear = None return (relu_,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='p_lin_weight'), target='lin.weight', persistent=None), InputSpec(kind=, arg=TensorArgument(name='p_lin_bias'), target='lin.bias', persistent=None), InputSpec(kind=, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='relu_'), target=None)]) Range constraints: {} .. GENERATED FROM PYTHON SOURCE LINES 95-97 See the ``torch.export`` `documentation `__ for more details. .. GENERATED FROM PYTHON SOURCE LINES 99-113 Graph Breaks ------------ Although ``torch.export`` shares components with ``torch.compile``, the key limitation of ``torch.export``, especially when compared to ``torch.compile``, is that it does not support graph breaks. This is because handling graph breaks involves interpreting the unsupported operation with default Python evaluation, which is incompatible with the export use case. Therefore, in order to make your model code compatible with ``torch.export``, you will need to modify your code to remove graph breaks. A graph break is necessary in cases such as: - data-dependent control flow .. GENERATED FROM PYTHON SOURCE LINES 113-126 .. code-block:: Python class Bad1(torch.nn.Module): def forward(self, x): if x.sum() > 0: return torch.sin(x) return torch.cos(x) import traceback as tb try: export(Bad1(), (torch.randn(3, 3),)) except Exception: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[3, 3][3, 1]cpu"): l_x_ = L_x_ # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:116 in forward, code: if x.sum() > 0: sum_1: "f32[][]cpu" = l_x_.sum(); l_x_ = None gt: "b8[][]cpu" = sum_1 > 0; sum_1 = gt = None Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 122, in export(Bad1(), (torch.randn(3, 3),)) File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 360, in export return _export( File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1092, in wrapper raise e File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1065, in wrapper ep = fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 121, in wrapper return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2112, in _export ep = _export_for_training( File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1092, in wrapper raise e File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1065, in wrapper ep = fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 121, in wrapper return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1975, in _export_for_training export_artifact = export_func( # type: ignore[operator] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1344, in _strict_export_lower_to_aten_ir gm_torch_level = _export_to_torch_ir( File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 739, in _export_to_torch_ir gm_torch_level, _ = torch._dynamo.export( File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1677, in inner result_traced = opt_f(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 659, in _fn raise e.with_traceback(None) from None torch._dynamo.exc.Unsupported: Data-dependent branching Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow. Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround. Hint: Use `torch.cond` to express dynamic control flow. Developer debug context: attempted to jump with TensorVariable() from user code: File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 116, in forward if x.sum() > 0: Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo" .. GENERATED FROM PYTHON SOURCE LINES 127-128 - accessing tensor data with ``.data`` .. GENERATED FROM PYTHON SOURCE LINES 128-139 .. code-block:: Python class Bad2(torch.nn.Module): def forward(self, x): x.data[0, 0] = 3 return x try: export(Bad2(), (torch.randn(3, 3),)) except Exception: tb.print_exc() .. GENERATED FROM PYTHON SOURCE LINES 140-141 - calling unsupported functions (such as many built-in functions) .. GENERATED FROM PYTHON SOURCE LINES 141-153 .. code-block:: Python class Bad3(torch.nn.Module): def forward(self, x): x = x + 1 return x + id(x) try: export(Bad3(), (torch.randn(3, 3),)) except Exception: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 148, in export(Bad3(), (torch.randn(3, 3),)) File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 360, in export return _export( File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1092, in wrapper raise e File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1065, in wrapper ep = fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 121, in wrapper return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2112, in _export ep = _export_for_training( File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1092, in wrapper raise e File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1065, in wrapper ep = fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 121, in wrapper return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1975, in _export_for_training export_artifact = export_func( # type: ignore[operator] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1344, in _strict_export_lower_to_aten_ir gm_torch_level = _export_to_torch_ir( File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 739, in _export_to_torch_ir gm_torch_level, _ = torch._dynamo.export( File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1677, in inner result_traced = opt_f(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 659, in _fn raise e.with_traceback(None) from None torch._dynamo.exc.Unsupported: call_id not supported for sourceless TensorVariable from user code: File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 145, in forward return x + id(x) Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo" .. GENERATED FROM PYTHON SOURCE LINES 154-170 Non-Strict Export ----------------- To trace the program, ``torch.export`` uses TorchDynamo by default, a byte code analysis engine, to symbolically analyze the Python code and build a graph based on the results. This analysis allows ``torch.export`` to provide stronger guarantees about safety, but not all Python code is supported, causing these graph breaks. To address this issue, in PyTorch 2.3, we introduced a new mode of exporting called non-strict mode, where we trace through the program using the Python interpreter executing it exactly as it would in eager mode, allowing us to skip over unsupported Python features. This is done through adding a ``strict=False`` flag. Looking at some of the previous examples which resulted in graph breaks: .. GENERATED FROM PYTHON SOURCE LINES 172-176 - Calling unsupported functions (such as many built-in functions) traces through, but in this case, ``id(x)`` gets specialized as a constant integer in the graph. This is because ``id(x)`` is not a tensor operation, so the operation is not recorded in the graph. .. GENERATED FROM PYTHON SOURCE LINES 176-187 .. code-block:: Python class Bad3(torch.nn.Module): def forward(self, x): x = x + 1 return x + id(x) bad3_nonstrict = export(Bad3(), (torch.randn(3, 3),), strict=False) print(bad3_nonstrict) print(bad3_nonstrict.module()(torch.ones(3, 3))) .. rst-class:: sphx-glr-script-out .. code-block:: none ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 3]"): # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:179 in forward, code: x = x + 1 add: "f32[3, 3]" = torch.ops.aten.add.Tensor(x, 1); x = None # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:180 in forward, code: return x + id(x) add_1: "f32[3, 3]" = torch.ops.aten.add.Tensor(add, 140309147538416); add = None return (add_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='add_1'), target=None)]) Range constraints: {} tensor([[1.4031e+14, 1.4031e+14, 1.4031e+14], [1.4031e+14, 1.4031e+14, 1.4031e+14], [1.4031e+14, 1.4031e+14, 1.4031e+14]]) .. GENERATED FROM PYTHON SOURCE LINES 188-190 However, there are still some features that require rewrites to the original module: .. GENERATED FROM PYTHON SOURCE LINES 192-198 Control Flow Ops ---------------- ``torch.export`` actually does support data-dependent control flow. But these need to be expressed using control flow ops. For example, we can fix the control flow example above using the ``cond`` op, like so: .. GENERATED FROM PYTHON SOURCE LINES 198-212 .. code-block:: Python class Bad1Fixed(torch.nn.Module): def forward(self, x): def true_fn(x): return torch.sin(x) def false_fn(x): return torch.cos(x) return torch.cond(x.sum() > 0, true_fn, false_fn, [x]) exported_bad1_fixed = export(Bad1Fixed(), (torch.randn(3, 3),)) print(exported_bad1_fixed) print(exported_bad1_fixed.module()(torch.ones(3, 3))) print(exported_bad1_fixed.module()(-torch.ones(3, 3))) .. rst-class:: sphx-glr-script-out .. code-block:: none ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 3]"): # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:205 in forward, code: return torch.cond(x.sum() > 0, true_fn, false_fn, [x]) sum_1: "f32[]" = torch.ops.aten.sum.default(x) gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 0); sum_1 = None # File: /usr/local/lib/python3.10/dist-packages/torch/_higher_order_ops/cond.py:137 in cond, code: return cond_op(pred, true_fn, false_fn, operands) true_graph_0 = self.true_graph_0 false_graph_0 = self.false_graph_0 cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x]); gt = true_graph_0 = false_graph_0 = x = None getitem: "f32[3, 3]" = cond[0]; cond = None return (getitem,) class true_graph_0(torch.nn.Module): def forward(self, x: "f32[3, 3]"): # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:202 in true_fn, code: return torch.sin(x) sin: "f32[3, 3]" = torch.ops.aten.sin.default(x); x = None return (sin,) class false_graph_0(torch.nn.Module): def forward(self, x: "f32[3, 3]"): # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:204 in false_fn, code: return torch.cos(x) cos: "f32[3, 3]" = torch.ops.aten.cos.default(x); x = None return (cos,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None)]) Range constraints: {} tensor([[0.8415, 0.8415, 0.8415], [0.8415, 0.8415, 0.8415], [0.8415, 0.8415, 0.8415]]) tensor([[0.5403, 0.5403, 0.5403], [0.5403, 0.5403, 0.5403], [0.5403, 0.5403, 0.5403]]) .. GENERATED FROM PYTHON SOURCE LINES 213-224 There are limitations to ``cond`` that one should be aware of: - The predicate (i.e. ``x.sum() > 0``) must result in a boolean or a single-element tensor. - The operands (i.e. ``[x]``) must be tensors. - The branch function (i.e. ``true_fn`` and ``false_fn``) signature must match with the operands and they must both return a single tensor with the same metadata (for example, ``dtype``, ``shape``, etc.). - Branch functions cannot mutate input or global variables. - Branch functions cannot access closure variables, except for ``self`` if the function is defined in the scope of a method. For more details about ``cond``, check out the `cond documentation `__. .. GENERATED FROM PYTHON SOURCE LINES 226-228 We can also use ``map``, which applies a function across the first dimension of the first tensor argument. .. GENERATED FROM PYTHON SOURCE LINES 228-243 .. code-block:: Python from torch._higher_order_ops.map import map as torch_map class MapModule(torch.nn.Module): def forward(self, xs, y, z): def body(x, y, z): return x + y + z return torch_map(body, xs, y, z) inps = (torch.ones(6, 4), torch.tensor(5), torch.tensor(4)) exported_map_example = export(MapModule(), inps) print(exported_map_example) print(exported_map_example.module()(*inps)) .. rst-class:: sphx-glr-script-out .. code-block:: none ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, xs: "f32[6, 4]", y: "i64[]", z: "i64[]"): # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:236 in forward, code: return torch_map(body, xs, y, z) body_graph_0 = self.body_graph_0 map_impl = torch.ops.higher_order.map_impl(body_graph_0, [xs], [y, z]); body_graph_0 = xs = y = z = None getitem: "f32[6, 4]" = map_impl[0]; map_impl = None return (getitem,) class body_graph_0(torch.nn.Module): def forward(self, xs: "f32[4]", y: "i64[]", z: "i64[]"): # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:234 in body, code: return x + y + z add: "f32[4]" = torch.ops.aten.add.Tensor(xs, y); xs = y = None add_1: "f32[4]" = torch.ops.aten.add.Tensor(add, z); add = z = None return (add_1,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='xs'), target=None, persistent=None), InputSpec(kind=, arg=TensorArgument(name='y'), target=None, persistent=None), InputSpec(kind=, arg=TensorArgument(name='z'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='getitem'), target=None)]) Range constraints: {} tensor([[10., 10., 10., 10.], [10., 10., 10., 10.], [10., 10., 10., 10.], [10., 10., 10., 10.], [10., 10., 10., 10.], [10., 10., 10., 10.]]) .. GENERATED FROM PYTHON SOURCE LINES 244-247 Other control flow ops include ``while_loop``, ``associative_scan``, and ``scan``. For more documentation on each operator, please refer to `this page `__. .. GENERATED FROM PYTHON SOURCE LINES 249-255 Constraints/Dynamic Shapes -------------------------- This section covers dynamic behavior and representation of exported programs. Dynamic behavior is subjective to the particular model being exported, so for the most part of this tutorial, we'll focus on this particular toy model (with the resulting tensor shapes annotated): .. GENERATED FROM PYTHON SOURCE LINES 255-274 .. code-block:: Python class DynamicModel(torch.nn.Module): def __init__(self): super().__init__() self.l = torch.nn.Linear(5, 3) def forward( self, w: torch.Tensor, # [6, 5] x: torch.Tensor, # [4] y: torch.Tensor, # [8, 4] z: torch.Tensor, # [32] ): x0 = x + y # [8, 4] x1 = self.l(w) # [6, 3] x2 = x0.flatten() # [32] x3 = x2 + z # [32] return x1, x3 .. GENERATED FROM PYTHON SOURCE LINES 275-277 By default, ``torch.export`` produces a static program. One consequence of this is that at runtime, the program won't work on inputs with different shapes, even if they're valid in eager mode. .. GENERATED FROM PYTHON SOURCE LINES 277-290 .. code-block:: Python w = torch.randn(6, 5) x = torch.randn(4) y = torch.randn(8, 4) z = torch.randn(32) model = DynamicModel() ep = export(model, (w, x, y, z)) model(w, x, torch.randn(3, 4), torch.randn(12)) try: ep.module()(w, x, torch.randn(3, 4), torch.randn(12)) except Exception: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 286, in ep.module()(w, x, torch.randn(3, 4), torch.randn(12)) File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 830, in call_wrapped return self._wrapped_call(self, *args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 406, in __call__ raise e File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 393, in __call__ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1857, in _call_impl return inner() File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1784, in inner args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 838, in _fn return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/_unlift.py", line 55, in _check_input_constraints_pre_hook _check_input_constraints_for_graph( File "/usr/local/lib/python3.10/dist-packages/torch/_export/utils.py", line 398, in _check_input_constraints_for_graph raise RuntimeError( RuntimeError: Expected input at *args[2].shape[0] to be equal to 8, but got 3 .. GENERATED FROM PYTHON SOURCE LINES 291-297 Basic concepts: symbols and guards ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ To enable dynamism, ``export()`` provides a ``dynamic_shapes`` argument. The easiest way to work with dynamic shapes is using ``Dim.AUTO`` and looking at the program that's returned. Dynamic behavior is specified at a input dimension-level; for each input we can specify a tuple of values: .. GENERATED FROM PYTHON SOURCE LINES 297-308 .. code-block:: Python from torch.export.dynamic_shapes import Dim dynamic_shapes = { "w": (Dim.AUTO, Dim.AUTO), "x": (Dim.AUTO,), "y": (Dim.AUTO, Dim.AUTO), "z": (Dim.AUTO,), } ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) .. GENERATED FROM PYTHON SOURCE LINES 309-323 Before we look at the program that's produced, let's understand what specifying ``dynamic_shapes`` entails, and how that interacts with export. For every input dimension where a ``Dim`` object is specified, a symbol is `allocated `_, taking on a range of ``[2, inf]`` (why not ``[0, inf]`` or ``[1, inf]``? we'll explain later in the 0/1 specialization section). Export then runs model tracing, looking at each operation that's performed by the model. Each individual operation can emit what's called "guards"; basically boolean condition that are required to be true for the program to be valid. When guards involve symbols allocated for input dimensions, the program contains restrictions on what input shapes are valid; i.e. the program's dynamic behavior. The symbolic shapes subsystem is the part responsible for taking in all the emitted guards and producing a final program representation that adheres to all of these guards. Before we see this "final representation" in an ``ExportedProgram``, let's look at the guards emitted by the toy model we're tracing. Here, each forward input tensor is annotated with the symbol allocated at the start of tracing: .. GENERATED FROM PYTHON SOURCE LINES 323-342 .. code-block:: Python class DynamicModel(torch.nn.Module): def __init__(self): super().__init__() self.l = torch.nn.Linear(5, 3) def forward( self, w: torch.Tensor, # [s0, s1] x: torch.Tensor, # [s2] y: torch.Tensor, # [s3, s4] z: torch.Tensor, # [s5] ): x0 = x + y # guard: s2 == s4 x1 = self.l(w) # guard: s1 == 5 x2 = x0.flatten() # no guard added here x3 = x2 + z # guard: s3 * s4 == s5 return x1, x3 .. GENERATED FROM PYTHON SOURCE LINES 343-360 Let's understand each of the operations and the emitted guards: - ``x0 = x + y``: This is an element-wise add with broadcasting, since ``x`` is a 1-d tensor and ``y`` a 2-d tensor. ``x`` is broadcasted along the last dimension of ``y``, emitting the guard ``s2 == s4``. - ``x1 = self.l(w)``: Calling ``nn.Linear()`` performs a matrix multiplication with model parameters. In export, parameters, buffers, and constants are considered program state, which is considered static, and so this is a matmul between a dynamic input (``w: [s0, s1]``), and a statically-shaped tensor. This emits the guard ``s1 == 5``. - ``x2 = x0.flatten()``: This call actually doesn't emit any guards! (at least none relevant to input shapes) - ``x3 = x2 + z``: ``x2`` has shape ``[s3*s4]`` after flattening, and this element-wise add emits ``s3 * s4 == s5``. Writing all of these guards down and summarizing is almost like a mathematical proof, which is what the symbolic shapes subsystem tries to do! In summary, we can conclude that the program must have the following input shapes to be valid: - ``w: [s0, 5]`` - ``x: [s2]`` - ``y: [s3, s2]`` - ``z: [s2*s3]`` And when we do finally print out the exported program to see our result, those shapes are what we see annotated on the corresponding inputs: .. GENERATED FROM PYTHON SOURCE LINES 360-363 .. code-block:: Python print(ep) .. rst-class:: sphx-glr-script-out .. code-block:: none ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, p_l_weight: "f32[3, 5]", p_l_bias: "f32[3]", w: "f32[s0, 5]", x: "f32[s2]", y: "f32[s3, s2]", z: "f32[s2*s3]"): # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:268 in forward, code: x0 = x + y # [8, 4] add: "f32[s3, s2]" = torch.ops.aten.add.Tensor(x, y); x = y = None # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:269 in forward, code: x1 = self.l(w) # [6, 3] linear: "f32[s0, 3]" = torch.ops.aten.linear.default(w, p_l_weight, p_l_bias); w = p_l_weight = p_l_bias = None # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:270 in forward, code: x2 = x0.flatten() # [32] flatten: "f32[s2*s3]" = torch.ops.aten.flatten.using_ints(add); add = None # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:271 in forward, code: x3 = x2 + z # [32] add_1: "f32[s2*s3]" = torch.ops.aten.add.Tensor(flatten, z); flatten = z = None return (linear, add_1) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='p_l_weight'), target='l.weight', persistent=None), InputSpec(kind=, arg=TensorArgument(name='p_l_bias'), target='l.bias', persistent=None), InputSpec(kind=, arg=TensorArgument(name='w'), target=None, persistent=None), InputSpec(kind=, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=, arg=TensorArgument(name='y'), target=None, persistent=None), InputSpec(kind=, arg=TensorArgument(name='z'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='linear'), target=None), OutputSpec(kind=, arg=TensorArgument(name='add_1'), target=None)]) Range constraints: {s0: VR[2, int_oo], s2: VR[2, int_oo], s3: VR[2, int_oo], s2*s3: VR[4, int_oo]} .. GENERATED FROM PYTHON SOURCE LINES 364-375 Another feature to notice is the range_constraints field above, which contains a valid range for each symbol. This isn't so interesting currently, since this export call doesn't emit any guards related to symbol bounds and each base symbol has a generic bound, but this will come up later. So far, because we've been exporting this toy model, this experience has not been representative of how hard it typically is to debug dynamic shapes guards & issues. In most cases it isn't obvious what guards are being emitted, and which operations and parts of user code are responsible. For this toy model we pinpoint the exact lines, and the guards are rather intuitive. In more complicated cases, a helpful first step is always to enable verbose logging. This can be done either with the environment variable ``TORCH_LOGS="+dynamic"``, or interactively with ``torch._logging.set_logs(dynamic=10)``: .. GENERATED FROM PYTHON SOURCE LINES 375-379 .. code-block:: Python torch._logging.set_logs(dynamic=10) ep = export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) .. rst-class:: sphx-glr-script-out .. code-block:: none I0718 23:09:26.915000 25987 torch/fx/experimental/symbolic_shapes.py:3334] [8/0] create_env I0718 23:09:26.917000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [8/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" I0718 23:09:26.918000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [8/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" V0718 23:09:26.919000 25987 torch/fx/experimental/symbolic_shapes.py:7018] [8/0] runtime_assert True == True [statically known] I0718 23:09:26.921000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [8/0] create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" I0718 23:09:26.923000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [8/0] create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" I0718 23:09:26.924000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [8/0] create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" I0718 23:09:26.926000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [8/0] create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s5" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" V0718 23:09:26.929000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [8/0] eval size_oblivious(Eq(s2, 1)) == False [statically known] V0718 23:09:26.929000 25987 torch/fx/experimental/symbolic_shapes.py:7018] [8/0] runtime_assert True == True [statically known] V0718 23:09:26.930000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [8/0] eval size_oblivious(Eq(s4, 1)) == False [statically known] I0718 23:09:26.931000 25987 torch/fx/experimental/symbolic_shapes.py:6630] [8/0] runtime_assert Eq(s2, s4) [guard added] x0 = x + y # [8, 4] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:268 in forward (_subclasses/fake_impls.py:881 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)" I0718 23:09:26.932000 25987 torch/fx/experimental/symbolic_shapes.py:6234] [8/0] set_replacement s4 = s2 (solve) VR[2, int_oo] V0718 23:09:26.933000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [8/0] eval size_oblivious(Ne(s2, 1)) == True [statically known] V0718 23:09:26.934000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [8/0] eval size_oblivious(Ne(s3, 1)) == True [statically known] I0718 23:09:26.941000 25987 torch/fx/experimental/symbolic_shapes.py:6630] [8/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # [6, 3] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:269 in forward (_meta_registrations.py:2236 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)" V0718 23:09:26.941000 25987 torch/fx/experimental/symbolic_shapes.py:6071] [8/0] _update_var_to_range s1 = VR[5, 5] (update) I0718 23:09:26.942000 25987 torch/fx/experimental/symbolic_shapes.py:6234] [8/0] set_replacement s1 = 5 (range_refined_to_singleton) VR[5, 5] V0718 23:09:26.944000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [8/0] eval size_oblivious(Eq(s0, 1)) == False [statically known] V0718 23:09:26.949000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [8/0] eval size_oblivious(Eq(s2*s3, 1)) == False [statically known] V0718 23:09:26.950000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [8/0] eval size_oblivious(Eq(s5, 1)) == False [statically known] I0718 23:09:26.951000 25987 torch/fx/experimental/symbolic_shapes.py:6630] [8/0] runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # [32] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:271 in forward (_subclasses/fake_impls.py:881 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)" V0718 23:09:26.952000 25987 torch/fx/experimental/symbolic_shapes.py:6071] [8/0] _update_var_to_range s5 = VR[4, int_oo] (update) I0718 23:09:26.953000 25987 torch/fx/experimental/symbolic_shapes.py:6234] [8/0] set_replacement s5 = s2*s3 (solve) VR[4, int_oo] V0718 23:09:26.955000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [8/0] eval size_oblivious(Ne(s2*s3, 1)) == True [statically known] I0718 23:09:26.961000 25987 torch/fx/experimental/symbolic_shapes.py:4734] [8/0] produce_guards V0718 23:09:26.961000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [8/0] track_symint L['w'].size()[0] s0 None V0718 23:09:26.961000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [8/0] track_symint L['w'].size()[1] 5 None V0718 23:09:26.962000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [8/0] track_symint L['w'].stride()[0] 5 None V0718 23:09:26.962000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [8/0] track_symint L['w'].stride()[1] 1 None V0718 23:09:26.962000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [8/0] track_symint L['w'].storage_offset() 0 None V0718 23:09:26.963000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [8/0] track_symint L['x'].size()[0] s2 None V0718 23:09:26.963000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [8/0] track_symint L['x'].stride()[0] 1 None V0718 23:09:26.963000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [8/0] track_symint L['x'].storage_offset() 0 None V0718 23:09:26.963000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [8/0] track_symint L['y'].size()[0] s3 None V0718 23:09:26.964000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [8/0] track_symint L['y'].size()[1] s2 None V0718 23:09:26.964000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [8/0] track_symint L['y'].stride()[0] s2 None V0718 23:09:26.964000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [8/0] track_symint L['y'].stride()[1] 1 None V0718 23:09:26.965000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [8/0] track_symint L['y'].storage_offset() 0 None V0718 23:09:26.965000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [8/0] track_symint L['z'].size()[0] s2*s3 None V0718 23:09:26.965000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [8/0] track_symint L['z'].stride()[0] 1 None V0718 23:09:26.966000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [8/0] track_symint L['z'].storage_offset() 0 None V0718 23:09:26.997000 25987 torch/fx/experimental/symbolic_shapes.py:6787] eval size_oblivious(Ne(s0, 1)) == True [statically known] .. GENERATED FROM PYTHON SOURCE LINES 380-383 This spits out quite a handful, even with this simple toy model. The log lines here have been cut short at front and end to ignore unnecessary info, but looking through the logs we can see the lines relevant to what we described above; e.g. the allocation of symbols: .. GENERATED FROM PYTHON SOURCE LINES 383-394 .. code-block:: Python """ create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in ) create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in ) runtime_assert True == True [statically known] create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in ) create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in ) create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in ) create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in ) """ .. rst-class:: sphx-glr-script-out .. code-block:: none "\ncreate_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in )\ncreate_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in )\nruntime_assert True == True [statically known]\ncreate_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in )\ncreate_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in )\ncreate_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2841 in )\ncreate_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2841 in )\n" .. GENERATED FROM PYTHON SOURCE LINES 395-397 The lines with `create_symbol` show when a new symbol has been allocated, and the logs also identify the tensor variable names and dimensions they've been allocated for. In other lines we can also see the guards emitted: .. GENERATED FROM PYTHON SOURCE LINES 397-404 .. code-block:: Python """ runtime_assert Eq(s2, s4) [guard added] x0 = x + y # output shape: [8, 4] # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)" runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # [6, 3] # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)" runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # [32] # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)" """ .. rst-class:: sphx-glr-script-out .. code-block:: none '\nruntime_assert Eq(s2, s4) [guard added] x0 = x + y # output shape: [8, 4] # dynamic_shapes_tutorial.py:16 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)"\nruntime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # [6, 3] # dynamic_shapes_tutorial.py:17 in forward (_meta_registrations.py:2127 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)"\nruntime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # [32] # dynamic_shapes_tutorial.py:19 in forward (_subclasses/fake_impls.py:845 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)"\n' .. GENERATED FROM PYTHON SOURCE LINES 405-415 Next to the ``[guard added]`` messages, we also see the responsible user lines of code - luckily here the model is simple enough. In many real-world cases it's not so straightforward: high-level torch operations can have complicated fake-kernel implementations or operator decompositions that complicate where and what guards are emitted. In such cases the best way to dig deeper and investigate is to follow the logs' suggestion, and re-run with environment variable ``TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="..."``, to further attribute the guard of interest. ``Dim.AUTO`` is just one of the available options for interacting with ``dynamic_shapes``; as of writing this 2 other options are available: ``Dim.DYNAMIC``, and ``Dim.STATIC``. ``Dim.STATIC`` simply marks a dimension static, while ``Dim.DYNAMIC`` is similar to ``Dim.AUTO`` in all ways except one: it raises an error when specializing to a constant; this is designed to maintain dynamism. See for example what happens when a static guard is emitted on a dynamically-marked dimension: .. GENERATED FROM PYTHON SOURCE LINES 415-422 .. code-block:: Python dynamic_shapes["w"] = (Dim.AUTO, Dim.DYNAMIC) try: export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) except Exception: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none I0718 23:09:27.016000 25987 torch/fx/experimental/symbolic_shapes.py:3334] [9/0] create_env I0718 23:09:27.019000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [9/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" I0718 23:09:27.019000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [9/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" V0718 23:09:27.020000 25987 torch/fx/experimental/symbolic_shapes.py:7018] [9/0] runtime_assert True == True [statically known] I0718 23:09:27.022000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [9/0] create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" I0718 23:09:27.024000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [9/0] create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" I0718 23:09:27.025000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [9/0] create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" I0718 23:09:27.027000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [9/0] create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s5" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" V0718 23:09:27.029000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [9/0] eval size_oblivious(Eq(s2, 1)) == False [statically known] V0718 23:09:27.029000 25987 torch/fx/experimental/symbolic_shapes.py:7018] [9/0] runtime_assert True == True [statically known] V0718 23:09:27.030000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [9/0] eval size_oblivious(Eq(s4, 1)) == False [statically known] I0718 23:09:27.031000 25987 torch/fx/experimental/symbolic_shapes.py:6630] [9/0] runtime_assert Eq(s2, s4) [guard added] x0 = x + y # [8, 4] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:268 in forward (_subclasses/fake_impls.py:881 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)" I0718 23:09:27.032000 25987 torch/fx/experimental/symbolic_shapes.py:6234] [9/0] set_replacement s4 = s2 (solve) VR[2, int_oo] V0718 23:09:27.033000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [9/0] eval size_oblivious(Ne(s2, 1)) == True [statically known] V0718 23:09:27.034000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [9/0] eval size_oblivious(Ne(s3, 1)) == True [statically known] I0718 23:09:27.040000 25987 torch/fx/experimental/symbolic_shapes.py:6630] [9/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # [6, 3] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:269 in forward (_meta_registrations.py:2236 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)" V0718 23:09:27.041000 25987 torch/fx/experimental/symbolic_shapes.py:6071] [9/0] _update_var_to_range s1 = VR[5, 5] (update) I0718 23:09:27.041000 25987 torch/fx/experimental/symbolic_shapes.py:6234] [9/0] set_replacement s1 = 5 (range_refined_to_singleton) VR[5, 5] V0718 23:09:27.043000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [9/0] eval size_oblivious(Eq(s0, 1)) == False [statically known] V0718 23:09:27.048000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [9/0] eval size_oblivious(Eq(s2*s3, 1)) == False [statically known] V0718 23:09:27.049000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [9/0] eval size_oblivious(Eq(s5, 1)) == False [statically known] I0718 23:09:27.049000 25987 torch/fx/experimental/symbolic_shapes.py:6630] [9/0] runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # [32] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:271 in forward (_subclasses/fake_impls.py:881 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)" V0718 23:09:27.050000 25987 torch/fx/experimental/symbolic_shapes.py:6071] [9/0] _update_var_to_range s5 = VR[4, int_oo] (update) I0718 23:09:27.051000 25987 torch/fx/experimental/symbolic_shapes.py:6234] [9/0] set_replacement s5 = s2*s3 (solve) VR[4, int_oo] V0718 23:09:27.053000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [9/0] eval size_oblivious(Ne(s2*s3, 1)) == True [statically known] I0718 23:09:27.059000 25987 torch/fx/experimental/symbolic_shapes.py:4734] [9/0] produce_guards V0718 23:09:27.059000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [9/0] track_symint L['w'].size()[0] s0 None V0718 23:09:27.059000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [9/0] track_symint L['w'].size()[1] 5 RelaxedUnspecConstraint(warn_only=False) V0718 23:09:27.060000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [9/0] track_symint L['w'].stride()[0] 5 None V0718 23:09:27.060000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [9/0] track_symint L['w'].stride()[1] 1 None V0718 23:09:27.060000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [9/0] track_symint L['w'].storage_offset() 0 None V0718 23:09:27.060000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [9/0] track_symint L['x'].size()[0] s2 None V0718 23:09:27.061000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [9/0] track_symint L['x'].stride()[0] 1 None V0718 23:09:27.061000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [9/0] track_symint L['x'].storage_offset() 0 None V0718 23:09:27.061000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [9/0] track_symint L['y'].size()[0] s3 None V0718 23:09:27.062000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [9/0] track_symint L['y'].size()[1] s2 None V0718 23:09:27.062000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [9/0] track_symint L['y'].stride()[0] s2 None V0718 23:09:27.062000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [9/0] track_symint L['y'].stride()[1] 1 None V0718 23:09:27.062000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [9/0] track_symint L['y'].storage_offset() 0 None V0718 23:09:27.063000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [9/0] track_symint L['z'].size()[0] s2*s3 None V0718 23:09:27.063000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [9/0] track_symint L['z'].stride()[0] 1 None V0718 23:09:27.063000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [9/0] track_symint L['z'].storage_offset() 0 None E0718 23:09:27.065000 25987 torch/_guards.py:359] [9/0] Error while creating guard: E0718 23:09:27.065000 25987 torch/_guards.py:359] [9/0] Name: '' E0718 23:09:27.065000 25987 torch/_guards.py:359] [9/0] Source: shape_env E0718 23:09:27.065000 25987 torch/_guards.py:359] [9/0] Create Function: SHAPE_ENV E0718 23:09:27.065000 25987 torch/_guards.py:359] [9/0] Guard Types: None E0718 23:09:27.065000 25987 torch/_guards.py:359] [9/0] Code List: None E0718 23:09:27.065000 25987 torch/_guards.py:359] [9/0] Object Weakref: None E0718 23:09:27.065000 25987 torch/_guards.py:359] [9/0] Guarded Class Weakref: None E0718 23:09:27.065000 25987 torch/_guards.py:359] [9/0] Traceback (most recent call last): E0718 23:09:27.065000 25987 torch/_guards.py:359] [9/0] File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 357, in create E0718 23:09:27.065000 25987 torch/_guards.py:359] [9/0] return self.create_fn(builder, self) E0718 23:09:27.065000 25987 torch/_guards.py:359] [9/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1959, in SHAPE_ENV E0718 23:09:27.065000 25987 torch/_guards.py:359] [9/0] python_code_parts, verbose_code_parts = _get_code_parts( E0718 23:09:27.065000 25987 torch/_guards.py:359] [9/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1942, in _get_code_parts E0718 23:09:27.065000 25987 torch/_guards.py:359] [9/0] return output_graph.shape_env.produce_guards_verbose( E0718 23:09:27.065000 25987 torch/_guards.py:359] [9/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5409, in produce_guards_verbose E0718 23:09:27.065000 25987 torch/_guards.py:359] [9/0] raise ConstraintViolationError( E0718 23:09:27.065000 25987 torch/_guards.py:359] [9/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['w'].size()[1])! For more information, run with TORCH_LOGS="+dynamic". E0718 23:09:27.065000 25987 torch/_guards.py:359] [9/0] - Not all values of RelaxedUnspecConstraint(L['w'].size()[1]) are valid because L['w'].size()[1] was inferred to be a constant (5). E0718 23:09:27.068000 25987 torch/_guards.py:361] [9/0] Created at: E0718 23:09:27.068000 25987 torch/_guards.py:361] [9/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 694, in transform E0718 23:09:27.068000 25987 torch/_guards.py:361] [9/0] tracer = InstructionTranslator( E0718 23:09:27.068000 25987 torch/_guards.py:361] [9/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3329, in __init__ E0718 23:09:27.068000 25987 torch/_guards.py:361] [9/0] output=OutputGraph( E0718 23:09:27.068000 25987 torch/_guards.py:361] [9/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 358, in __init__ E0718 23:09:27.068000 25987 torch/_guards.py:361] [9/0] self.init_ambient_guards() E0718 23:09:27.068000 25987 torch/_guards.py:361] [9/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 512, in init_ambient_guards E0718 23:09:27.068000 25987 torch/_guards.py:361] [9/0] self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV)) Traceback (most recent call last): File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 739, in _export_to_torch_ir gm_torch_level, _ = torch._dynamo.export( File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1722, in inner raise constraint_violation_error File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1677, in inner result_traced = opt_f(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 655, in _fn return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1432, in __call__ return self._torchdynamo_orig_callable( File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 598, in __call__ return _compile( File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1059, in _compile guarded_code = compile_inner(code, one_graph, hooks, transform) File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 97, in wrapper_function return function(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 761, in compile_inner return _compile_inner(code, one_graph, hooks, transform) File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 906, in _compile_inner check_fn = CheckFunctionManager( File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 2481, in __init__ guard.create(builder) File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 357, in create return self.create_fn(builder, self) File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1959, in SHAPE_ENV python_code_parts, verbose_code_parts = _get_code_parts( File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1942, in _get_code_parts return output_graph.shape_env.produce_guards_verbose( File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5409, in produce_guards_verbose raise ConstraintViolationError( torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['w'].size()[1])! For more information, run with TORCH_LOGS="+dynamic". - Not all values of RelaxedUnspecConstraint(L['w'].size()[1]) are valid because L['w'].size()[1] was inferred to be a constant (5). During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 418, in export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 360, in export return _export( File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1092, in wrapper raise e File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1065, in wrapper ep = fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 121, in wrapper return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2112, in _export ep = _export_for_training( File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1092, in wrapper raise e File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1065, in wrapper ep = fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 121, in wrapper return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1975, in _export_for_training export_artifact = export_func( # type: ignore[operator] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1344, in _strict_export_lower_to_aten_ir gm_torch_level = _export_to_torch_ir( File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 756, in _export_to_torch_ir raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 torch._dynamo.exc.UserError: Constraints violated (L['w'].size()[1])! For more information, run with TORCH_LOGS="+dynamic". - Not all values of RelaxedUnspecConstraint(L['w'].size()[1]) are valid because L['w'].size()[1] was inferred to be a constant (5). .. GENERATED FROM PYTHON SOURCE LINES 423-426 Static guards also aren't always inherent to the model; they can also come from user specifications. In fact, a common pitfall leading to shape specializations is when the user specifies conflicting markers for equivalent dimensions; one dynamic and another static. The same error type is raised when this is the case for ``x.shape[0]`` and ``y.shape[1]``: .. GENERATED FROM PYTHON SOURCE LINES 426-435 .. code-block:: Python dynamic_shapes["w"] = (Dim.AUTO, Dim.AUTO) dynamic_shapes["x"] = (Dim.STATIC,) dynamic_shapes["y"] = (Dim.AUTO, Dim.DYNAMIC) try: export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) except Exception: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none I0718 23:09:27.086000 25987 torch/fx/experimental/symbolic_shapes.py:3334] [10/0] create_env I0718 23:09:27.088000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [10/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" I0718 23:09:27.089000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [10/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" V0718 23:09:27.089000 25987 torch/fx/experimental/symbolic_shapes.py:7018] [10/0] runtime_assert True == True [statically known] I0718 23:09:27.093000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [10/0] create_symbol s2 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" I0718 23:09:27.093000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [10/0] create_symbol s3 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" I0718 23:09:27.096000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [10/0] create_symbol s4 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" V0718 23:09:27.099000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [10/0] eval size_oblivious(Eq(s3, 1)) == False [statically known] I0718 23:09:27.103000 25987 torch/fx/experimental/symbolic_shapes.py:6630] [10/0] runtime_assert Eq(s3, 4) [guard added] x0 = x + y # [8, 4] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:268 in forward (_subclasses/fake_impls.py:881 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s3, 4)" V0718 23:09:27.104000 25987 torch/fx/experimental/symbolic_shapes.py:6071] [10/0] _update_var_to_range s3 = VR[4, 4] (update) I0718 23:09:27.105000 25987 torch/fx/experimental/symbolic_shapes.py:6234] [10/0] set_replacement s3 = 4 (range_refined_to_singleton) VR[4, 4] V0718 23:09:27.107000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [10/0] eval size_oblivious(Ne(s2, 1)) == True [statically known] I0718 23:09:27.113000 25987 torch/fx/experimental/symbolic_shapes.py:6630] [10/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # [6, 3] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:269 in forward (_meta_registrations.py:2236 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)" V0718 23:09:27.114000 25987 torch/fx/experimental/symbolic_shapes.py:6071] [10/0] _update_var_to_range s1 = VR[5, 5] (update) I0718 23:09:27.114000 25987 torch/fx/experimental/symbolic_shapes.py:6234] [10/0] set_replacement s1 = 5 (range_refined_to_singleton) VR[5, 5] V0718 23:09:27.116000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [10/0] eval size_oblivious(Eq(s0, 1)) == False [statically known] V0718 23:09:27.117000 25987 torch/fx/experimental/symbolic_shapes.py:7018] [10/0] runtime_assert True == True [statically known] V0718 23:09:27.124000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [10/0] eval size_oblivious(Eq(s4, 1)) == False [statically known] I0718 23:09:27.128000 25987 torch/fx/experimental/symbolic_shapes.py:6630] [10/0] runtime_assert Eq(4*s2, s4) [guard added] x3 = x2 + z # [32] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:271 in forward (_subclasses/fake_impls.py:881 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(4*s2, s4)" V0718 23:09:27.131000 25987 torch/fx/experimental/symbolic_shapes.py:6071] [10/0] _update_var_to_range s4 = VR[8, int_oo] (update) I0718 23:09:27.132000 25987 torch/fx/experimental/symbolic_shapes.py:6234] [10/0] set_replacement s4 = 4*s2 (solve) VR[8, int_oo] I0718 23:09:27.138000 25987 torch/fx/experimental/symbolic_shapes.py:4734] [10/0] produce_guards V0718 23:09:27.139000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [10/0] track_symint L['w'].size()[0] s0 None V0718 23:09:27.139000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [10/0] track_symint L['w'].size()[1] 5 None V0718 23:09:27.139000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [10/0] track_symint L['w'].stride()[0] 5 None V0718 23:09:27.140000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [10/0] track_symint L['w'].stride()[1] 1 None V0718 23:09:27.140000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [10/0] track_symint L['w'].storage_offset() 0 None V0718 23:09:27.140000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [10/0] track_symint L['x'].size()[0] 4 None V0718 23:09:27.141000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [10/0] track_symint L['x'].stride()[0] 1 None V0718 23:09:27.141000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [10/0] track_symint L['x'].storage_offset() 0 None V0718 23:09:27.141000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [10/0] track_symint L['y'].size()[0] s2 None V0718 23:09:27.141000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [10/0] track_symint L['y'].size()[1] 4 RelaxedUnspecConstraint(warn_only=False) V0718 23:09:27.142000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [10/0] track_symint L['y'].stride()[0] 4 None V0718 23:09:27.142000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [10/0] track_symint L['y'].stride()[1] 1 None V0718 23:09:27.142000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [10/0] track_symint L['y'].storage_offset() 0 None V0718 23:09:27.143000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [10/0] track_symint L['z'].size()[0] 4*s2 None V0718 23:09:27.143000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [10/0] track_symint L['z'].stride()[0] 1 None V0718 23:09:27.143000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [10/0] track_symint L['z'].storage_offset() 0 None E0718 23:09:27.145000 25987 torch/_guards.py:359] [10/0] Error while creating guard: E0718 23:09:27.145000 25987 torch/_guards.py:359] [10/0] Name: '' E0718 23:09:27.145000 25987 torch/_guards.py:359] [10/0] Source: shape_env E0718 23:09:27.145000 25987 torch/_guards.py:359] [10/0] Create Function: SHAPE_ENV E0718 23:09:27.145000 25987 torch/_guards.py:359] [10/0] Guard Types: None E0718 23:09:27.145000 25987 torch/_guards.py:359] [10/0] Code List: None E0718 23:09:27.145000 25987 torch/_guards.py:359] [10/0] Object Weakref: None E0718 23:09:27.145000 25987 torch/_guards.py:359] [10/0] Guarded Class Weakref: None E0718 23:09:27.145000 25987 torch/_guards.py:359] [10/0] Traceback (most recent call last): E0718 23:09:27.145000 25987 torch/_guards.py:359] [10/0] File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 357, in create E0718 23:09:27.145000 25987 torch/_guards.py:359] [10/0] return self.create_fn(builder, self) E0718 23:09:27.145000 25987 torch/_guards.py:359] [10/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1959, in SHAPE_ENV E0718 23:09:27.145000 25987 torch/_guards.py:359] [10/0] python_code_parts, verbose_code_parts = _get_code_parts( E0718 23:09:27.145000 25987 torch/_guards.py:359] [10/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1942, in _get_code_parts E0718 23:09:27.145000 25987 torch/_guards.py:359] [10/0] return output_graph.shape_env.produce_guards_verbose( E0718 23:09:27.145000 25987 torch/_guards.py:359] [10/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5409, in produce_guards_verbose E0718 23:09:27.145000 25987 torch/_guards.py:359] [10/0] raise ConstraintViolationError( E0718 23:09:27.145000 25987 torch/_guards.py:359] [10/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['y'].size()[1])! For more information, run with TORCH_LOGS="+dynamic". E0718 23:09:27.145000 25987 torch/_guards.py:359] [10/0] - Not all values of RelaxedUnspecConstraint(L['y'].size()[1]) are valid because L['y'].size()[1] was inferred to be a constant (4). E0718 23:09:27.146000 25987 torch/_guards.py:361] [10/0] Created at: E0718 23:09:27.146000 25987 torch/_guards.py:361] [10/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 694, in transform E0718 23:09:27.146000 25987 torch/_guards.py:361] [10/0] tracer = InstructionTranslator( E0718 23:09:27.146000 25987 torch/_guards.py:361] [10/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3329, in __init__ E0718 23:09:27.146000 25987 torch/_guards.py:361] [10/0] output=OutputGraph( E0718 23:09:27.146000 25987 torch/_guards.py:361] [10/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 358, in __init__ E0718 23:09:27.146000 25987 torch/_guards.py:361] [10/0] self.init_ambient_guards() E0718 23:09:27.146000 25987 torch/_guards.py:361] [10/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 512, in init_ambient_guards E0718 23:09:27.146000 25987 torch/_guards.py:361] [10/0] self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV)) Traceback (most recent call last): File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 739, in _export_to_torch_ir gm_torch_level, _ = torch._dynamo.export( File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1722, in inner raise constraint_violation_error File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1677, in inner result_traced = opt_f(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 655, in _fn return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1432, in __call__ return self._torchdynamo_orig_callable( File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 598, in __call__ return _compile( File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1059, in _compile guarded_code = compile_inner(code, one_graph, hooks, transform) File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 97, in wrapper_function return function(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 761, in compile_inner return _compile_inner(code, one_graph, hooks, transform) File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 906, in _compile_inner check_fn = CheckFunctionManager( File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 2481, in __init__ guard.create(builder) File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 357, in create return self.create_fn(builder, self) File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1959, in SHAPE_ENV python_code_parts, verbose_code_parts = _get_code_parts( File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1942, in _get_code_parts return output_graph.shape_env.produce_guards_verbose( File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5409, in produce_guards_verbose raise ConstraintViolationError( torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (L['y'].size()[1])! For more information, run with TORCH_LOGS="+dynamic". - Not all values of RelaxedUnspecConstraint(L['y'].size()[1]) are valid because L['y'].size()[1] was inferred to be a constant (4). During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 431, in export(model, (w, x, y, z), dynamic_shapes=dynamic_shapes) File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 360, in export return _export( File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1092, in wrapper raise e File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1065, in wrapper ep = fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 121, in wrapper return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2112, in _export ep = _export_for_training( File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1092, in wrapper raise e File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1065, in wrapper ep = fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 121, in wrapper return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1975, in _export_for_training export_artifact = export_func( # type: ignore[operator] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1344, in _strict_export_lower_to_aten_ir gm_torch_level = _export_to_torch_ir( File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 756, in _export_to_torch_ir raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 torch._dynamo.exc.UserError: Constraints violated (L['y'].size()[1])! For more information, run with TORCH_LOGS="+dynamic". - Not all values of RelaxedUnspecConstraint(L['y'].size()[1]) are valid because L['y'].size()[1] was inferred to be a constant (4). .. GENERATED FROM PYTHON SOURCE LINES 436-443 Here you might ask why export "specializes", i.e. why we resolve this static/dynamic conflict by going with the static route. The answer is because of the symbolic shapes system described above, of symbols and guards. When ``x.shape[0]`` is marked static, we don't allocate a symbol, and compile treating this shape as a concrete integer 4. A symbol is allocated for ``y.shape[1]``, and so we finally emit the guard ``s3 == 4``, leading to specialization. One feature of export is that during tracing, statements like asserts, ``torch._check()``, and ``if/else`` conditions will also emit guards. See what happens when we augment the existing model with such statements: .. GENERATED FROM PYTHON SOURCE LINES 443-472 .. code-block:: Python class DynamicModel(torch.nn.Module): def __init__(self): super().__init__() self.l = torch.nn.Linear(5, 3) def forward(self, w, x, y, z): assert w.shape[0] <= 512 torch._check(x.shape[0] >= 4) if w.shape[0] == x.shape[0] + 2: x0 = x + y x1 = self.l(w) x2 = x0.flatten() x3 = x2 + z return x1, x3 else: return w dynamic_shapes = { "w": (Dim.AUTO, Dim.AUTO), "x": (Dim.AUTO,), "y": (Dim.AUTO, Dim.AUTO), "z": (Dim.AUTO,), } try: ep = export(DynamicModel(), (w, x, y, z), dynamic_shapes=dynamic_shapes) except Exception: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none I0718 23:09:27.162000 25987 torch/fx/experimental/symbolic_shapes.py:3334] [11/0] create_env I0718 23:09:27.164000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [11/0] create_symbol s0 = 6 for L['w'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" I0718 23:09:27.164000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [11/0] create_symbol s1 = 5 for L['w'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" V0718 23:09:27.165000 25987 torch/fx/experimental/symbolic_shapes.py:7018] [11/0] runtime_assert True == True [statically known] I0718 23:09:27.168000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [11/0] create_symbol s2 = 4 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" I0718 23:09:27.170000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [11/0] create_symbol s3 = 8 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" I0718 23:09:27.170000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [11/0] create_symbol s4 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s4" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" I0718 23:09:27.173000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [11/0] create_symbol s5 = 32 for L['z'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s5" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" I0718 23:09:27.180000 25987 torch/fx/experimental/symbolic_shapes.py:6630] [11/0] runtime_assert s0 <= 512 [guard added] assert w.shape[0] <= 512 # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:450 in forward (_dynamo/symbolic_convert.py:669 in inner), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="s0 <= 512" V0718 23:09:27.181000 25987 torch/fx/experimental/symbolic_shapes.py:6071] [11/0] _update_var_to_range s0 = VR[2, 512] (update) I0718 23:09:27.186000 25987 torch/fx/experimental/symbolic_shapes.py:6630] [11/0] runtime_assert s2 >= 4 [guard added] torch._check(x.shape[0] >= 4) # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:451 in forward (_dynamo/utils.py:3284 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="s2 >= 4" V0718 23:09:27.186000 25987 torch/fx/experimental/symbolic_shapes.py:6071] [11/0] _update_var_to_range s2 = VR[4, int_oo] (update) I0718 23:09:27.193000 25987 torch/fx/experimental/symbolic_shapes.py:6630] [11/0] eval Eq(s0, s2 + 2) [guard added] if w.shape[0] == x.shape[0] + 2: # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:452 in forward (_dynamo/variables/tensor.py:1245 in evaluate_expr), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s0, s2 + 2)" V0718 23:09:27.196000 25987 torch/fx/experimental/symbolic_shapes.py:6071] [11/0] _update_var_to_range s2 = VR[4, 510] (update) V0718 23:09:27.197000 25987 torch/fx/experimental/symbolic_shapes.py:6071] [11/0] _update_var_to_range s0 = VR[6, 512] (update) I0718 23:09:27.197000 25987 torch/fx/experimental/symbolic_shapes.py:6234] [11/0] set_replacement s0 = s2 + 2 (solve) VR[6, 512] V0718 23:09:27.199000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [11/0] eval size_oblivious(Eq(s2, 1)) == False [statically known] V0718 23:09:27.200000 25987 torch/fx/experimental/symbolic_shapes.py:7018] [11/0] runtime_assert True == True [statically known] V0718 23:09:27.201000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [11/0] eval size_oblivious(Eq(s4, 1)) == False [statically known] I0718 23:09:27.203000 25987 torch/fx/experimental/symbolic_shapes.py:6630] [11/0] runtime_assert Eq(s2, s4) [guard added] x0 = x + y # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:453 in forward (_subclasses/fake_impls.py:881 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2, s4)" V0718 23:09:27.204000 25987 torch/fx/experimental/symbolic_shapes.py:6071] [11/0] _update_var_to_range s4 = VR[4, 510] (update) I0718 23:09:27.205000 25987 torch/fx/experimental/symbolic_shapes.py:6234] [11/0] set_replacement s4 = s2 (solve) VR[4, 510] V0718 23:09:27.207000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [11/0] eval size_oblivious(Ne(s2, 1)) == True [statically known] V0718 23:09:27.208000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [11/0] eval size_oblivious(Ne(s3, 1)) == True [statically known] I0718 23:09:27.214000 25987 torch/fx/experimental/symbolic_shapes.py:6630] [11/0] runtime_assert Eq(s1, 5) [guard added] x1 = self.l(w) # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:454 in forward (_meta_registrations.py:2236 in meta_mm), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 5)" V0718 23:09:27.215000 25987 torch/fx/experimental/symbolic_shapes.py:6071] [11/0] _update_var_to_range s1 = VR[5, 5] (update) I0718 23:09:27.215000 25987 torch/fx/experimental/symbolic_shapes.py:6234] [11/0] set_replacement s1 = 5 (range_refined_to_singleton) VR[5, 5] V0718 23:09:27.225000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [11/0] eval size_oblivious(Eq(s2*s3, 1)) == False [statically known] V0718 23:09:27.226000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [11/0] eval size_oblivious(Eq(s5, 1)) == False [statically known] I0718 23:09:27.235000 25987 torch/fx/experimental/symbolic_shapes.py:6630] [11/0] runtime_assert Eq(s2*s3, s5) [guard added] x3 = x2 + z # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:456 in forward (_subclasses/fake_impls.py:881 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s2*s3, s5)" V0718 23:09:27.236000 25987 torch/fx/experimental/symbolic_shapes.py:6071] [11/0] _update_var_to_range s5 = VR[8, int_oo] (update) I0718 23:09:27.237000 25987 torch/fx/experimental/symbolic_shapes.py:6234] [11/0] set_replacement s5 = s2*s3 (solve) VR[8, int_oo] V0718 23:09:27.239000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [11/0] eval size_oblivious(Ne(s2*s3, 1)) == True [statically known] V0718 23:09:27.242000 25987 torch/fx/experimental/symbolic_shapes.py:7018] [11/0] runtime_assert s2 >= 4 == True [statically known] I0718 23:09:27.248000 25987 torch/fx/experimental/symbolic_shapes.py:4734] [11/0] produce_guards V0718 23:09:27.249000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [11/0] track_symint L['w'].size()[0] s2 + 2 None V0718 23:09:27.249000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [11/0] track_symint L['w'].size()[1] 5 None V0718 23:09:27.250000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [11/0] track_symint L['w'].stride()[0] 5 None V0718 23:09:27.250000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [11/0] track_symint L['w'].stride()[1] 1 None V0718 23:09:27.250000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [11/0] track_symint L['w'].storage_offset() 0 None V0718 23:09:27.250000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [11/0] track_symint L['x'].size()[0] s2 None V0718 23:09:27.251000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [11/0] track_symint L['x'].stride()[0] 1 None V0718 23:09:27.251000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [11/0] track_symint L['x'].storage_offset() 0 None V0718 23:09:27.251000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [11/0] track_symint L['y'].size()[0] s3 None V0718 23:09:27.252000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [11/0] track_symint L['y'].size()[1] s2 None V0718 23:09:27.252000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [11/0] track_symint L['y'].stride()[0] s2 None V0718 23:09:27.252000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [11/0] track_symint L['y'].stride()[1] 1 None V0718 23:09:27.253000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [11/0] track_symint L['y'].storage_offset() 0 None V0718 23:09:27.253000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [11/0] track_symint L['z'].size()[0] s2*s3 None V0718 23:09:27.253000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [11/0] track_symint L['z'].stride()[0] 1 None V0718 23:09:27.254000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [11/0] track_symint L['z'].storage_offset() 0 None .. GENERATED FROM PYTHON SOURCE LINES 473-481 Each of these statements emits an additional guard, and the exported program shows the changes; ``s0`` is eliminated in favor of ``s2 + 2``, and ``s2`` now contains lower and upper bounds, reflected in ``range_constraints``. For the if/else condition, you might ask why the True branch was taken, and why it wasn't the ``w.shape[0] != x.shape[0] + 2`` guard that got emitted from tracing. The answer is that export is guided by the sample inputs provided by tracing, and specializes on the branches taken. If different sample input shapes were provided that fail the ``if`` condition, export would trace and emit guards corresponding to the ``else`` branch. Additionally, you might ask why we traced only the ``if`` branch, and if it's possible to maintain control-flow in your program and keep both branches alive. For that, refer to rewriting your model code following the ``Control Flow Ops`` section above. .. GENERATED FROM PYTHON SOURCE LINES 483-491 0/1 specialization ^^^^^^^^^^^^^^^^^^ Since we're talking about guards and specializations, it's a good time to talk about the 0/1 specialization issue we brought up earlier. The bottom line is that export will specialize on sample input dimensions with value 0 or 1, because these shapes have trace-time properties that don't generalize to other shapes. For example, size 1 tensors can broadcast while other sizes fail; and size 0 ... . This just means that you should specify 0/1 sample inputs when you'd like your program to hardcode them, and non-0/1 sample inputs when dynamic behavior is desirable. See what happens at runtime when we export this linear layer: .. GENERATED FROM PYTHON SOURCE LINES 491-504 .. code-block:: Python ep = export( torch.nn.Linear(4, 3), (torch.randn(1, 4),), dynamic_shapes={ "input": (Dim.AUTO, Dim.STATIC), }, ) try: ep.module()(torch.randn(2, 4)) except Exception: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none I0718 23:09:27.315000 25987 torch/fx/experimental/symbolic_shapes.py:3334] [12/0] create_env I0718 23:09:27.328000 25987 torch/fx/experimental/symbolic_shapes.py:4734] [12/0] produce_guards V0718 23:09:27.328000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [12/0] track_symint L['args'][0].size()[0] 1 None V0718 23:09:27.329000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [12/0] track_symint L['args'][0].size()[1] 4 None V0718 23:09:27.329000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [12/0] track_symint L['args'][0].stride()[0] 4 None V0718 23:09:27.329000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [12/0] track_symint L['args'][0].stride()[1] 1 None V0718 23:09:27.330000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [12/0] track_symint L['args'][0].storage_offset() 0 None Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 500, in ep.module()(torch.randn(2, 4)) File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 830, in call_wrapped return self._wrapped_call(self, *args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 406, in __call__ raise e File "/usr/local/lib/python3.10/dist-packages/torch/fx/graph_module.py", line 393, in __call__ return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1857, in _call_impl return inner() File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1784, in inner args_kwargs_result = hook(self, args, kwargs) # type: ignore[misc] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 838, in _fn return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/_unlift.py", line 55, in _check_input_constraints_pre_hook _check_input_constraints_for_graph( File "/usr/local/lib/python3.10/dist-packages/torch/_export/utils.py", line 398, in _check_input_constraints_for_graph raise RuntimeError( RuntimeError: Expected input at *args[0].shape[0] to be equal to 1, but got 2 .. GENERATED FROM PYTHON SOURCE LINES 505-517 Named Dims ^^^^^^^^^^ So far we've only been talking about 3 ways to specify dynamic shapes: ``Dim.AUTO``, ``Dim.DYNAMIC``, and ``Dim.STATIC``. The attraction of these is the low-friction user experience; all the guards emitted during model tracing are adhered to, and dynamic behavior like min/max ranges, relations, and static/dynamic dimensions are automatically figured out underneath export. The dynamic shapes subsystem essentially acts as a "discovery" process, summarizing these guards and presenting what export believes is the overall dynamic behavior of the program. The drawback of this design appears once the user has stronger expectations or beliefs about the dynamic behavior of these models - maybe there is a strong desire on dynamism and specializations on particular dimensions are to be avoided at all costs, or maybe we just want to catch changes in dynamic behavior with changes to the original model code, or possibly underlying decompositions or meta-kernels. These changes won't be detected and the ``export()`` call will most likely succeed, unless tests are in place that check the resulting ``ExportedProgram`` representation. For such cases, our stance is to recommend the "traditional" way of specifying dynamic shapes, which longer-term users of export might be familiar with: named ``Dims``: .. GENERATED FROM PYTHON SOURCE LINES 517-525 .. code-block:: Python dx = Dim("dx", min=4, max=256) dh = Dim("dh", max=512) dynamic_shapes = { "x": (dx, None), "y": (2 * dx, dh), } .. GENERATED FROM PYTHON SOURCE LINES 526-536 This style of dynamic shapes allows the user to specify what symbols are allocated for input dimensions, min/max bounds on those symbols, and places restrictions on the dynamic behavior of the ``ExportedProgram`` produced; ``ConstraintViolation`` errors will be raised if model tracing emits guards that conflict with the relations or static/dynamic specifications given. For example, in the above specification, the following is asserted: - ``x.shape[0]`` is to have range ``[4, 256]``, and related to ``y.shape[0]`` by ``y.shape[0] == 2 * x.shape[0]``. - ``x.shape[1]`` is static. - ``y.shape[1]`` has range ``[2, 512]``, and is unrelated to any other dimension. In this design, we allow relations between dimensions to be specified with univariate linear expressions: ``A * dim + B`` can be specified for any dimension. This allows users to specify more complex constraints like integer divisibility for dynamic dimensions: .. GENERATED FROM PYTHON SOURCE LINES 536-542 .. code-block:: Python dx = Dim("dx", min=4, max=512) dynamic_shapes = { "x": (4 * dx, None) # x.shape[0] has range [16, 2048], and is divisible by 4. } .. GENERATED FROM PYTHON SOURCE LINES 543-549 Constraint violations, suggested fixes ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ One common issue with this specification style (before ``Dim.AUTO`` was introduced), is that the specification would often be mismatched with what was produced by model tracing. That would lead to ``ConstraintViolation`` errors and export suggested fixes - see for example with this model & specification, where the model inherently requires equality between dimensions 0 of ``x`` and ``y``, and requires dimension 1 to be static. .. GENERATED FROM PYTHON SOURCE LINES 549-568 .. code-block:: Python class Foo(torch.nn.Module): def forward(self, x, y): w = x + y return w + torch.ones(4) dx, dy, d1 = torch.export.dims("dx", "dy", "d1") try: ep = export( Foo(), (torch.randn(6, 4), torch.randn(6, 4)), dynamic_shapes={ "x": (dx, d1), "y": (dy, d1), }, ) except Exception: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none I0718 23:09:27.452000 25987 torch/fx/experimental/symbolic_shapes.py:3334] [13/0] create_env I0718 23:09:27.455000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [13/0] create_symbol s0 = 6 for L['x'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" I0718 23:09:27.455000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [13/0] create_symbol s1 = 4 for L['x'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" V0718 23:09:27.456000 25987 torch/fx/experimental/symbolic_shapes.py:7018] [13/0] runtime_assert True == True [statically known] I0718 23:09:27.459000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [13/0] create_symbol s2 = 6 for L['y'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" I0718 23:09:27.459000 25987 torch/fx/experimental/symbolic_shapes.py:4606] [13/0] create_symbol s3 = 4 for L['y'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:3033 in ), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s3" or to suppress this message run with TORCHDYNAMO_EXTENDED_ADVICE="0" V0718 23:09:27.463000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [13/0] eval size_oblivious(Eq(s1, 1)) == False [statically known] V0718 23:09:27.464000 25987 torch/fx/experimental/symbolic_shapes.py:7018] [13/0] runtime_assert True == True [statically known] V0718 23:09:27.464000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [13/0] eval size_oblivious(Eq(s0, 1)) == False [statically known] V0718 23:09:27.465000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [13/0] eval size_oblivious(Eq(s3, 1)) == False [statically known] I0718 23:09:27.467000 25987 torch/fx/experimental/symbolic_shapes.py:6630] [13/0] runtime_assert Eq(s1, s3) [guard added] w = x + y # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:552 in forward (_subclasses/fake_impls.py:881 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, s3)" I0718 23:09:27.468000 25987 torch/fx/experimental/symbolic_shapes.py:6234] [13/0] set_replacement s3 = s1 (solve) VR[2, int_oo] V0718 23:09:27.469000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [13/0] eval size_oblivious(Eq(s2, 1)) == False [statically known] I0718 23:09:27.470000 25987 torch/fx/experimental/symbolic_shapes.py:6630] [13/0] runtime_assert Eq(s0, s2) [guard added] w = x + y # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:552 in forward (_subclasses/fake_impls.py:881 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s0, s2)" I0718 23:09:27.472000 25987 torch/fx/experimental/symbolic_shapes.py:6234] [13/0] set_replacement s2 = s0 (solve) VR[2, int_oo] V0718 23:09:27.473000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [13/0] eval size_oblivious(Ne(s1, 1)) == True [statically known] V0718 23:09:27.474000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [13/0] eval size_oblivious(Ne(s0, 1)) == True [statically known] I0718 23:09:27.481000 25987 torch/fx/experimental/symbolic_shapes.py:6630] [13/0] runtime_assert Eq(s1, 4) [guard added] return w + torch.ones(4) # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:553 in forward (_subclasses/fake_impls.py:881 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Eq(s1, 4)" V0718 23:09:27.482000 25987 torch/fx/experimental/symbolic_shapes.py:6071] [13/0] _update_var_to_range s1 = VR[4, 4] (update) I0718 23:09:27.483000 25987 torch/fx/experimental/symbolic_shapes.py:6234] [13/0] set_replacement s1 = 4 (range_refined_to_singleton) VR[4, 4] V0718 23:09:27.486000 25987 torch/fx/experimental/symbolic_shapes.py:6071] [13/0] _update_var_to_range s3 = VR[4, 4] (update) I0718 23:09:27.487000 25987 torch/fx/experimental/symbolic_shapes.py:6234] [13/0] set_replacement s3 = 4 (find) VR[4, 4] I0718 23:09:27.490000 25987 torch/fx/experimental/symbolic_shapes.py:4734] [13/0] produce_guards V0718 23:09:27.490000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['x'].size()[0] s0 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo]) V0718 23:09:27.490000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['x'].size()[1] 4 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo]) V0718 23:09:27.491000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['x'].stride()[0] 4 None V0718 23:09:27.491000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['x'].stride()[1] 1 None V0718 23:09:27.491000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['x'].storage_offset() 0 None V0718 23:09:27.492000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['y'].size()[0] s0 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo]) V0718 23:09:27.492000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['y'].size()[1] 4 StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo]) V0718 23:09:27.492000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['y'].stride()[0] 4 None V0718 23:09:27.493000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['y'].stride()[1] 1 None V0718 23:09:27.493000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [13/0] track_symint L['y'].storage_offset() 0 None E0718 23:09:27.494000 25987 torch/_guards.py:359] [13/0] Error while creating guard: E0718 23:09:27.494000 25987 torch/_guards.py:359] [13/0] Name: '' E0718 23:09:27.494000 25987 torch/_guards.py:359] [13/0] Source: shape_env E0718 23:09:27.494000 25987 torch/_guards.py:359] [13/0] Create Function: SHAPE_ENV E0718 23:09:27.494000 25987 torch/_guards.py:359] [13/0] Guard Types: None E0718 23:09:27.494000 25987 torch/_guards.py:359] [13/0] Code List: None E0718 23:09:27.494000 25987 torch/_guards.py:359] [13/0] Object Weakref: None E0718 23:09:27.494000 25987 torch/_guards.py:359] [13/0] Guarded Class Weakref: None E0718 23:09:27.494000 25987 torch/_guards.py:359] [13/0] Traceback (most recent call last): E0718 23:09:27.494000 25987 torch/_guards.py:359] [13/0] File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 357, in create E0718 23:09:27.494000 25987 torch/_guards.py:359] [13/0] return self.create_fn(builder, self) E0718 23:09:27.494000 25987 torch/_guards.py:359] [13/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1959, in SHAPE_ENV E0718 23:09:27.494000 25987 torch/_guards.py:359] [13/0] python_code_parts, verbose_code_parts = _get_code_parts( E0718 23:09:27.494000 25987 torch/_guards.py:359] [13/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1942, in _get_code_parts E0718 23:09:27.494000 25987 torch/_guards.py:359] [13/0] return output_graph.shape_env.produce_guards_verbose( E0718 23:09:27.494000 25987 torch/_guards.py:359] [13/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5409, in produce_guards_verbose E0718 23:09:27.494000 25987 torch/_guards.py:359] [13/0] raise ConstraintViolationError( E0718 23:09:27.494000 25987 torch/_guards.py:359] [13/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (d1, dy)! For more information, run with TORCH_LOGS="+dynamic". E0718 23:09:27.494000 25987 torch/_guards.py:359] [13/0] - Not all values of d1 = L['x'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4). E0718 23:09:27.494000 25987 torch/_guards.py:359] [13/0] - Not all values of d1 = L['y'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4). E0718 23:09:27.494000 25987 torch/_guards.py:359] [13/0] - The values of dy = L['y'].size()[0] and dx = L['x'].size()[0] must always be equal. E0718 23:09:27.496000 25987 torch/_guards.py:361] [13/0] Created at: E0718 23:09:27.496000 25987 torch/_guards.py:361] [13/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 694, in transform E0718 23:09:27.496000 25987 torch/_guards.py:361] [13/0] tracer = InstructionTranslator( E0718 23:09:27.496000 25987 torch/_guards.py:361] [13/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3329, in __init__ E0718 23:09:27.496000 25987 torch/_guards.py:361] [13/0] output=OutputGraph( E0718 23:09:27.496000 25987 torch/_guards.py:361] [13/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 358, in __init__ E0718 23:09:27.496000 25987 torch/_guards.py:361] [13/0] self.init_ambient_guards() E0718 23:09:27.496000 25987 torch/_guards.py:361] [13/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/output_graph.py", line 512, in init_ambient_guards E0718 23:09:27.496000 25987 torch/_guards.py:361] [13/0] self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV)) Traceback (most recent call last): File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 739, in _export_to_torch_ir gm_torch_level, _ = torch._dynamo.export( File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1722, in inner raise constraint_violation_error File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1677, in inner result_traced = opt_f(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 655, in _fn return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1432, in __call__ return self._torchdynamo_orig_callable( File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 598, in __call__ return _compile( File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1059, in _compile guarded_code = compile_inner(code, one_graph, hooks, transform) File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 97, in wrapper_function return function(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 761, in compile_inner return _compile_inner(code, one_graph, hooks, transform) File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 906, in _compile_inner check_fn = CheckFunctionManager( File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 2481, in __init__ guard.create(builder) File "/usr/local/lib/python3.10/dist-packages/torch/_guards.py", line 357, in create return self.create_fn(builder, self) File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1959, in SHAPE_ENV python_code_parts, verbose_code_parts = _get_code_parts( File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/guards.py", line 1942, in _get_code_parts return output_graph.shape_env.produce_guards_verbose( File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 5409, in produce_guards_verbose raise ConstraintViolationError( torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (d1, dy)! For more information, run with TORCH_LOGS="+dynamic". - Not all values of d1 = L['x'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4). - Not all values of d1 = L['y'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4). - The values of dy = L['y'].size()[0] and dx = L['x'].size()[0] must always be equal. Suggested fixes: d1 = 4 dy = dx During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 557, in ep = export( File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 360, in export return _export( File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1092, in wrapper raise e File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1065, in wrapper ep = fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 121, in wrapper return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2112, in _export ep = _export_for_training( File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1092, in wrapper raise e File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1065, in wrapper ep = fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 121, in wrapper return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1975, in _export_for_training export_artifact = export_func( # type: ignore[operator] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1344, in _strict_export_lower_to_aten_ir gm_torch_level = _export_to_torch_ir( File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 756, in _export_to_torch_ir raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 torch._dynamo.exc.UserError: Constraints violated (d1, dy)! For more information, run with TORCH_LOGS="+dynamic". - Not all values of d1 = L['x'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4). - Not all values of d1 = L['y'].size()[1] in the specified range are valid because d1 was inferred to be a constant (4). - The values of dy = L['y'].size()[0] and dx = L['x'].size()[0] must always be equal. Suggested fixes: d1 = 4 dy = dx .. GENERATED FROM PYTHON SOURCE LINES 569-580 The expectation with suggested fixes is that the user can interactively copy-paste the changes into their dynamic shapes specification, and successfully export afterwards. Lastly, there's couple nice-to-knows about the options for specification: - ``None`` is a good option for static behavior: - ``dynamic_shapes=None`` (default) exports with the entire model being static. - specifying ``None`` at an input-level exports with all tensor dimensions static, and is also required for non-tensor inputs. - specifying ``None`` at a dimension-level specializes that dimension, though this is deprecated in favor of ``Dim.STATIC``. - specifying per-dimension integer values also produces static behavior, and will additionally check that the provided sample input matches the specification. These options are combined in the inputs & dynamic shapes spec below: .. GENERATED FROM PYTHON SOURCE LINES 580-594 .. code-block:: Python inputs = ( torch.randn(4, 4), torch.randn(3, 3), 16, False, ) dynamic_shapes = { "tensor_0": (Dim.AUTO, None), "tensor_1": None, "int_val": None, "bool_val": None, } .. GENERATED FROM PYTHON SOURCE LINES 595-615 Data-dependent errors --------------------- While trying to export models, you have may have encountered errors like "Could not guard on data-dependent expression", or Could not extract specialized integer from data-dependent expression". These errors exist because ``torch.export()`` compiles programs using FakeTensors, which symbolically represent their real tensor counterparts. While these have equivalent symbolic properties (e.g. sizes, strides, dtypes), they diverge in that FakeTensors do not contain any data values. While this avoids unnecessary memory usage and expensive computation, it does mean that export may be unable to out-of-the-box compile parts of user code where compilation relies on data values. In short, if the compiler requires a concrete, data-dependent value in order to proceed, it will error out, complaining that the value is not available. Data-dependent values appear in many places, and common sources are calls like ``item()``, ``tolist()``, or ``torch.unbind()`` that extract scalar values from tensors. How are these values represented in the exported program? In the `Constraints/Dynamic Shapes `_ section, we talked about allocating symbols to represent dynamic input dimensions. The same happens here: we allocate symbols for every data-dependent value that appears in the program. The important distinction is that these are "unbacked" symbols, in contrast to the "backed" symbols allocated for input dimensions. The `"backed/unbacked" `_ nomenclature refers to the presence/absence of a "hint" for the symbol: a concrete value backing the symbol, that can inform the compiler on how to proceed. In the input shape symbol case (backed symbols), these hints are simply the sample input shapes provided, which explains why control-flow branching is determined by the sample input properties. For data-dependent values, the symbols are taken from FakeTensor "data" during tracing, and so the compiler doesn't know the actual values (hints) that these symbols would take on. Let's see how these show up in exported programs: .. GENERATED FROM PYTHON SOURCE LINES 615-629 .. code-block:: Python class Foo(torch.nn.Module): def forward(self, x, y): a = x.item() b = y.tolist() return b + [a] inps = ( torch.tensor(1), torch.tensor([2, 3]), ) ep = export(Foo(), inps) print(ep) .. rst-class:: sphx-glr-script-out .. code-block:: none I0718 23:09:27.511000 25987 torch/fx/experimental/symbolic_shapes.py:3334] [14/0] create_env I0718 23:09:27.515000 25987 torch/fx/experimental/symbolic_shapes.py:4276] [14/0] create_unbacked_symint u0 [-int_oo, int_oo] a = x.item() # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:618 in forward (_subclasses/fake_impls.py:422 in local_scalar_dense) I0718 23:09:27.516000 25987 torch/fx/experimental/symbolic_shapes.py:1130] [14/0] compute_unbacked_bindings [u0] I0718 23:09:27.518000 25987 torch/fx/experimental/symbolic_shapes.py:4276] [14/0] create_unbacked_symint u1 [-int_oo, int_oo] b = y.tolist() # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:619 in forward (_subclasses/fake_impls.py:422 in local_scalar_dense) I0718 23:09:27.518000 25987 torch/fx/experimental/symbolic_shapes.py:1130] [14/0] compute_unbacked_bindings [u1] I0718 23:09:27.520000 25987 torch/fx/experimental/symbolic_shapes.py:4276] [14/0] create_unbacked_symint u2 [-int_oo, int_oo] b = y.tolist() # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:619 in forward (_subclasses/fake_impls.py:422 in local_scalar_dense) I0718 23:09:27.521000 25987 torch/fx/experimental/symbolic_shapes.py:1130] [14/0] compute_unbacked_bindings [u2] I0718 23:09:27.524000 25987 torch/fx/experimental/symbolic_shapes.py:4734] [14/0] produce_guards V0718 23:09:27.525000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [14/0] track_symint L['x'].storage_offset() 0 None V0718 23:09:27.525000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [14/0] track_symint L['y'].size()[0] 2 None V0718 23:09:27.525000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [14/0] track_symint L['y'].stride()[0] 1 None V0718 23:09:27.525000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [14/0] track_symint L['y'].storage_offset() 0 None I0718 23:09:27.531000 25987 torch/fx/experimental/symbolic_shapes.py:4276] create_unbacked_symint u3 [-int_oo, int_oo] (_subclasses/fake_impls.py:422 in local_scalar_dense) I0718 23:09:27.532000 25987 torch/fx/experimental/symbolic_shapes.py:4276] create_unbacked_symint u4 [-int_oo, int_oo] (_subclasses/fake_impls.py:422 in local_scalar_dense) I0718 23:09:27.537000 25987 torch/fx/experimental/symbolic_shapes.py:4276] create_unbacked_symint u5 [-int_oo, int_oo] (_subclasses/fake_impls.py:422 in local_scalar_dense) I0718 23:09:27.538000 25987 torch/fx/experimental/symbolic_shapes.py:1130] compute_unbacked_bindings [u5] I0718 23:09:27.538000 25987 torch/fx/experimental/symbolic_shapes.py:6234] set_replacement u5 = u0 (rename_unbacked_to) VR[-int_oo, int_oo] I0718 23:09:27.540000 25987 torch/fx/experimental/symbolic_shapes.py:4276] create_unbacked_symint u6 [-int_oo, int_oo] (_subclasses/fake_impls.py:422 in local_scalar_dense) I0718 23:09:27.540000 25987 torch/fx/experimental/symbolic_shapes.py:1130] compute_unbacked_bindings [u6] I0718 23:09:27.541000 25987 torch/fx/experimental/symbolic_shapes.py:6234] set_replacement u6 = u1 (rename_unbacked_to) VR[-int_oo, int_oo] I0718 23:09:27.543000 25987 torch/fx/experimental/symbolic_shapes.py:4276] create_unbacked_symint u7 [-int_oo, int_oo] (_subclasses/fake_impls.py:422 in local_scalar_dense) I0718 23:09:27.543000 25987 torch/fx/experimental/symbolic_shapes.py:1130] compute_unbacked_bindings [u7] I0718 23:09:27.543000 25987 torch/fx/experimental/symbolic_shapes.py:6234] set_replacement u7 = u2 (rename_unbacked_to) VR[-int_oo, int_oo] ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "i64[]", y: "i64[2]"): # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:618 in forward, code: a = x.item() item: "Sym(u0)" = torch.ops.aten.item.default(x); x = None # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:619 in forward, code: b = y.tolist() select: "i64[]" = torch.ops.aten.select.int(y, 0, 0) item_1: "Sym(u1)" = torch.ops.aten.item.default(select); select = None select_1: "i64[]" = torch.ops.aten.select.int(y, 0, 1); y = None item_2: "Sym(u2)" = torch.ops.aten.item.default(select_1); select_1 = None return (item_1, item_2, item) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=SymIntArgument(name='item_1'), target=None), OutputSpec(kind=, arg=SymIntArgument(name='item_2'), target=None), OutputSpec(kind=, arg=SymIntArgument(name='item'), target=None)]) Range constraints: {u0: VR[-int_oo, int_oo], u1: VR[-int_oo, int_oo], u2: VR[-int_oo, int_oo], u3: VR[-int_oo, int_oo], u4: VR[-int_oo, int_oo], u5: VR[-int_oo, int_oo], u6: VR[-int_oo, int_oo], u7: VR[-int_oo, int_oo]} .. GENERATED FROM PYTHON SOURCE LINES 630-634 The result is that 3 unbacked symbols (notice they're prefixed with "u", instead of the usual "s" for input shape/backed symbols) are allocated and returned: 1 for the ``item()`` call, and 1 for each of the elements of ``y`` with the ``tolist()`` call. Note from the range constraints field that these take on ranges of ``[-int_oo, int_oo]``, not the default ``[0, int_oo]`` range allocated to input shape symbols, since we have no information on what these values are - they don't represent sizes, so don't necessarily have positive values. .. GENERATED FROM PYTHON SOURCE LINES 636-641 Guards, torch._check() ^^^^^^^^^^^^^^^^^^^^^^ But the case above is easy to export, because the concrete values of these symbols aren't used in any compiler decision-making; all that's relevant is that the return values are unbacked symbols. The data-dependent errors highlighted in this section are cases like the following, where `data-dependent guards `_ are encountered: .. GENERATED FROM PYTHON SOURCE LINES 641-650 .. code-block:: Python class Foo(torch.nn.Module): def forward(self, x, y): a = x.item() if a // 2 >= 5: return y + 2 else: return y * 5 .. GENERATED FROM PYTHON SOURCE LINES 651-669 Here we actually need the "hint", or the concrete value of ``a`` for the compiler to decide whether to trace ``return y + 2`` or ``return y * 5`` as the output. Because we trace with FakeTensors, we don't know what ``a // 2 >= 5`` actually evaluates to, and export errors out with "Could not guard on data-dependent expression ``u0 // 2 >= 5 (unhinted)``". So how do we export this toy model? Unlike ``torch.compile()``, export requires full graph compilation, and we can't just graph break on this. Here are some basic options: 1. Manual specialization: we could intervene by selecting the branch to trace, either by removing the control-flow code to contain only the specialized branch, or using ``torch.compiler.is_compiling()`` to guard what's traced at compile-time. 2. ``torch.cond()``: we could rewrite the control-flow code to use ``torch.cond()`` so we don't specialize on a branch. While these options are valid, they have their pitfalls. Option 1 sometimes requires drastic, invasive rewrites of the model code to specialize, and ``torch.cond()`` is not a comprehensive system for handling data-dependent errors. As we will see, there are data-dependent errors that do not involve control-flow. The generally recommended approach is to start with ``torch._check()`` calls. While these give the impression of purely being assert statements, they are in fact a system of informing the compiler on properties of symbols. While a ``torch._check()`` call does act as an assertion at runtime, when traced at compile-time, the checked expression is sent to the symbolic shapes subsystem for reasoning, and any symbol properties that follow from the expression being true, are stored as symbol properties (provided it's smart enough to infer those properties). So even if unbacked symbols don't have hints, if we're able to communicate properties that are generally true for these symbols via ``torch._check()`` calls, we can potentially bypass data-dependent guards without rewriting the offending model code. For example in the model above, inserting ``torch._check(a >= 10)`` would tell the compiler that ``y + 2`` can always be returned, and ``torch._check(a == 4)`` tells it to return ``y * 5``. See what happens when we re-export this model. .. GENERATED FROM PYTHON SOURCE LINES 669-687 .. code-block:: Python class Foo(torch.nn.Module): def forward(self, x, y): a = x.item() torch._check(a >= 10) torch._check(a <= 60) if a // 2 >= 5: return y + 2 else: return y * 5 inps = ( torch.tensor(32), torch.randn(4), ) ep = export(Foo(), inps) print(ep) .. rst-class:: sphx-glr-script-out .. code-block:: none I0718 23:09:27.553000 25987 torch/fx/experimental/symbolic_shapes.py:3334] [15/0] create_env I0718 23:09:27.557000 25987 torch/fx/experimental/symbolic_shapes.py:4276] [15/0] create_unbacked_symint u0 [-int_oo, int_oo] a = x.item() # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:672 in forward (_subclasses/fake_impls.py:422 in local_scalar_dense) I0718 23:09:27.557000 25987 torch/fx/experimental/symbolic_shapes.py:1130] [15/0] compute_unbacked_bindings [u0] I0718 23:09:27.560000 25987 torch/fx/experimental/symbolic_shapes.py:6630] [15/0] runtime_assert u0 >= 10 [guard added] torch._check(a >= 10) # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:673 in forward (_dynamo/utils.py:3284 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 10" V0718 23:09:27.561000 25987 torch/fx/experimental/symbolic_shapes.py:6071] [15/0] _update_var_to_range u0 = VR[10, int_oo] (update) I0718 23:09:27.566000 25987 torch/fx/experimental/symbolic_shapes.py:6630] [15/0] runtime_assert u0 <= 60 [guard added] torch._check(a <= 60) # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:674 in forward (_dynamo/utils.py:3284 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 <= 60" V0718 23:09:27.567000 25987 torch/fx/experimental/symbolic_shapes.py:6071] [15/0] _update_var_to_range u0 = VR[10, 60] (update) V0718 23:09:27.572000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [15/0] eval False == True [statically known] V0718 23:09:27.575000 25987 torch/fx/experimental/symbolic_shapes.py:7018] [15/0] runtime_assert u0 >= 10 == True [statically known] V0718 23:09:27.576000 25987 torch/fx/experimental/symbolic_shapes.py:7018] [15/0] runtime_assert u0 <= 60 == True [statically known] I0718 23:09:27.579000 25987 torch/fx/experimental/symbolic_shapes.py:4734] [15/0] produce_guards V0718 23:09:27.579000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [15/0] track_symint L['x'].storage_offset() 0 None V0718 23:09:27.579000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [15/0] track_symint L['y'].size()[0] 4 None V0718 23:09:27.580000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [15/0] track_symint L['y'].stride()[0] 1 None V0718 23:09:27.580000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [15/0] track_symint L['y'].storage_offset() 0 None I0718 23:09:27.592000 25987 torch/fx/experimental/symbolic_shapes.py:4276] create_unbacked_symint u1 [-int_oo, int_oo] (_subclasses/fake_impls.py:422 in local_scalar_dense) I0718 23:09:27.593000 25987 torch/fx/experimental/symbolic_shapes.py:1130] compute_unbacked_bindings [u1] V0718 23:09:27.593000 25987 torch/fx/experimental/symbolic_shapes.py:6071] _update_var_to_range u1 = VR[10, 60] (update) I0718 23:09:27.593000 25987 torch/fx/experimental/symbolic_shapes.py:6234] set_replacement u1 = u0 (rename_unbacked_to) VR[10, 60] ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "i64[]", y: "f32[4]"): # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:672 in forward, code: a = x.item() item: "Sym(u0)" = torch.ops.aten.item.default(x); x = None ge_1: "Sym(u0 >= 10)" = item >= 10 _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 10 on node 'ge_1'"); ge_1 = _assert_scalar_default = None le_1: "Sym(u0 <= 60)" = item <= 60; item = None _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 60 on node 'le_1'"); le_1 = _assert_scalar_default_1 = None # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:676 in forward, code: return y + 2 add: "f32[4]" = torch.ops.aten.add.Tensor(y, 2); y = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='add'), target=None)]) Range constraints: {u0: VR[10, 60], u1: VR[10, 60]} .. GENERATED FROM PYTHON SOURCE LINES 688-698 Export succeeds, and note from the range constraints field that ``u0`` takes on a range of ``[10, 60]``. So what information do ``torch._check()`` calls actually communicate? This varies as the symbolic shapes subsystem gets smarter, but at a fundamental level, these are generally true: 1. Equality with non-data-dependent expressions: ``torch._check()`` calls that communicate equalities like ``u0 == s0 + 4`` or ``u0 == 5``. 2. Range refinement: calls that provide lower or upper bounds for symbols, like the above. 3. Some basic reasoning around more complicated expressions: inserting ``torch._check(a < 4)`` will typically tell the compiler that ``a >= 4`` is false. Checks on complex expressions like ``torch._check(a ** 2 - 3 * a <= 10)`` will typically get you past identical guards. As mentioned previously, ``torch._check()`` calls have applicability outside of data-dependent control flow. For example, here's a model where ``torch._check()`` insertion prevails while manual specialization & ``torch.cond()`` do not: .. GENERATED FROM PYTHON SOURCE LINES 698-713 .. code-block:: Python class Foo(torch.nn.Module): def forward(self, x, y): a = x.item() return y[a] inps = ( torch.tensor(32), torch.randn(60), ) try: export(Foo(), inps) except Exception: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none I0718 23:09:27.607000 25987 torch/fx/experimental/symbolic_shapes.py:3334] [16/0] create_env I0718 23:09:27.611000 25987 torch/fx/experimental/symbolic_shapes.py:4276] [16/0] create_unbacked_symint u0 [-int_oo, int_oo] a = x.item() # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:701 in forward (_subclasses/fake_impls.py:422 in local_scalar_dense) I0718 23:09:27.611000 25987 torch/fx/experimental/symbolic_shapes.py:1130] [16/0] compute_unbacked_bindings [u0] V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] Data dependent variable 'u0' allocated at: V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/bin/sphinx-build", line 8, in V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] sys.exit(main()) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 313, in main V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return make_main(argv) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 195, in make_main V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return make_mode.run_make_mode(argv[1:]) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 160, in run_make_mode V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return make.run_generic_build(args[0]) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 148, in run_generic_build V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return build_main(args + opts) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 276, in build_main V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] app = Sphinx(args.sourcedir, args.confdir, args.outputdir, V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 262, in __init__ V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] self._init_builder() V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 335, in _init_builder V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] self.events.emit('builder-inited') V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/sphinx/events.py", line 94, in emit V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] results.append(listener.handler(self.app, *args)) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 743, in generate_gallery_rst V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] ) = generate_dir_rst( V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 598, in generate_dir_rst V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] results = parallel( V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 599, in V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] p_fun(fname, target_dir, src_dir, gallery_conf) for fname in iterator V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/var/lib/workspace/conf.py", line 79, in wrapper V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] p.start() V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/lib/python3.10/multiprocessing/process.py", line 121, in start V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] self._popen = self._Popen(self) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/lib/python3.10/multiprocessing/context.py", line 224, in _Popen V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return _default_context.get_context().Process._Popen(process_obj) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/lib/python3.10/multiprocessing/context.py", line 281, in _Popen V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return Popen(process_obj) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__ V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] self._launch(process_obj) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 71, in _launch V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] code = process_obj._bootstrap(parent_sentinel=child_r) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] self.run() V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] self._target(*self._args, **self._kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/var/lib/workspace/conf.py", line 67, in call_fn V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] result = func(*args, **kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1346, in generate_file_rst V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] output_blocks, time_elapsed = execute_script( V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1164, in execute_script V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] execute_code_block( V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1020, in execute_code_block V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] is_last_expr, mem_max = _exec_and_get_memory( V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 865, in _exec_and_get_memory V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] mem_max, _ = call_memory( V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1700, in _sg_call_memory_noop V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return 0.0, func() V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 783, in __call__ V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] exec(self.code, self.fake_main.__dict__) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 709, in V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] export(Foo(), inps) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 360, in export V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return _export( V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1065, in wrapper V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] ep = fn(*args, **kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 121, in wrapper V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return fn(*args, **kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2112, in _export V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] ep = _export_for_training( V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1065, in wrapper V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] ep = fn(*args, **kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 121, in wrapper V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return fn(*args, **kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1975, in _export_for_training V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] export_artifact = export_func( # type: ignore[operator] V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1344, in _strict_export_lower_to_aten_ir V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] gm_torch_level = _export_to_torch_ir( V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 739, in _export_to_torch_ir V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] gm_torch_level, _ = torch._dynamo.export( V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1677, in inner V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] result_traced = opt_f(*args, **kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return self._call_impl(*args, **kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return forward_call(*args, **kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 655, in _fn V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return fn(*args, **kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return self._call_impl(*args, **kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return forward_call(*args, **kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1432, in __call__ V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return self._torchdynamo_orig_callable( V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 598, in __call__ V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return _compile( V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 1059, in _compile V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] guarded_code = compile_inner(code, one_graph, hooks, transform) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_utils_internal.py", line 97, in wrapper_function V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return function(*args, **kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 761, in compile_inner V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return _compile_inner(code, one_graph, hooks, transform) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 797, in _compile_inner V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] out_code = transform_code_object(code, transform) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/bytecode_transformation.py", line 1422, in transform_code_object V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] transformations(instructions, code_options) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 257, in _fn V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return fn(*args, **kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/convert_frame.py", line 715, in transform V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] tracer.run() V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 3500, in run V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] super().run() V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1337, in run V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] while self.step(): V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1246, in step V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] self.dispatch_table[inst.opcode](self, inst) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 819, in wrapper V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return inner_fn(self, inst) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 2168, in CALL_FUNCTION V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] self.call_function(fn, args, {}) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/symbolic_convert.py", line 1170, in call_function V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/misc.py", line 903, in call_function V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return self.obj.call_method(tx, self.name, args, kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/tensor.py", line 632, in call_method V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return wrap_fx_proxy( V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2302, in wrap_fx_proxy V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2368, in wrap_fx_proxy_cls V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return _wrap_fx_proxy( V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/variables/builder.py", line 2464, in _wrap_fx_proxy V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 3127, in get_fake_value V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] ret_val = wrap_fake_exception( V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 2641, in wrap_fake_exception V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return fn() V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 3128, in V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] lambda: run_node(tx.output, node, args, kwargs, nnmodule) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/utils.py", line 3295, in run_node V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return getattr(args[0], node.target)(*args[1:], **kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 27, in wrapper V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return fn(*args, **kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1282, in __torch_dispatch__ V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return self.dispatch(func, types, args, kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1823, in dispatch V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return self._cached_dispatch_impl(func, types, args, kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1393, in _cached_dispatch_impl V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] output = self._dispatch_impl(func, types, args, kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2397, in _dispatch_impl V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] op_impl_out = op_impl(self, func, *args, **kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_impls.py", line 160, in dispatch_to_op_implementations_dict V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return op_implementations_dict[func](fake_mode, func, *args, **kwargs) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_impls.py", line 422, in local_scalar_dense V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] r = fake_mode.shape_env.create_unbacked_symint() V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] return retlog(fn(*args, **kwargs)) V0718 23:09:27.615000 25987 torch/fx/experimental/symbolic_shapes.py:5984] [16/0] W0718 23:09:27.624000 25987 torch/fx/experimental/symbolic_shapes.py:6679] [16/0] failed during evaluate_expr(-u0 > 60, hint=None, size_oblivious=True, forcing_spec=False E0718 23:09:27.624000 25987 torch/fx/experimental/recording.py:299] [16/0] failed while running evaluate_expr(*(-u0 > 60, None, False, True), **{}) E0718 23:09:27.624000 25987 torch/fx/experimental/recording.py:299] [16/0] Traceback (most recent call last): E0718 23:09:27.624000 25987 torch/fx/experimental/recording.py:299] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper E0718 23:09:27.624000 25987 torch/fx/experimental/recording.py:299] [16/0] return retlog(fn(*args, **kwargs)) E0718 23:09:27.624000 25987 torch/fx/experimental/recording.py:299] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6671, in evaluate_expr E0718 23:09:27.624000 25987 torch/fx/experimental/recording.py:299] [16/0] return self._evaluate_expr( E0718 23:09:27.624000 25987 torch/fx/experimental/recording.py:299] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6894, in _evaluate_expr E0718 23:09:27.624000 25987 torch/fx/experimental/recording.py:299] [16/0] raise self._make_data_dependent_error( E0718 23:09:27.624000 25987 torch/fx/experimental/recording.py:299] [16/0] torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression -u0 > 60 (unhinted: -u0 > 60). (Size-like symbols: none) E0718 23:09:27.624000 25987 torch/fx/experimental/recording.py:299] [16/0] E0718 23:09:27.624000 25987 torch/fx/experimental/recording.py:299] [16/0] Caused by: return y[a] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:702 in forward (_meta_registrations.py:5278 in meta_select) E0718 23:09:27.624000 25987 torch/fx/experimental/recording.py:299] [16/0] For more information, run with TORCH_LOGS="dynamic" E0718 23:09:27.624000 25987 torch/fx/experimental/recording.py:299] [16/0] For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0" E0718 23:09:27.624000 25987 torch/fx/experimental/recording.py:299] [16/0] If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 E0718 23:09:27.624000 25987 torch/fx/experimental/recording.py:299] [16/0] For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing E0718 23:09:27.624000 25987 torch/fx/experimental/recording.py:299] [16/0] E0718 23:09:27.624000 25987 torch/fx/experimental/recording.py:299] [16/0] User Stack (most recent call last): E0718 23:09:27.624000 25987 torch/fx/experimental/recording.py:299] [16/0] (snipped, see stack below for prefix) E0718 23:09:27.624000 25987 torch/fx/experimental/recording.py:299] [16/0] File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward E0718 23:09:27.624000 25987 torch/fx/experimental/recording.py:299] [16/0] return y[a] E0718 23:09:27.624000 25987 torch/fx/experimental/recording.py:299] [16/0] E0718 23:09:27.624000 25987 torch/fx/experimental/recording.py:299] [16/0] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] failed while attempting to run meta for aten.select.int E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] Traceback (most recent call last): E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2427, in _dispatch_impl E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] r = func(*args, **kwargs) E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 756, in __call__ E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] return self._op(*args, **kwargs) E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/_meta_registrations.py", line 5278, in meta_select E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] guard_size_oblivious(-index > size) or guard_size_oblivious(index >= size) E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 408, in guard_size_oblivious E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] return expr.node.guard_size_oblivious("", 0) E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 588, in guard_size_oblivious E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] r = self.evaluate(size_oblivious=True) E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 510, in evaluate E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] return self.shape_env.evaluate_sym_node(self, size_oblivious) E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6655, in evaluate_sym_node E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] return self.evaluate_expr( E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] return retlog(fn(*args, **kwargs)) E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6671, in evaluate_expr E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] return self._evaluate_expr( E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6894, in _evaluate_expr E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] raise self._make_data_dependent_error( E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not guard on data-dependent expression -u0 > 60 (unhinted: -u0 > 60). (Size-like symbols: none) E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] Caused by: return y[a] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:702 in forward (_meta_registrations.py:5278 in meta_select) E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] For more information, run with TORCH_LOGS="dynamic" E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0" E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] User Stack (most recent call last): E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] (snipped, see stack below for prefix) E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] return y[a] E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] E0718 23:09:27.626000 25987 torch/_subclasses/fake_tensor.py:2431] [16/0] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 709, in export(Foo(), inps) File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 360, in export return _export( File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1092, in wrapper raise e File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1065, in wrapper ep = fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 121, in wrapper return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2112, in _export ep = _export_for_training( File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1092, in wrapper raise e File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1065, in wrapper ep = fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 121, in wrapper return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1975, in _export_for_training export_artifact = export_func( # type: ignore[operator] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1344, in _strict_export_lower_to_aten_ir gm_torch_level = _export_to_torch_ir( File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 739, in _export_to_torch_ir gm_torch_level, _ = torch._dynamo.export( File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 1677, in inner result_traced = opt_f(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 659, in _fn raise e.with_traceback(None) from None torch._dynamo.exc.UserError: Could not guard on data-dependent expression -u0 > 60 (unhinted: -u0 > 60). (Size-like symbols: none) Caused by: return y[a] # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:702 in forward (_meta_registrations.py:5278 in meta_select) For more information, run with TORCH_LOGS="dynamic" For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0" If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing User Stack (most recent call last): (snipped, see stack below for prefix) File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward return y[a] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#constrain-as-size-example from user code: File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 702, in forward return y[a] Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo" .. GENERATED FROM PYTHON SOURCE LINES 714-718 Here is a scenario where ``torch._check()`` insertion is required simply to prevent an operation from failing. The export call will fail with "Could not guard on data-dependent expression ``-u0 > 60``", implying that the compiler doesn't know if this is a valid indexing operation - if the value of ``x`` is out-of-bounds for ``y`` or not. Here, manual specialization is too prohibitive, and ``torch.cond()`` has no place. Instead, informing the compiler of ``u0``'s range is sufficient: .. GENERATED FROM PYTHON SOURCE LINES 718-733 .. code-block:: Python class Foo(torch.nn.Module): def forward(self, x, y): a = x.item() torch._check(a >= 0) torch._check(a < y.shape[0]) return y[a] inps = ( torch.tensor(32), torch.randn(60), ) ep = export(Foo(), inps) print(ep) .. rst-class:: sphx-glr-script-out .. code-block:: none I0718 23:09:27.638000 25987 torch/fx/experimental/symbolic_shapes.py:3334] [17/0] create_env I0718 23:09:27.642000 25987 torch/fx/experimental/symbolic_shapes.py:4276] [17/0] create_unbacked_symint u0 [-int_oo, int_oo] a = x.item() # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:721 in forward (_subclasses/fake_impls.py:422 in local_scalar_dense) I0718 23:09:27.643000 25987 torch/fx/experimental/symbolic_shapes.py:1130] [17/0] compute_unbacked_bindings [u0] I0718 23:09:27.644000 25987 torch/fx/experimental/symbolic_shapes.py:6630] [17/0] runtime_assert u0 >= 0 [guard added] torch._check(a >= 0) # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:722 in forward (_dynamo/utils.py:3284 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 0" V0718 23:09:27.645000 25987 torch/fx/experimental/symbolic_shapes.py:6071] [17/0] _update_var_to_range u0 = VR[0, int_oo] (update) I0718 23:09:27.648000 25987 torch/fx/experimental/symbolic_shapes.py:6630] [17/0] runtime_assert u0 < 60 [guard added] torch._check(a < y.shape[0]) # ar/lib/workspace/intermediate_source/torch_export_tutorial.py:723 in forward (_dynamo/utils.py:3284 in run_node), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 < 60" V0718 23:09:27.649000 25987 torch/fx/experimental/symbolic_shapes.py:6071] [17/0] _update_var_to_range u0 = VR[0, 59] (update) V0718 23:09:27.652000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [17/0] eval size_oblivious(-u0 > 60) == False [statically known] V0718 23:09:27.652000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [17/0] eval size_oblivious(u0 >= 60) == False [statically known] V0718 23:09:27.653000 25987 torch/fx/experimental/symbolic_shapes.py:6787] [17/0] eval False == True [statically known] V0718 23:09:27.656000 25987 torch/fx/experimental/symbolic_shapes.py:7018] [17/0] runtime_assert u0 >= 0 == True [statically known] V0718 23:09:27.657000 25987 torch/fx/experimental/symbolic_shapes.py:7018] [17/0] runtime_assert u0 <= 59 == True [statically known] V0718 23:09:27.658000 25987 torch/fx/experimental/symbolic_shapes.py:7018] [17/0] runtime_assert u0 < 60 == True [statically known] I0718 23:09:27.661000 25987 torch/fx/experimental/symbolic_shapes.py:4734] [17/0] produce_guards V0718 23:09:27.661000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [17/0] track_symint L['x'].storage_offset() 0 None V0718 23:09:27.662000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [17/0] track_symint L['y'].size()[0] 60 None V0718 23:09:27.662000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [17/0] track_symint L['y'].stride()[0] 1 None V0718 23:09:27.662000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [17/0] track_symint L['y'].storage_offset() 0 None I0718 23:09:27.675000 25987 torch/fx/experimental/symbolic_shapes.py:4276] create_unbacked_symint u1 [-int_oo, int_oo] (_subclasses/fake_impls.py:422 in local_scalar_dense) I0718 23:09:27.676000 25987 torch/fx/experimental/symbolic_shapes.py:1130] compute_unbacked_bindings [u1] V0718 23:09:27.676000 25987 torch/fx/experimental/symbolic_shapes.py:6071] _update_var_to_range u1 = VR[0, 59] (update) I0718 23:09:27.677000 25987 torch/fx/experimental/symbolic_shapes.py:6234] set_replacement u1 = u0 (rename_unbacked_to) VR[0, 59] ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "i64[]", y: "f32[60]"): # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:721 in forward, code: a = x.item() item: "Sym(u0)" = torch.ops.aten.item.default(x); x = None ge_1: "Sym(u0 >= 0)" = item >= 0 _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_1 = _assert_scalar_default = None le_1: "Sym(u0 <= 59)" = item <= 59 _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(le_1, "Runtime assertion failed for expression u0 <= 59 on node 'le_1'"); le_1 = _assert_scalar_default_1 = None # lt_1: "Sym(u0 < 60)" = item < 60 _assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(lt_1, "Runtime assertion failed for expression u0 < 60 on node 'lt_1'"); lt_1 = _assert_scalar_default_2 = None # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:724 in forward, code: return y[a] select: "f32[]" = torch.ops.aten.select.int(y, 0, item); y = item = None return (select,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='select'), target=None)]) Range constraints: {u0: VR[0, 59], u1: VR[0, 59]} .. GENERATED FROM PYTHON SOURCE LINES 734-744 Specialized values ^^^^^^^^^^^^^^^^^^ Another category of data-dependent error happens when the program attempts to extract a concrete data-dependent integer/float value while tracing. This looks something like "Could not extract specialized integer from data-dependent expression", and is analogous to the previous class of errors - if these occur when attempting to evaluate concrete integer/float values, data-dependent guard errors arise with evaluating concrete boolean values. This error typically occurs when there is an explicit or implicit ``int()`` cast on a data-dependent expression. For example, this list comprehension has a `range()` call that implicitly does an ``int()`` cast on the size of the list: .. GENERATED FROM PYTHON SOURCE LINES 744-760 .. code-block:: Python class Foo(torch.nn.Module): def forward(self, x, y): a = x.item() b = torch.cat([y for y in range(a)], dim=0) return b + int(a) inps = ( torch.tensor(32), torch.randn(60), ) try: export(Foo(), inps, strict=False) except Exception: tb.print_exc() .. rst-class:: sphx-glr-script-out .. code-block:: none I0718 23:09:27.693000 25987 torch/fx/experimental/symbolic_shapes.py:3334] create_env I0718 23:09:27.698000 25987 torch/fx/experimental/symbolic_shapes.py:4276] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:422 in local_scalar_dense) I0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:1130] compute_unbacked_bindings [u0] V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] Data dependent variable 'u0' allocated at: V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/bin/sphinx-build", line 8, in V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] sys.exit(main()) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 313, in main V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return make_main(argv) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 195, in make_main V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return make_mode.run_make_mode(argv[1:]) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 160, in run_make_mode V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return make.run_generic_build(args[0]) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/make_mode.py", line 148, in run_generic_build V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return build_main(args + opts) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/sphinx/cmd/build.py", line 276, in build_main V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] app = Sphinx(args.sourcedir, args.confdir, args.outputdir, V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 262, in __init__ V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] self._init_builder() V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/sphinx/application.py", line 335, in _init_builder V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] self.events.emit('builder-inited') V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/sphinx/events.py", line 94, in emit V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] results.append(listener.handler(self.app, *args)) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_gallery.py", line 743, in generate_gallery_rst V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] ) = generate_dir_rst( V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 598, in generate_dir_rst V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] results = parallel( V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 599, in V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] p_fun(fname, target_dir, src_dir, gallery_conf) for fname in iterator V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/var/lib/workspace/conf.py", line 79, in wrapper V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] p.start() V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/lib/python3.10/multiprocessing/process.py", line 121, in start V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] self._popen = self._Popen(self) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/lib/python3.10/multiprocessing/context.py", line 224, in _Popen V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return _default_context.get_context().Process._Popen(process_obj) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/lib/python3.10/multiprocessing/context.py", line 281, in _Popen V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return Popen(process_obj) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 19, in __init__ V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] self._launch(process_obj) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/lib/python3.10/multiprocessing/popen_fork.py", line 71, in _launch V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] code = process_obj._bootstrap(parent_sentinel=child_r) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] self.run() V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] self._target(*self._args, **self._kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/var/lib/workspace/conf.py", line 67, in call_fn V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] result = func(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1346, in generate_file_rst V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] output_blocks, time_elapsed = execute_script( V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1164, in execute_script V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] execute_code_block( V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1020, in execute_code_block V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] is_last_expr, mem_max = _exec_and_get_memory( V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 865, in _exec_and_get_memory V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] mem_max, _ = call_memory( V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 1700, in _sg_call_memory_noop V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return 0.0, func() V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/sphinx_gallery/gen_rst.py", line 783, in __call__ V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] exec(self.code, self.fake_main.__dict__) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 756, in V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] export(Foo(), inps, strict=False) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 360, in export V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return _export( V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1065, in wrapper V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] ep = fn(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 121, in wrapper V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return fn(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2112, in _export V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] ep = _export_for_training( V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1065, in wrapper V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] ep = fn(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 121, in wrapper V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return fn(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1975, in _export_for_training V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] export_artifact = export_func( # type: ignore[operator] V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1910, in _non_strict_export V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] aten_export_artifact = _to_aten_func( # type: ignore[operator] V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1696, in _export_to_aten_ir_make_fx V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] gm, graph_signature = transform(_make_fx_helper)( V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1840, in _aot_export_non_strict V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1616, in _make_fx_helper V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] gm = make_fx( V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2240, in wrapped V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return make_fx_tracer.trace(f, *args) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2178, in trace V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return self._trace_inner(f, *args) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2149, in _trace_inner V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] t = dispatch_trace( V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 51, in inner V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return disable_fn(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 838, in _fn V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return fn(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1174, in dispatch_trace V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] graph = tracer.trace(root, concrete_args) # type: ignore[arg-type] V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1738, in trace V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] res = super().trace(root, concrete_args) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 838, in _fn V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return fn(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 838, in trace V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] (self.create_arg(fn(*args)),), V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1229, in wrapped V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] out = f(*tensors) # type:ignore[call-arg] V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "", line 1, in V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1520, in wrapped_fn V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return tuple(flat_fn(*args)) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] tree_out = fn(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 903, in functional_call V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] out = mod(*args[params_len:], **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 813, in module_call_wrapper V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return self.call_module(mod, forward, args, kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1808, in call_module V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return Tracer.call_module(self, m, forward, args, kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 531, in call_module V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] ret_val = forward(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 806, in forward V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return _orig_module_call(mod, *args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return self._call_impl(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return forward_call(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1824, in forward V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] tree_out = mod(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 813, in module_call_wrapper V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return self.call_module(mod, forward, args, kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1808, in call_module V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return Tracer.call_module(self, m, forward, args, kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 531, in call_module V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] ret_val = forward(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 806, in forward V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return _orig_module_call(mod, *args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return self._call_impl(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return forward_call(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 747, in forward V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] a = x.item() V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1277, in __torch_function__ V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return func(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1324, in __torch_function__ V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return func(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/_export/non_strict_utils.py", line 683, in __torch_function__ V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return func(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 875, in handler V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return torch._library.utils.handle_dispatch_mode( V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/_library/utils.py", line 296, in handle_dispatch_mode V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 27, in wrapper V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return fn(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1379, in __torch_dispatch__ V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return proxy_call(self, func, self.pre_dispatch, args, kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 914, in proxy_call V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] out = func(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/_ops.py", line 756, in __call__ V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return self._op(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/utils/_stats.py", line 27, in wrapper V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return fn(*args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1282, in __torch_dispatch__ V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return self.dispatch(func, types, args, kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1823, in dispatch V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return self._cached_dispatch_impl(func, types, args, kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 1393, in _cached_dispatch_impl V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] output = self._dispatch_impl(func, types, args, kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_tensor.py", line 2397, in _dispatch_impl V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] op_impl_out = op_impl(self, func, *args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_impls.py", line 160, in dispatch_to_op_implementations_dict V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return op_implementations_dict[func](fake_mode, func, *args, **kwargs) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/_subclasses/fake_impls.py", line 422, in local_scalar_dense V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] r = fake_mode.shape_env.create_unbacked_symint() V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] return retlog(fn(*args, **kwargs)) V0718 23:09:27.699000 25987 torch/fx/experimental/symbolic_shapes.py:5984] W0718 23:09:27.709000 25987 torch/fx/experimental/symbolic_shapes.py:6679] failed during evaluate_expr(u0, hint=None, size_oblivious=False, forcing_spec=False E0718 23:09:27.709000 25987 torch/fx/experimental/recording.py:299] failed while running evaluate_expr(*(u0, None, False, False), **{}) E0718 23:09:27.709000 25987 torch/fx/experimental/recording.py:299] Traceback (most recent call last): E0718 23:09:27.709000 25987 torch/fx/experimental/recording.py:299] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper E0718 23:09:27.709000 25987 torch/fx/experimental/recording.py:299] return retlog(fn(*args, **kwargs)) E0718 23:09:27.709000 25987 torch/fx/experimental/recording.py:299] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6671, in evaluate_expr E0718 23:09:27.709000 25987 torch/fx/experimental/recording.py:299] return self._evaluate_expr( E0718 23:09:27.709000 25987 torch/fx/experimental/recording.py:299] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6894, in _evaluate_expr E0718 23:09:27.709000 25987 torch/fx/experimental/recording.py:299] raise self._make_data_dependent_error( E0718 23:09:27.709000 25987 torch/fx/experimental/recording.py:299] torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: none) E0718 23:09:27.709000 25987 torch/fx/experimental/recording.py:299] E0718 23:09:27.709000 25987 torch/fx/experimental/recording.py:299] Caused by: (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:748 in forward) E0718 23:09:27.709000 25987 torch/fx/experimental/recording.py:299] For more information, run with TORCH_LOGS="dynamic" E0718 23:09:27.709000 25987 torch/fx/experimental/recording.py:299] For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0" E0718 23:09:27.709000 25987 torch/fx/experimental/recording.py:299] If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 E0718 23:09:27.709000 25987 torch/fx/experimental/recording.py:299] For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing E0718 23:09:27.709000 25987 torch/fx/experimental/recording.py:299] E0718 23:09:27.709000 25987 torch/fx/experimental/recording.py:299] For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 def forward(self, arg0_1: "i64[]", arg1_1: "f32[60]"): # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:747 in forward, code: a = x.item() item: "Sym(u0)" = torch.ops.aten.item.default(arg0_1); arg0_1 = item = None Traceback (most recent call last): File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 756, in export(Foo(), inps, strict=False) File "/usr/local/lib/python3.10/dist-packages/torch/export/__init__.py", line 360, in export return _export( File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1092, in wrapper raise e File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1065, in wrapper ep = fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 121, in wrapper return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 2112, in _export ep = _export_for_training( File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1092, in wrapper raise e File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1065, in wrapper ep = fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/exported_program.py", line 121, in wrapper return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1975, in _export_for_training export_artifact = export_func( # type: ignore[operator] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1910, in _non_strict_export aten_export_artifact = _to_aten_func( # type: ignore[operator] File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1696, in _export_to_aten_ir_make_fx gm, graph_signature = transform(_make_fx_helper)( File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1840, in _aot_export_non_strict gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags) File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1616, in _make_fx_helper gm = make_fx( File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2240, in wrapped return make_fx_tracer.trace(f, *args) File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2178, in trace return self._trace_inner(f, *args) File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 2149, in _trace_inner t = dispatch_trace( File "/usr/local/lib/python3.10/dist-packages/torch/_compile.py", line 51, in inner return disable_fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 838, in _fn return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1174, in dispatch_trace graph = tracer.trace(root, concrete_args) # type: ignore[arg-type] File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1738, in trace res = super().trace(root, concrete_args) File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 838, in _fn return fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 838, in trace (self.create_arg(fn(*args)),), File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1229, in wrapped out = f(*tensors) # type:ignore[call-arg] File "", line 1, in File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1520, in wrapped_fn return tuple(flat_fn(*args)) File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/utils.py", line 184, in flat_fn tree_out = fn(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/_functorch/_aot_autograd/traced_function_transforms.py", line 903, in functional_call out = mod(*args[params_len:], **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 813, in module_call_wrapper return self.call_module(mod, forward, args, kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1808, in call_module return Tracer.call_module(self, m, forward, args, kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 531, in call_module ret_val = forward(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 806, in forward return _orig_module_call(mod, *args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/export/_trace.py", line 1824, in forward tree_out = mod(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 813, in module_call_wrapper return self.call_module(mod, forward, args, kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/proxy_tensor.py", line 1808, in call_module return Tracer.call_module(self, m, forward, args, kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 531, in call_module ret_val = forward(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/fx/_symbolic_trace.py", line 806, in forward return _orig_module_call(mod, *args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1762, in _call_impl return forward_call(*args, **kwargs) File "/var/lib/workspace/intermediate_source/torch_export_tutorial.py", line 748, in forward b = torch.cat([y for y in range(a)], dim=0) File "/usr/local/lib/python3.10/dist-packages/torch/__init__.py", line 431, in __index__ return self.node.int_() File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 466, in int_ return self.guard_int("", 0) # NB: uses Python backtrace File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 516, in guard_int r = self.evaluate() File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/sym_node.py", line 510, in evaluate return self.shape_env.evaluate_sym_node(self, size_oblivious) File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6655, in evaluate_sym_node return self.evaluate_expr( File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/recording.py", line 263, in wrapper return retlog(fn(*args, **kwargs)) File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6671, in evaluate_expr return self._evaluate_expr( File "/usr/local/lib/python3.10/dist-packages/torch/fx/experimental/symbolic_shapes.py", line 6894, in _evaluate_expr raise self._make_data_dependent_error( torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode: Could not extract specialized integer from data-dependent expression u0 (unhinted: u0). (Size-like symbols: none) Caused by: (ar/lib/workspace/intermediate_source/torch_export_tutorial.py:748 in forward) For more information, run with TORCH_LOGS="dynamic" For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u0" If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing For C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 .. GENERATED FROM PYTHON SOURCE LINES 761-766 For these errors, some basic options you have are: 1. Avoid unnecessary ``int()`` cast calls, in this case the ``int(a)`` in the return statement. 2. Use ``torch._check()`` calls; unfortunately all you may be able to do in this case is specialize (with ``torch._check(a == 60)``). 3. Rewrite the offending code at a higher level. For example, the list comprehension is semantically a ``repeat()`` op, which doesn't involve an ``int()`` cast. The following rewrite avoids data-dependent errors: .. GENERATED FROM PYTHON SOURCE LINES 766-780 .. code-block:: Python class Foo(torch.nn.Module): def forward(self, x, y): a = x.item() b = y.unsqueeze(0).repeat(a, 1) return b + a inps = ( torch.tensor(32), torch.randn(60), ) ep = export(Foo(), inps, strict=False) print(ep) .. rst-class:: sphx-glr-script-out .. code-block:: none I0718 23:09:27.719000 25987 torch/fx/experimental/symbolic_shapes.py:3334] create_env I0718 23:09:27.724000 25987 torch/fx/experimental/symbolic_shapes.py:4276] create_unbacked_symint u0 [-int_oo, int_oo] (_subclasses/fake_impls.py:422 in local_scalar_dense) I0718 23:09:27.725000 25987 torch/fx/experimental/symbolic_shapes.py:1130] compute_unbacked_bindings [u0] I0718 23:09:27.729000 25987 torch/fx/experimental/symbolic_shapes.py:6630] runtime_assert u0 >= 0 [guard added] (_refs/__init__.py:4796 in new_empty), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="u0 >= 0" V0718 23:09:27.730000 25987 torch/fx/experimental/symbolic_shapes.py:6071] _update_var_to_range u0 = VR[0, int_oo] (update) V0718 23:09:27.732000 25987 torch/fx/experimental/symbolic_shapes.py:6787] eval size_oblivious(Eq(u0, 0)) == False [statically known] V0718 23:09:27.735000 25987 torch/fx/experimental/symbolic_shapes.py:6787] eval size_oblivious(Eq(u0, 1)) == False [statically known] V0718 23:09:27.736000 25987 torch/fx/experimental/symbolic_shapes.py:7018] runtime_assert True == True [statically known] I0718 23:09:27.739000 25987 torch/fx/experimental/symbolic_shapes.py:4734] produce_guards V0718 23:09:27.740000 25987 torch/fx/experimental/symbolic_shapes.py:4954] track_symint L['args'][0][0].storage_offset() 0 None V0718 23:09:27.740000 25987 torch/fx/experimental/symbolic_shapes.py:4954] track_symint L['args'][0][1].size()[0] 60 None V0718 23:09:27.740000 25987 torch/fx/experimental/symbolic_shapes.py:4954] track_symint L['args'][0][1].stride()[0] 1 None V0718 23:09:27.741000 25987 torch/fx/experimental/symbolic_shapes.py:4954] track_symint L['args'][0][1].storage_offset() 0 None V0718 23:09:27.742000 25987 torch/fx/experimental/symbolic_shapes.py:7018] runtime_assert u0 >= 0 == True [statically known] ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "i64[]", y: "f32[60]"): # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:769 in forward, code: a = x.item() item: "Sym(u0)" = torch.ops.aten.item.default(x); x = None # sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item); sym_constrain_range_for_size_default = None # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:769 in forward, code: a = x.item() ge: "Sym(u0 >= 0)" = item >= 0 _assert_scalar_default = torch.ops.aten._assert_scalar.default(ge, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge = _assert_scalar_default = None # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:770 in forward, code: b = y.unsqueeze(0).repeat(a, 1) unsqueeze: "f32[1, 60]" = torch.ops.aten.unsqueeze.default(y, 0); y = None repeat: "f32[u0, 60]" = torch.ops.aten.repeat.default(unsqueeze, [item, 1]); unsqueeze = None # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:771 in forward, code: return b + a add: "f32[u0, 60]" = torch.ops.aten.add.Tensor(repeat, item); repeat = item = None return (add,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='x'), target=None, persistent=None), InputSpec(kind=, arg=TensorArgument(name='y'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='add'), target=None)]) Range constraints: {u0: VR[0, int_oo]} .. GENERATED FROM PYTHON SOURCE LINES 781-784 Data-dependent errors can be much more involved, and there are many more options in your toolkit to deal with them: ``torch._check_is_size()``, ``guard_size_oblivious()``, or real-tensor tracing, as starters. For more in-depth guides, please refer to the `Export Programming Model `_, or `Dealing with GuardOnDataDependentSymNode errors `_. .. GENERATED FROM PYTHON SOURCE LINES 786-796 Custom Ops ---------- ``torch.export`` can export PyTorch programs with custom operators. Please refer to `this page `__ on how to author a custom operator in either C++ or Python. The following is an example of registering a custom operator in python to be used by ``torch.export``. The important thing to note is that the custom op must have a `FakeTensor kernel `__. .. GENERATED FROM PYTHON SOURCE LINES 796-807 .. code-block:: Python @torch.library.custom_op("my_custom_library::custom_op", mutates_args={}) def custom_op(x: torch.Tensor) -> torch.Tensor: print("custom_op called!") return torch.relu(x) @custom_op.register_fake def custom_op_meta(x): # Returns an empty tensor with the same shape as the expected output return torch.empty_like(x) .. GENERATED FROM PYTHON SOURCE LINES 808-809 Here is an example of exporting a program with the custom op. .. GENERATED FROM PYTHON SOURCE LINES 809-821 .. code-block:: Python class CustomOpExample(torch.nn.Module): def forward(self, x): x = torch.sin(x) x = torch.ops.my_custom_library.custom_op(x) x = torch.cos(x) return x exported_custom_op_example = export(CustomOpExample(), (torch.randn(3, 3),)) print(exported_custom_op_example) print(exported_custom_op_example.module()(torch.randn(3, 3))) .. rst-class:: sphx-glr-script-out .. code-block:: none I0718 23:09:27.757000 25987 torch/fx/experimental/symbolic_shapes.py:3334] [18/0] create_env I0718 23:09:27.766000 25987 torch/fx/experimental/symbolic_shapes.py:4734] [18/0] produce_guards V0718 23:09:27.766000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [18/0] track_symint L['x'].size()[0] 3 None V0718 23:09:27.766000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [18/0] track_symint L['x'].size()[1] 3 None V0718 23:09:27.767000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [18/0] track_symint L['x'].stride()[0] 3 None V0718 23:09:27.767000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [18/0] track_symint L['x'].stride()[1] 1 None V0718 23:09:27.767000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [18/0] track_symint L['x'].storage_offset() 0 None ExportedProgram: class GraphModule(torch.nn.Module): def forward(self, x: "f32[3, 3]"): # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:812 in forward, code: x = torch.sin(x) sin: "f32[3, 3]" = torch.ops.aten.sin.default(x); x = None # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:813 in forward, code: x = torch.ops.my_custom_library.custom_op(x) custom_op: "f32[3, 3]" = torch.ops.my_custom_library.custom_op.default(sin); sin = None # File: /var/lib/workspace/intermediate_source/torch_export_tutorial.py:814 in forward, code: x = torch.cos(x) cos: "f32[3, 3]" = torch.ops.aten.cos.default(custom_op); custom_op = None return (cos,) Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=, arg=TensorArgument(name='x'), target=None, persistent=None)], output_specs=[OutputSpec(kind=, arg=TensorArgument(name='cos'), target=None)]) Range constraints: {} custom_op called! tensor([[1.0000, 0.8897, 1.0000], [0.8974, 1.0000, 1.0000], [1.0000, 1.0000, 0.9156]]) .. GENERATED FROM PYTHON SOURCE LINES 822-823 Note that in the ``ExportedProgram``, the custom operator is included in the graph. .. GENERATED FROM PYTHON SOURCE LINES 825-841 IR/Decompositions ----------------- The graph produced by ``torch.export`` returns a graph containing only `ATen operators `__, which are the basic unit of computation in PyTorch. As there are over 3000 ATen operators, export provides a way to narrow down the operator set used in the graph based on certain characteristics, creating different IRs. By default, export produces the most generic IR which contains all ATen operators, including both functional and non-functional operators. A functional operator is one that does not contain any mutations or aliasing of the inputs. You can find a list of all ATen operators `here `__ and you can inspect if an operator is functional by checking ``op._schema.is_mutable``, for example: .. GENERATED FROM PYTHON SOURCE LINES 841-845 .. code-block:: Python print(torch.ops.aten.add.Tensor._schema.is_mutable) print(torch.ops.aten.add_.Tensor._schema.is_mutable) .. rst-class:: sphx-glr-script-out .. code-block:: none False True .. GENERATED FROM PYTHON SOURCE LINES 846-850 This generic IR can be used to train in eager PyTorch Autograd. This IR can be more explicitly reached through the API ``torch.export.export_for_training``, which was introduced in PyTorch 2.5, but calling ``torch.export.export`` should produce the same graph as of PyTorch 2.6. .. GENERATED FROM PYTHON SOURCE LINES 850-865 .. code-block:: Python class DecompExample(torch.nn.Module): def __init__(self) -> None: super().__init__() self.conv = torch.nn.Conv2d(1, 3, 1, 1) self.bn = torch.nn.BatchNorm2d(3) def forward(self, x): x = self.conv(x) x = self.bn(x) return (x,) ep_for_training = torch.export.export_for_training(DecompExample(), (torch.randn(1, 1, 3, 3),)) print(ep_for_training.graph) .. rst-class:: sphx-glr-script-out .. code-block:: none I0718 23:09:27.789000 25987 torch/fx/experimental/symbolic_shapes.py:3334] [19/0] create_env I0718 23:09:27.820000 25987 torch/fx/experimental/symbolic_shapes.py:4734] [19/0] produce_guards V0718 23:09:27.820000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [19/0] track_symint L['x'].size()[0] 1 None V0718 23:09:27.820000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [19/0] track_symint L['x'].size()[1] 1 None V0718 23:09:27.821000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [19/0] track_symint L['x'].size()[2] 3 None V0718 23:09:27.821000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [19/0] track_symint L['x'].size()[3] 3 None V0718 23:09:27.821000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [19/0] track_symint L['x'].stride()[0] 9 None V0718 23:09:27.822000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [19/0] track_symint L['x'].stride()[1] 9 None V0718 23:09:27.822000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [19/0] track_symint L['x'].stride()[2] 3 None V0718 23:09:27.822000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [19/0] track_symint L['x'].stride()[3] 1 None V0718 23:09:27.822000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [19/0] track_symint L['x'].storage_offset() 0 None graph(): %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight] %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias] %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight] %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias] %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean] %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var] %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked] %x : [num_users=1] = placeholder[target=x] %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv_weight, %p_conv_bias), kwargs = {}) %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {}) %batch_norm : [num_users=1] = call_function[target=torch.ops.aten.batch_norm.default](args = (%conv2d, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05, True), kwargs = {}) return (batch_norm,) .. GENERATED FROM PYTHON SOURCE LINES 866-873 We can then lower this exported program to an operator set which only contains functional ATen operators through the API ``run_decompositions``, which decomposes the ATen operators into the ones specified in the decomposition table, and functionalizes the graph. By specifying an empty set, we're only performing functionalization, and does not do any additional decompositions. This results in an IR which contains ~2000 operators (instead of the 3000 operators above), and is ideal for inference cases. .. GENERATED FROM PYTHON SOURCE LINES 873-877 .. code-block:: Python ep_for_inference = ep_for_training.run_decompositions(decomp_table={}) print(ep_for_inference.graph) .. rst-class:: sphx-glr-script-out .. code-block:: none graph(): %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight] %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias] %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight] %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias] %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean] %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var] %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked] %x : [num_users=1] = placeholder[target=x] %conv2d : [num_users=1] = call_function[target=torch.ops.aten.conv2d.default](args = (%x, %p_conv_weight, %p_conv_bias), kwargs = {}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {}) %_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%conv2d, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {}) %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {}) %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {}) return (getitem_3, getitem_4, add, getitem) .. GENERATED FROM PYTHON SOURCE LINES 878-881 As we can see, the previously mutable operator, ``torch.ops.aten.add_.default`` has now been replaced with ``torch.ops.aten.add.default``, a l operator. .. GENERATED FROM PYTHON SOURCE LINES 883-888 We can also further lower this exported program to an operator set which only contains the `Core ATen Operator Set `__, which is a collection of only ~180 operators. This IR is optimal for backends who do not want to reimplement all ATen operators. .. GENERATED FROM PYTHON SOURCE LINES 888-895 .. code-block:: Python from torch.export import default_decompositions core_aten_decomp_table = default_decompositions() core_aten_ep = ep_for_training.run_decompositions(decomp_table=core_aten_decomp_table) print(core_aten_ep.graph) .. rst-class:: sphx-glr-script-out .. code-block:: none graph(): %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight] %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias] %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight] %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias] %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean] %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var] %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked] %x : [num_users=1] = placeholder[target=x] %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %p_conv_weight, %p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {}) %_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%convolution, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {}) %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {}) %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {}) return (getitem_3, getitem_4, add, getitem) .. GENERATED FROM PYTHON SOURCE LINES 896-900 We now see that ``torch.ops.aten.conv2d.default`` has been decomposed into ``torch.ops.aten.convolution.default``. This is because ``convolution`` is a more "core" operator, as operations like ``conv1d`` and ``conv2d`` can be implemented using the same op. .. GENERATED FROM PYTHON SOURCE LINES 902-903 We can also specify our own decomposition behaviors: .. GENERATED FROM PYTHON SOURCE LINES 903-913 .. code-block:: Python my_decomp_table = torch.export.default_decompositions() def my_awesome_custom_conv2d_function(x, weight, bias, stride=[1, 1], padding=[0, 0], dilation=[1, 1], groups=1): return 2 * torch.ops.aten.convolution(x, weight, bias, stride, padding, dilation, False, [0, 0], groups) my_decomp_table[torch.ops.aten.conv2d.default] = my_awesome_custom_conv2d_function my_ep = ep_for_training.run_decompositions(my_decomp_table) print(my_ep.graph) .. rst-class:: sphx-glr-script-out .. code-block:: none graph(): %p_conv_weight : [num_users=1] = placeholder[target=p_conv_weight] %p_conv_bias : [num_users=1] = placeholder[target=p_conv_bias] %p_bn_weight : [num_users=1] = placeholder[target=p_bn_weight] %p_bn_bias : [num_users=1] = placeholder[target=p_bn_bias] %b_bn_running_mean : [num_users=1] = placeholder[target=b_bn_running_mean] %b_bn_running_var : [num_users=1] = placeholder[target=b_bn_running_var] %b_bn_num_batches_tracked : [num_users=1] = placeholder[target=b_bn_num_batches_tracked] %x : [num_users=1] = placeholder[target=x] %convolution : [num_users=1] = call_function[target=torch.ops.aten.convolution.default](args = (%x, %p_conv_weight, %p_conv_bias, [1, 1], [0, 0], [1, 1], False, [0, 0], 1), kwargs = {}) %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%convolution, 2), kwargs = {}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_bn_num_batches_tracked, 1), kwargs = {}) %_native_batch_norm_legit_functional : [num_users=3] = call_function[target=torch.ops.aten._native_batch_norm_legit_functional.default](args = (%mul, %p_bn_weight, %p_bn_bias, %b_bn_running_mean, %b_bn_running_var, True, 0.1, 1e-05), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 0), kwargs = {}) %getitem_3 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 3), kwargs = {}) %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%_native_batch_norm_legit_functional, 4), kwargs = {}) return (getitem_3, getitem_4, add, getitem) .. GENERATED FROM PYTHON SOURCE LINES 914-918 Notice that instead of ``torch.ops.aten.conv2d.default`` being decomposed into ``torch.ops.aten.convolution.default``, it is now decomposed into ``torch.ops.aten.convolution.default`` and ``torch.ops.aten.mul.Tensor``, which matches our custom decomposition rule. .. GENERATED FROM PYTHON SOURCE LINES 920-935 ExportDB -------- ``torch.export`` will only ever export a single computation graph from a PyTorch program. Because of this requirement, there will be Python or PyTorch features that are not compatible with ``torch.export``, which will require users to rewrite parts of their model code. We have seen examples of this earlier in the tutorial -- for example, rewriting if-statements using ``cond``. `ExportDB `__ is the standard reference that documents supported and unsupported Python/PyTorch features for ``torch.export``. It is essentially a list a program samples, each of which represents the usage of one particular Python/PyTorch feature and its interaction with ``torch.export``. Examples are also tagged by category so that they can be more easily searched. For example, let's use ExportDB to get a better understanding of how the predicate works in the ``cond`` operator. We can look at the example called ``cond_predicate``, which has a ``torch.cond`` tag. The example code looks like: .. GENERATED FROM PYTHON SOURCE LINES 935-946 .. code-block:: Python def cond_predicate(x): """ The conditional statement (aka predicate) passed to ``cond()`` must be one of the following: - ``torch.Tensor`` with a single element - boolean expression NOTE: If the `pred` is test on a dim with batch size < 2, it will be specialized. """ pred = x.dim() > 2 and x.shape[2] > 10 return cond(pred, lambda x: x.cos(), lambda y: y.sin(), [x]) .. GENERATED FROM PYTHON SOURCE LINES 947-955 More generally, ExportDB can be used as a reference when one of the following occurs: 1. Before attempting ``torch.export``, you know ahead of time that your model uses some tricky Python/PyTorch features and you want to know if ``torch.export`` covers that feature. 2. When attempting ``torch.export``, there is a failure and it's unclear how to work around it. ExportDB is not exhaustive, but is intended to cover all use cases found in typical PyTorch code. Feel free to reach out if there is an important Python/PyTorch feature that should be added to ExportDB or supported by ``torch.export``. .. GENERATED FROM PYTHON SOURCE LINES 957-966 Running the Exported Program ---------------------------- As ``torch.export`` is only a graph capturing mechanism, calling the artifact produced by ``torch.export`` eagerly will be equivalent to running the eager module. To optimize the execution of the Exported Program, we can pass this exported artifact to backends such as Inductor through ``torch.compile``, `AOTInductor `__, or `TensorRT `__. .. GENERATED FROM PYTHON SOURCE LINES 966-988 .. code-block:: Python class M(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(3, 3) def forward(self, x): x = self.linear(x) return x inp = torch.randn(2, 3, device="cuda") m = M().to(device="cuda") ep = torch.export.export(m, (inp,)) # Run it eagerly res = ep.module()(inp) print(res) # Run it with torch.compile res = torch.compile(ep.module(), backend="inductor")(inp) print(res) .. rst-class:: sphx-glr-script-out .. code-block:: none I0718 23:09:28.778000 25987 torch/fx/experimental/symbolic_shapes.py:3334] [20/0] create_env I0718 23:09:28.792000 25987 torch/fx/experimental/symbolic_shapes.py:4734] [20/0] produce_guards V0718 23:09:28.792000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [20/0] track_symint L['x'].size()[0] 2 None V0718 23:09:28.793000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [20/0] track_symint L['x'].size()[1] 3 None V0718 23:09:28.793000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [20/0] track_symint L['x'].stride()[0] 3 None V0718 23:09:28.793000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [20/0] track_symint L['x'].stride()[1] 1 None V0718 23:09:28.794000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [20/0] track_symint L['x'].storage_offset() 0 None tensor([[ 0.8205, 0.2483, 0.3667], [ 1.1952, -1.2561, 1.0813]], device='cuda:0', grad_fn=) I0718 23:09:29.921000 25987 torch/fx/experimental/symbolic_shapes.py:3334] [21/0] create_env /usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:236: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance. I0718 23:09:30.784000 25987 torch/fx/experimental/symbolic_shapes.py:4734] [21/0] produce_guards I0718 23:09:30.793000 25987 torch/fx/experimental/symbolic_shapes.py:4734] [21/0] produce_guards V0718 23:09:30.793000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [21/0] track_symint L['x'].size()[0] 2 None V0718 23:09:30.794000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [21/0] track_symint L['x'].size()[1] 3 None V0718 23:09:30.794000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [21/0] track_symint L['x'].stride()[0] 3 None V0718 23:09:30.794000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [21/0] track_symint L['x'].stride()[1] 1 None V0718 23:09:30.795000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [21/0] track_symint L['x'].storage_offset() 0 None V0718 23:09:30.795000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [21/0] track_symint L['self']._modules['linear']._parameters['weight'].size()[0] 3 None V0718 23:09:30.795000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [21/0] track_symint L['self']._modules['linear']._parameters['weight'].size()[1] 3 None V0718 23:09:30.795000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [21/0] track_symint L['self']._modules['linear']._parameters['weight'].stride()[0] 3 None V0718 23:09:30.796000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [21/0] track_symint L['self']._modules['linear']._parameters['weight'].stride()[1] 1 None V0718 23:09:30.796000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [21/0] track_symint L['self']._modules['linear']._parameters['weight'].storage_offset() 0 None V0718 23:09:30.796000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [21/0] track_symint L['self']._modules['linear']._parameters['bias'].size()[0] 3 None V0718 23:09:30.797000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [21/0] track_symint L['self']._modules['linear']._parameters['bias'].stride()[0] 1 None V0718 23:09:30.797000 25987 torch/fx/experimental/symbolic_shapes.py:4954] [21/0] track_symint L['self']._modules['linear']._parameters['bias'].storage_offset() 0 None V0718 23:09:30.797000 25987 torch/fx/experimental/symbolic_shapes.py:5156] [21/0] Skipping guard L['x'].size()[0] == 2 V0718 23:09:30.797000 25987 torch/fx/experimental/symbolic_shapes.py:5156] [21/0] Skipping guard L['x'].size()[1] == 3 V0718 23:09:30.798000 25987 torch/fx/experimental/symbolic_shapes.py:5156] [21/0] Skipping guard L['x'].stride()[0] == 3 V0718 23:09:30.798000 25987 torch/fx/experimental/symbolic_shapes.py:5156] [21/0] Skipping guard L['x'].stride()[1] == 1 V0718 23:09:30.798000 25987 torch/fx/experimental/symbolic_shapes.py:5156] [21/0] Skipping guard L['x'].storage_offset() == 0 V0718 23:09:30.799000 25987 torch/fx/experimental/symbolic_shapes.py:5156] [21/0] Skipping guard L['self']._modules['linear']._parameters['weight'].size()[0] == 3 V0718 23:09:30.799000 25987 torch/fx/experimental/symbolic_shapes.py:5156] [21/0] Skipping guard L['self']._modules['linear']._parameters['weight'].size()[1] == 3 V0718 23:09:30.799000 25987 torch/fx/experimental/symbolic_shapes.py:5156] [21/0] Skipping guard L['self']._modules['linear']._parameters['weight'].stride()[0] == 3 V0718 23:09:30.800000 25987 torch/fx/experimental/symbolic_shapes.py:5156] [21/0] Skipping guard L['self']._modules['linear']._parameters['weight'].stride()[1] == 1 V0718 23:09:30.800000 25987 torch/fx/experimental/symbolic_shapes.py:5156] [21/0] Skipping guard L['self']._modules['linear']._parameters['weight'].storage_offset() == 0 V0718 23:09:30.800000 25987 torch/fx/experimental/symbolic_shapes.py:5156] [21/0] Skipping guard L['self']._modules['linear']._parameters['bias'].size()[0] == 3 V0718 23:09:30.800000 25987 torch/fx/experimental/symbolic_shapes.py:5156] [21/0] Skipping guard L['self']._modules['linear']._parameters['bias'].stride()[0] == 1 V0718 23:09:30.801000 25987 torch/fx/experimental/symbolic_shapes.py:5156] [21/0] Skipping guard L['self']._modules['linear']._parameters['bias'].storage_offset() == 0 tensor([[ 0.8205, 0.2483, 0.3667], [ 1.1952, -1.2561, 1.0813]], device='cuda:0', grad_fn=) .. GENERATED FROM PYTHON SOURCE LINES 989-1003 .. code-block:: python import torch._inductor # Note: these APIs are subject to change # Compile the exported program to a PT2 archive using ``AOTInductor`` with torch.no_grad(): pt2_path = torch._inductor.aoti_compile_and_package(ep) # Load and run the .so file in Python. # To load and run it in a C++ environment, see: # https://pytorch.org/docs/main/torch.compiler_aot_inductor.html aoti_compiled = torch._inductor.aoti_load_package(pt2_path) res = aoti_compiled(inp) .. GENERATED FROM PYTHON SOURCE LINES 1005-1011 Conclusion ---------- We introduced ``torch.export``, the new PyTorch 2.X way to export single computation graphs from PyTorch programs. In particular, we demonstrate several code modifications and considerations (control flow ops, constraints, etc.) that need to be made in order to export a graph. .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 4.580 seconds) .. _sphx_glr_download_intermediate_torch_export_tutorial.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: torch_export_tutorial.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: torch_export_tutorial.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: torch_export_tutorial.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_