Rate this Page
torch._dynamo.nonstrict_trace">

Use torch._dynamo.nonstrict_trace#

Created On: Jul 28, 2025 | Last Updated On: Jul 28, 2025

Summary:

  • Use nonstrict_trace to trace a function with non-strict tracing inside of a torch.compile’d region. You may wish to do this because the Dynamo graph breaks on something inside of the function and you are sure that the function is non-strict traceable.

Consider the following scenario:

def get_magic_num():
    # This explicit graph break call is meant to emulate any kind of Dynamo
    # graph break, e.g., the function is implemented in C, or uses some python
    # language feature Dynamo doesn't yet support.
    torch._dynamo.graph_break()
    return torch.tensor([42])
@torch.compile(fullgraph=True)
def func(x):
    n = get_magic_num()
    return x + n
try:
    func(torch.rand(10))
except Exception as e:
    print(e)
Call to `torch._dynamo.graph_break()`
  Explanation: User-inserted graph break. Message: None
  Hint: Remove the `torch._dynamo.graph_break()` call.

  Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`

 For more details about this graph break, please visit: https://pytorch-labs.github.io/compile-graph-break-site/gb/gb0025.html

from user code:
   File "/tmp/ipykernel_854/2253748958.py", line 9, in func
    n = get_magic_num()
  File "/tmp/ipykernel_854/2253748958.py", line 5, in get_magic_num
    torch._dynamo.graph_break()

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

If we run the code above, we’ll get an error from Dynamo, because it sees a graph break while the user specified fullgraph=True.

In these situations, if a user still wants to keep fullgraph=True, they typically have several options:

  1. The graph break is due to a language feature Dynamo doesn’t yet support. In this case, the user either rewrites their code, or files an issue on GitHub.

  2. The graph break is due to a call to a function implemented in C. In this case, the user can try to use a custom op. The user could also try providing a polyfill (a reference implementation in Python) so that Dynamo can trace through it.

  3. Worst case scenario – an internal compiler error. In this case, the user likely has to file an issue on GitHub.

In addition to all these options, PyTorch does provide an alternative torch._dynamo.nonstrict_trace, if the function call that induced the graph break satisfies certain requirements:

  • The requirements of general non-strict tracing.

  • The inputs and outputs must contain either basic types (e.g., int, float, list, dict, torch.Tensor), or user-defined types that are registered to torch.utils._pytree.

  • The function must be defined outside the torch.compile’d region.

  • Any non-input values read by the function will be treated as a constant (e.g., a global tensor), and will not be guarded on.

When tracing through a call to a torch._dynamo.nonstrict_trace’d function, torch.compile switches to non-strict tracing, and the FX graph will eventually contain all the relevant tensor operations which happened inside that function.

For the example above, we can use torch._dynamo.nonstrict_trace to eliminate the graph break:

@torch._dynamo.nonstrict_trace
def get_magic_num():
    # This explicit graph break call is meant to emulate any kind of Dynamo
    # graph break, e.g., the function is implemented in C, or uses some python
    # language feature Dynamo doesn't yet support.
    torch._dynamo.graph_break()
    return torch.tensor([42])
@torch.compile(fullgraph=True)
def func(x):
    n = get_magic_num()
    return x + n
print(func(torch.rand(10)))
# No graph break and no error.
tensor([42.6713, 42.0628, 42.8153, 42.0862, 42.4542, 42.7024, 42.0485, 42.7828,
        42.5870, 42.4988])

Note that one can use it inside a torch.compile’d region as well:

def get_magic_num():
    # This explicit graph break call is meant to emulate any kind of Dynamo
    # graph break, e.g., the function is implemented in C, or uses some python
    # language feature Dynamo doesn't yet support.
    torch._dynamo.graph_break()
    return torch.tensor([42])
@torch.compile(fullgraph=True)
def func(x):
    n = torch._dynamo.nonstrict_trace(get_magic_num)()
    return x + n
print(func(torch.rand(10)))
# No graph break and no error.
tensor([42.9439, 42.4081, 42.4315, 42.4321, 42.8948, 42.0755, 42.7228, 42.7411,
        42.5308, 42.0496])