Rate this Page
fullgraph=False">

Nested Graph Breaks#

Created On: Jul 28, 2025 | Last Updated On: Jul 28, 2025

Summary:

  • Graph breaks in nested functions can result in hard-to-understand compiler behavior, which we document below

  • A nested graph break results in O(N)\mathcal O(N) duplicate graph break behavior

Recall that when torch.compile is applied to a function, any nested function calls are also traced. A nested graph break refers to any graph break that happens in a nested function call.

def inner(x):
    ...
    torch._dynamo.graph_break()  # nested graph break
    ...

@torch.compile
def outer(x):
    ...
    y = inner(x)
    ...

The resumption semantics around nested graph breaks can be confusing, so we describe the behavior here.

Recall that in fullgraph=False, graph breaks are handled by compiling the FX graph that has been determined so far, running the unsupported code in regular Python, then resuming tracing after the unsupported code with a new FX graph. Resuming a function is actually a fairly complicated technical feat, so resuming tracing is only supported on top-level functions.

We can therefore resume tracing after a nested graph break with this restriction in the following way:

First, consider the below example where torch.compile traces from f and traces all the way until the graph break in inner1 is encountered.

def inner1(x):
    x = x + 1
    torch._dynamo.graph_break()  # stop tracing due to graph break
    return x + 2

def inner2(x):
    x = x + 4
    x = inner1(x)
    x = x + 8

@torch.compile
def f(x):
    # start tracing from here
    x = x + 16
    x = inner2(x)
    x = x + 32

f(torch.randn(3))

Since we can only resume from top-level functions, we graph break on the inner2 call in f.

# The semantics of torch.compile(f)(x) is roughly this:
def compiled_f_semantics(x):
    y = x + 16
    z = inner2(y)
    return torch.compile(resume_f_semantics)(z)

def resume_f_semantics(x):
    return x + 32

compiled_f_semantics(torch.randn(3))

inner2 is then automatically compiled as a top-level function. We trace all the way until the graph break in inner1 is encountered again.

def inner1(x):
    x = x + 1
    torch._dynamo.graph_break()  # stop tracing due to graph break
    return x + 2

# this torch.compile is automatically applied
@torch.compile
def inner2(x):
    # start tracing from here
    x = x + 4
    x = inner1(x)
    x = x + 8

def compiled_f_semantics(x):
    y = x + 16
    z = inner2(y)
    return torch.compile(resume_f_semantics)(z)

def resume_f_semantics(x):
    return x + 32

compiled_f_semantics(torch.randn(3))

Then we graph break on the inner1 call in inner2.

def compiled_inner2_semantics(x):
    y = x + 4
    z = inner1(y)
    return torch.compile(resume_inner2_semantics)(z)

def resume_inner2_semantics(x):
    return x + 8

inner1 is then automatically compiled as a top-level function. The graph break is from inner1, so we handle the graph break normally.

# this torch.compile is automatically applied
@torch.compile
def inner1(x):
    # start tracing from here
    x = x + 1
    torch._dynamo.graph_break()  # stop tracing due to graph break
    return x + 2

def compiled_f_semantics(x):
    y = x + 16
    z = compiled_inner2_semantics(y)
    return torch.compile(resume_f_semantics)(z)

def resume_f_semantics(x):
    return x + 32

def compiled_inner2_semantics(x):
    y = x + 4
    z = inner1(y)
    return torch.compile(resume_inner2_semantics)(z)

def resume_inner2_semantics(x):
    return x + 8

compiled_f_semantics(torch.randn(3))

inner1 is handled normally:

def compiled_inner1_semantics(x):
    y = x + 1
    torch._dynamo.graph_break()
    return torch.compile(resume_inner1_semantics)(y)

def resume_inner1_semantics(x):
    return x + 2

So the initial code is semantically equivalent to

def compiled_f_semantics(x):
    y = x + 16
    z = compiled_inner2_semantics(y)
    return torch.compile(resume_f_semantics)(z)

def resume_f_semantics(x):
    return x + 32

def compiled_inner2_semantics(x):
    y = x + 4
    z = compiled_inner1_semantics(y)
    return torch.compile(resume_inner2_semantics)(z)

def resume_inner2_semantics(x):
    return x + 8

def compiled_inner1_semantics(x):
    y = x + 1
    torch._dynamo.graph_break()
    return torch.compile(resume_inner1_semantics)(y)

def resume_inner1_semantics(x):
    return x + 2

compiled_f_semantics(torch.randn(3))

Note in particular that we traced 3 top-level functions, and that we traced the same graph break 3 times. This explains why you may encounter duplicate graph breaks when using torch.compile.

In summary, nested graph breaks are handled by:

  • Tracing from the top-level function all the way to the nested graph break

  • Graph breaking on the top-level function at the call to the second-level function

  • Compiling the PyTorch ops tracked so far and running the compiled graph

  • Calling the second-level function, which gets automatically compiled as a top-level function

  • Resuming tracing after the second-level function call

Note that the runtime of handling this graph break is O(NK)\mathcal O(NK), where NN is the nesting depth, and KK is the number of instructions from the top-level function to the graph break. We end up tracing O(N2)\mathcal O(N^2) frames, and we trace the same graph break O(N)\mathcal O(N) times.