torch.compile has different autograd semantics#
Created On: Jun 26, 2025 | Last Updated On: Jun 02, 2026
When you apply torch.compile to a function in your model’s forward pass,
it will automatically generate a backward pass for the compiled function.
During compilation, it will trace out a graph for the backward pass that
is used whenever autograd is invoked. We refer to the component inside
torch.compile that is responsible for this as AOTDispatcher
(sometimes known as AOTAutograd).
As so, torch.compile bakes in details of the computation into the
traced-out backward graph during compilation of the function
in the forward pass.
However, in eager-mode PyTorch, the backward computation is dynamic:
outside of the forward pass, you can wrap the call to
tensor.backward() or torch.autograd.grad(...)
in a context manager that may change its behavior.
This page documents how torch.compile’s autograd semantics differ from
eager-mode PyTorch and how to work around it.
Autocast behavior#
torch.compile bakes in an assumption about whether the backward pass will be
run under an ambient autocast context manager. Use
torch._functorch.config.backward_pass_autocast to control that assumption;
an incorrect assumption may lead to silent incorrectness.
Warning
AMP recommends that torch.autocast wrap only the forward pass and loss
computation. Backward passes under autocast are not recommended; see the
torch.autocast guidance in Autocasting. The default compiler
setting, "same_as_forward", intentionally preserves existing
torch.compile behavior by assuming the compiled backward runs under the
same autocast context as the compiled forward. If your code follows the AMP
recommendation and runs backward outside autocast, set
torch._functorch.config.backward_pass_autocast to "off" for the
compiled region.
The options are either:
"same_as_forward"(default). We assume that the backward of thetorch.compile’ed region will be run under the same autocast context manager that the region was run under (if any). Use this if your code looks like the following:with torch.amp.autocast(...): y = torch.compile(region)(x) ... # backward pass run under the same autocast context as the compiled region z.backward()
"off". We assume that the backward of the torch.compile’d region will not be run under any autocast context managers. Use this if your code looks like the following:with torch.amp.autocast(...): y = torch.compile(region)(x) ... # Backward pass runs under no autocast. z.backward()
There is a third option. If you set
torch._functorch.config.backward_pass_autocastto a list of kwargs, we will assume the backward pass runs under an autocast context constructed by those kwargs.For example, if your code looks like the following:
y = torch.compile(region)(x) ... # Backward pass runs under special context manager with torch.amp.autocast(**kwargs): z.backward()
then set
torch._functorch.config.backward_pass_autocast = kwargs.
Use patch to apply the option to a specific torch.compile call:
with torch.amp.autocast(...):
with torch._functorch.config.patch(backward_pass_autocast="same_as_forward")
y = torch.compile(region)(x)
...
# backward pass run under the same autocast context as the compiled region
z.backward()