```{code-cell}
:tags: [remove-cell]
import torch
import header_code
torch._logging.set_logs(graph_breaks=True)
```
# Toggling `error_on_graph_break`
**Summary:**
- When `fullgraph=False`, we can use `torch._dynamo.error_on_graph_break()` for more flexibility in
dealing with graph breaks.
So far, we have introduced two ways in dealing with graph breaks in `torch.compile`:
1. `fullgraph=True` errors on the first graph break and additionally guarantees that only one graph is traced from the code.
2. `fullgraph=False` continues tracing even when encountering graph breaks.
What if we want to disallow graph breaks for most of the code, but there are a few problematic functions where the graph breaks are hard to remove,
and we are okay with having those graph breaks? We can use `torch._dynamo.error_on_graph_break()` to achieve this.
`torch.compile` has an `error_on_graph_break` setting (initially set to `False`).
If a graph break or compiler error occurs in code while `error_on_graph_break` is set to `False`, then `torch.compile` will attempt to continue compilation after the graph break/error.
If `error_on_graph_break` is set to `True`, then `torch.compile` will abort compilation and propagate the error to user code.
A significant difference between `error_on_graph_break=True` and `fullgraph=True` is that the former **does not guarantee that a single graph will be captured**.
`error_on_graph_break` **can be arbitrarily toggled during compile time** by using the `torch._dynamo.error_on_graph_break()` context manager/decorator.
In comparison, once `fullgraph` is set to `True`, it cannot be set back to `False`.
Finally, `error_on_graph_break` has lower precedence than `fullgraph` - `error_on_graph_break` only takes effect when `fullgraph=False`.
## `error_on_graph_break(False)` example
```{code-cell}
@torch._dynamo.error_on_graph_break(False)
def code_with_a_difficult_graph_break(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
def inner(x):
return code_with_a_difficult_graph_break(x)
# NOTE: fullgraph=False
@torch._dynamo.error_on_graph_break(True)
@torch.compile
def fn(x):
return inner(x)
# No error, but there is a graph break
fn(torch.randn(3))
```
Using `error_on_graph_break(False)` under `error_on_graph_break(True)` is helpful for when we want to minimize graph breaks (i.e. follow the `fullgraph=True` programming model),
but there are some sections of code with non-performance-critical graph breaks that are difficult to work around.
`error_on_graph_break()` can be used as a context manager as well:
```{code-cell}
# NOTE: fullgraph=False
@torch._dynamo.error_on_graph_break(True)
@torch.compile
def fn(x):
x = x + 1
with torch._dynamo.error_on_graph_break(False):
torch._dynamo.graph_break() # no error
return x + 2
# No error, but there is a graph break
fn(torch.randn(3))
```
You can use monkey patching to toggle `error_on_graph_break` for code where you cannot edit the source (e.g. framework code):
```{code-cell}
class ThirdPartyModule(torch.nn.Module):
def forward(self, x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
tp_mod = ThirdPartyModule()
tp_mod.forward = torch._dynamo.error_on_graph_break(False)(tp_mod.forward)
@torch._dynamo.error_on_graph_break(True)
@torch.compile
def fn(x):
return tp_mod.forward(x)
# No error, but there is a graph break
fn(torch.randn(3))
```
## `error_on_graph_break(True)` example
```{code-cell}
@torch._dynamo.error_on_graph_break(True)
def inner2(x):
x = x + 1
torch._dynamo.graph_break() # error
return x + 2
def inner(x):
return inner2(x)
# fullgraph=False, error_on_graph_break=False
@torch.compile
def fn(x):
x = x + 4
torch._dynamo.graph_break() # no error
return inner(x)
try:
fn(torch.randn(3))
except Exception as e:
print(e)
```
Using `error_on_graph_break(True)` under `error_on_graph_break(False)` is helpful for when we want to use `torch.compile` flexibly (i.e. follow the `fullgraph=False` programming model),
but there are some sections of the code that are performance-critical and we want to ensure that those sections do not contain graph breaks.
## `error_on_graph_break` nesting behavior
`torch._dynamo.error_on_graph_break()` affects the `error_on_graph_break` setting of nested calls as well:
```{code-cell}
def inner(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
def inner2(x):
with torch._dynamo.error_on_graph_break(False):
return inner(x)
@torch._dynamo.error_on_graph_break(True)
@torch.compile
def fn(x):
return inner2(x)
# no error
fn(torch.randn(3))
```
`torch._dynamo.error_on_graph_break()` can be used under another `torch._dynamo.error_on_graph_break()` region:
```{code-cell}
def inner(x):
x = x + 1
with torch._dynamo.error_on_graph_break(False):
torch._dynamo.graph_break()
return x + 2
def inner2(x):
with torch._dynamo.error_on_graph_break(True):
return inner(x)
@torch.compile
def fn(x):
return inner2(x)
# no error
fn(torch.randn(3))
```
## Interaction with `fullgraph`
`fullgraph=True` takes higher precedence than `error_on_graph_break`:
```{code-cell}
@torch._dynamo.error_on_graph_break(False)
def inner(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
@torch.compile(fullgraph=True)
def fn(x):
return inner(x)
try:
fn(torch.randn(3))
except Exception as e:
print(e)
```
`fullgraph=True` cannot be toggled back to `fullgraph=False`:
```{code-cell}
@torch.compile(fullgraph=False)
def inner(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
@torch.compile(fullgraph=True)
def fn(x):
return inner(x)
try:
fn(torch.randn(3))
except Exception as e:
print(e)
```
```{code-cell}
@torch.compile(fullgraph=True)
def inner(x):
x = x + 1
torch._dynamo.graph_break()
return x + 2
@torch.compile(fullgraph=False)
def fn(x):
return inner(x)
try:
fn(torch.randn(3))
except Exception as e:
print(e)
```
## Summary of `fullgraph=True/False` vs `error_on_graph_break`
Here is a table summarizing the differences between `fullgraph=True/False` and `error_on_graph_break`:
| | `error_on_graph_break=True` | `error_on_graph_break=False` (default) |
| --- | --- | --- |
| `fullgraph=True` | Graph breaks result in errors. Only the first graph break will be reported. **One graph guarantee.**
`fullgraph` cannot be toggled to `False`. `error_on_graph_break` has no effect.
User code must be fully compatible with `torch.compile`. Guarantees no performance hits from graph breaks (because there are no graph breaks).
Ideal for code sensitive to graph breaks: framework/library code or cases where getting maximum performance is required. Prevents downstream user code from inadvertently allowing graph breaks. | Same as `fullgraph=True` and `error_on_graph_break=True` as `error_on_graph_break` has no effect when `fullgraph=True`. |
| `fullgraph=False` (default) | Graph breaks result in errors. Only the first graph break will be reported. **No one graph guarantee.**
`error_on_graph_break` can be toggled to `False`.
User code must be fully compatible with `torch.compile`. Guarantees no performance hits from graph breaks (because there are no graph breaks).
Ideal for user code sensitive to graph breaks. `error_on_graph_break` can be toggled to `False` to deal with sections that have graph breaks that are difficult to work around. | Will continue to compile after encountering graph breaks. All graph breaks will be reported.
`error_on_graph_break` can be toggled to `True`.
Doesn’t require many user code changes to work. Performance may be negatively impacted due to graph breaks.
Ideal for out-of-the-box use cases, on “non-weird” code, or where squeezing maximal performance is not necessary |