Note
Go to the end to download the full example code.
Introduction to torch.compile
#
Created On: Mar 15, 2023 | Last Updated: Oct 15, 2025 | Last Verified: Nov 05, 2024
Author: William Wen
torch.compile
is the new way to speed up your PyTorch code!
torch.compile
makes PyTorch code run faster by
JIT-compiling PyTorch code into optimized kernels,
while requiring minimal code changes.
torch.compile
accomplishes this by tracing through
your Python code, looking for PyTorch operations.
Code that is difficult to trace will result a
graph break, which are lost optimization opportunities, rather
than errors or silent incorrectness.
torch.compile
is available in PyTorch 2.0 and later.
This introduction covers basic torch.compile
usage
and demonstrates the advantages of torch.compile
over
our previous PyTorch compiler solution,
TorchScript.
For an end-to-end example on a real model, check out our end-to-end torch.compile tutorial.
To troubleshoot issues and to gain a deeper understanding of how to apply torch.compile
to your code, check out the torch.compile programming model.
Contents
Required pip dependencies for this tutorial
torch >= 2.0
numpy
scipy
System requirements
- A C++ compiler, such as g++
- Python development package (python-devel
/python-dev
)
Basic Usage#
We turn on some logging to help us to see what torch.compile
is doing
under the hood in this tutorial.
The following code will print out the PyTorch ops that torch.compile
traced.
import torch
torch._logging.set_logs(graph_code=True)
torch.compile
is a decorator that takes an arbitrary Python function.
def foo(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(3, 3), torch.randn(3, 3)))
@torch.compile
def opt_foo2(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
print(opt_foo2(torch.randn(3, 3), torch.randn(3, 3)))
TRACED GRAPH
===== __compiled_fn_1_57703c6c_17e9_44be_adf9_87ae8a7f015f =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):
l_x_ = L_x_
l_y_ = L_y_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:74 in foo, code: a = torch.sin(x)
a: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_); l_x_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:75 in foo, code: b = torch.cos(y)
b: "f32[3, 3][3, 1]cpu" = torch.cos(l_y_); l_y_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:76 in foo, code: return a + b
add: "f32[3, 3][3, 1]cpu" = a + b; a = b = None
return (add,)
tensor([[ 0.0663, 1.8726, 1.0057],
[-0.3487, 0.3188, 0.9310],
[ 1.8560, 0.4513, -0.4614]])
TRACED GRAPH
===== __compiled_fn_3_12712180_e493_4bc2_8b8e_dcdfd783faaa =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):
l_x_ = L_x_
l_y_ = L_y_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:85 in opt_foo2, code: a = torch.sin(x)
a: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_); l_x_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:86 in opt_foo2, code: b = torch.cos(y)
b: "f32[3, 3][3, 1]cpu" = torch.cos(l_y_); l_y_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:87 in opt_foo2, code: return a + b
add: "f32[3, 3][3, 1]cpu" = a + b; a = b = None
return (add,)
tensor([[ 0.2038, 0.5530, 0.2229],
[-0.3382, 0.5160, -0.0161],
[ 1.7310, 1.3559, 1.2261]])
torch.compile
is applied recursively, so nested function calls
within the top-level compiled function will also be compiled.
def inner(x):
return torch.sin(x)
@torch.compile
def outer(x, y):
a = inner(x)
b = torch.cos(y)
return a + b
print(outer(torch.randn(3, 3), torch.randn(3, 3)))
TRACED GRAPH
===== __compiled_fn_5_03c189a8_83d7_41cc_a42b_e8e8d534d682 =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[3, 3][3, 1]cpu", L_y_: "f32[3, 3][3, 1]cpu"):
l_x_ = L_x_
l_y_ = L_y_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:98 in inner, code: return torch.sin(x)
a: "f32[3, 3][3, 1]cpu" = torch.sin(l_x_); l_x_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:104 in outer, code: b = torch.cos(y)
b: "f32[3, 3][3, 1]cpu" = torch.cos(l_y_); l_y_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:105 in outer, code: return a + b
add: "f32[3, 3][3, 1]cpu" = a + b; a = b = None
return (add,)
tensor([[ 1.2845, -0.0892, -0.2115],
[ 1.3537, -0.0816, -0.0732],
[-0.3591, 1.5748, 0.7948]])
We can also optimize torch.nn.Module
instances by either calling
its .compile()
method or by directly torch.compile
-ing the module.
This is equivalent to torch.compile
-ing the module’s __call__
method
(which indirectly calls forward
).
t = torch.randn(10, 100)
class MyModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.lin = torch.nn.Linear(3, 3)
def forward(self, x):
return torch.nn.functional.relu(self.lin(x))
mod1 = MyModule()
mod1.compile()
print(mod1(torch.randn(3, 3)))
mod2 = MyModule()
mod2 = torch.compile(mod2)
print(mod2(torch.randn(3, 3)))
TRACED GRAPH
===== __compiled_fn_7_d919aa2b_ce68_443d_ab75_c1f3ad8968a4 =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_self_modules_lin_parameters_weight_: "f32[3, 3][3, 1]cpu", L_self_modules_lin_parameters_bias_: "f32[3][1]cpu", L_x_: "f32[3, 3][3, 1]cpu"):
l_self_modules_lin_parameters_weight_ = L_self_modules_lin_parameters_weight_
l_self_modules_lin_parameters_bias_ = L_self_modules_lin_parameters_bias_
l_x_ = L_x_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:126 in forward, code: return torch.nn.functional.relu(self.lin(x))
linear: "f32[3, 3][3, 1]cpu" = torch._C._nn.linear(l_x_, l_self_modules_lin_parameters_weight_, l_self_modules_lin_parameters_bias_); l_x_ = l_self_modules_lin_parameters_weight_ = l_self_modules_lin_parameters_bias_ = None
relu: "f32[3, 3][3, 1]cpu" = torch.nn.functional.relu(linear); linear = None
return (relu,)
tensor([[0.4863, 0.2575, 0.5411],
[0.1428, 0.0000, 0.3762],
[0.4444, 0.5583, 0.7902]], grad_fn=<CompiledFunctionBackward>)
tensor([[0.0000, 0.0000, 1.4330],
[0.0000, 0.0000, 0.0536],
[0.0000, 0.0000, 0.1456]], grad_fn=<CompiledFunctionBackward>)
Demonstrating Speedups#
Now let’s demonstrate how torch.compile
speeds up a simple PyTorch example.
For a demonstration on a more complex model, see our end-to-end torch.compile tutorial.
def foo3(x):
y = x + 1
z = torch.nn.functional.relu(y)
u = z * 2
return u
opt_foo3 = torch.compile(foo3)
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1024
inp = torch.randn(4096, 4096).cuda()
print("compile:", timed(lambda: opt_foo3(inp))[1])
print("eager:", timed(lambda: foo3(inp))[1])
TRACED GRAPH
===== __compiled_fn_9_08a72ca3_c6ee_45c6_a198_0e8c99e7092d =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[4096, 4096][4096, 1]cuda:0"):
l_x_ = L_x_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:147 in foo3, code: y = x + 1
y: "f32[4096, 4096][4096, 1]cuda:0" = l_x_ + 1; l_x_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:148 in foo3, code: z = torch.nn.functional.relu(y)
z: "f32[4096, 4096][4096, 1]cuda:0" = torch.nn.functional.relu(y); y = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:149 in foo3, code: u = z * 2
u: "f32[4096, 4096][4096, 1]cuda:0" = z * 2; z = None
return (u,)
compile: 0.40412646532058716
eager: 0.02964000031352043
Notice that torch.compile
appears to take a lot longer to complete
compared to eager. This is because torch.compile
takes extra time to compile
the model on the first few executions.
torch.compile
re-uses compiled code whever possible,
so if we run our optimized model several more times, we should
see a significant improvement compared to eager.
# turn off logging for now to prevent spam
torch._logging.set_logs(graph_code=False)
eager_times = []
for i in range(10):
_, eager_time = timed(lambda: foo3(inp))
eager_times.append(eager_time)
print(f"eager time {i}: {eager_time}")
print("~" * 10)
compile_times = []
for i in range(10):
_, compile_time = timed(lambda: opt_foo3(inp))
compile_times.append(compile_time)
print(f"compile time {i}: {compile_time}")
print("~" * 10)
import numpy as np
eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert speedup > 1
print(
f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x"
)
print("~" * 10)
eager time 0: 0.00088900001719594
eager time 1: 0.0008459999808110297
eager time 2: 0.0008459999808110297
eager time 3: 0.0008479999960400164
eager time 4: 0.000846999988425523
eager time 5: 0.0008420000085607171
eager time 6: 0.0008420000085607171
eager time 7: 0.0008509375038556755
eager time 8: 0.0008399999933317304
eager time 9: 0.0008440000237897038
~~~~~~~~~~
compile time 0: 0.0005019999807700515
compile time 1: 0.0003699999942909926
compile time 2: 0.00036100001307204366
compile time 3: 0.0003539999888744205
compile time 4: 0.00035700001171790063
compile time 5: 0.0003530000103637576
compile time 6: 0.0003530000103637576
compile time 7: 0.0003499999875202775
compile time 8: 0.0003539999888744205
compile time 9: 0.0003530000103637576
~~~~~~~~~~
(eval) eager median: 0.0008459999808110297, compile median: 0.0003539999888744205, speedup: 2.389830529376495x
~~~~~~~~~~
And indeed, we can see that running our model with torch.compile
results in a significant speedup. Speedup mainly comes from reducing Python overhead and
GPU read/writes, and so the observed speedup may vary on factors such as model
architecture and batch size. For example, if a model’s architecture is simple
and the amount of data is large, then the bottleneck would be
GPU compute and the observed speedup may be less significant.
To see speedups on a real model, check out our end-to-end torch.compile tutorial.
Benefits over TorchScript#
Why should we use torch.compile
over TorchScript? Primarily, the
advantage of torch.compile
lies in its ability to handle
arbitrary Python code with minimal changes to existing code.
Compare to TorchScript, which has a tracing mode (torch.jit.trace
) and
a scripting mode (torch.jit.script
). Tracing mode is susceptible to
silent incorrectness, while scripting mode requires significant code changes
and will raise errors on unsupported Python code.
For example, TorchScript tracing silently fails on data-dependent control flow
(the if x.sum() < 0:
line below)
because only the actual control flow path is traced.
In comparison, torch.compile
is able to correctly handle it.
def f1(x, y):
if x.sum() < 0:
return -y
return y
# Test that `fn1` and `fn2` return the same result, given the same arguments `args`.
def test_fns(fn1, fn2, args):
out1 = fn1(*args)
out2 = fn2(*args)
return torch.allclose(out1, out2)
inp1 = torch.randn(5, 5)
inp2 = torch.randn(5, 5)
traced_f1 = torch.jit.trace(f1, (inp1, inp2))
print("traced 1, 1:", test_fns(f1, traced_f1, (inp1, inp2)))
print("traced 1, 2:", test_fns(f1, traced_f1, (-inp1, inp2)))
compile_f1 = torch.compile(f1)
print("compile 1, 1:", test_fns(f1, compile_f1, (inp1, inp2)))
print("compile 1, 2:", test_fns(f1, compile_f1, (-inp1, inp2)))
print("~" * 10)
/var/lib/workspace/intermediate_source/torch_compile_tutorial.py:239: TracerWarning:
Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
traced 1, 1: True
traced 1, 2: False
compile 1, 1: True
compile 1, 2: True
~~~~~~~~~~
TorchScript scripting can handle data-dependent control flow, but it can require major code changes and will raise errors when unsupported Python is used.
In the example below, we forget TorchScript type annotations and we receive
a TorchScript error because the input type for argument y
, an int
,
does not match with the default argument type, torch.Tensor
.
In comparison, torch.compile
works without requiring any type annotations.
import traceback as tb
torch._logging.set_logs(graph_code=True)
def f2(x, y):
return x + y
inp1 = torch.randn(5, 5)
inp2 = 3
script_f2 = torch.jit.script(f2)
try:
script_f2(inp1, inp2)
except:
tb.print_exc()
compile_f2 = torch.compile(f2)
print("compile 2:", test_fns(f2, compile_f2, (inp1, inp2)))
print("~" * 10)
Traceback (most recent call last):
File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 288, in <module>
script_f2(inp1, inp2)
RuntimeError: f2() Expected a value of type 'Tensor (inferred)' for argument 'y' but instead found type 'int'.
Inferred 'y' to be of type 'Tensor' because it was not annotated with an explicit type.
Position: 1
Value: 3
Declaration: f2(Tensor x, Tensor y) -> Tensor
Cast error details: Unable to cast 3 to Tensor
TRACED GRAPH
===== __compiled_fn_18_60f88fab_6a3d_4dcc_a2ea_16a1899bfb1f =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[5, 5][5, 1]cpu"):
l_x_ = L_x_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:280 in f2, code: return x + y
add: "f32[5, 5][5, 1]cpu" = l_x_ + 3; l_x_ = None
return (add,)
compile 2: True
~~~~~~~~~~
Graph Breaks#
The graph break is one of the most fundamental concepts within torch.compile
.
It allows torch.compile
to handle arbitrary Python code by interrupting
compilation, running the unsupported code, then resuming compilation.
The term “graph break” comes from the fact that torch.compile
attempts
to capture and optimize the PyTorch operation graph. When unsupported Python code is encountered,
then this graph must be “broken”.
Graph breaks result in lost optimization opportunities, which may still be undesirable,
but this is better than silent incorrectness or a hard crash.
Let’s look at a data-dependent control flow example to better see how graph breaks work.
def bar(a, b):
x = a / (torch.abs(a) + 1)
if b.sum() < 0:
b = b * -1
return x * b
opt_bar = torch.compile(bar)
inp1 = torch.ones(10)
inp2 = torch.ones(10)
opt_bar(inp1, inp2)
opt_bar(inp1, -inp2)
TRACED GRAPH
===== __compiled_fn_20_d5309909_d209_4382_9b82_0ba74ced4ca8 =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
l_a_ = L_a_
l_b_ = L_b_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:312 in bar, code: x = a / (torch.abs(a) + 1)
abs_1: "f32[10][1]cpu" = torch.abs(l_a_)
add: "f32[10][1]cpu" = abs_1 + 1; abs_1 = None
x: "f32[10][1]cpu" = l_a_ / add; l_a_ = add = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:313 in bar, code: if b.sum() < 0:
sum_1: "f32[][]cpu" = l_b_.sum(); l_b_ = None
lt: "b8[][]cpu" = sum_1 < 0; sum_1 = None
return (lt, x)
TRACED GRAPH
===== __compiled_fn_24_24e667b5_a8e5_442d_b94a_a878f1114d23 =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
l_x_ = L_x_
l_b_ = L_b_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b
mul: "f32[10][1]cpu" = l_x_ * l_b_; l_x_ = l_b_ = None
return (mul,)
TRACED GRAPH
===== __compiled_fn_26_d1830df0_39a5_4379_96f3_af6c112110cd =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_b_: "f32[10][1]cpu", L_x_: "f32[10][1]cpu"):
l_b_ = L_b_
l_x_ = L_x_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:314 in torch_dynamo_resume_in_bar_at_313, code: b = b * -1
b: "f32[10][1]cpu" = l_b_ * -1; l_b_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b
mul_1: "f32[10][1]cpu" = l_x_ * b; l_x_ = b = None
return (mul_1,)
tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
0.5000])
The first time we run bar
, we see that torch.compile
traced 2 graphs
corresponding to the following code (noting that b.sum() < 0
is False):
x = a / (torch.abs(a) + 1); b.sum()
return x * b
The second time we run bar
, we take the other branch of the if statement
and we get 1 traced graph corresponding to the code b = b * -1; return x * b
.
We do not see a graph of x = a / (torch.abs(a) + 1)
outputted the second time
since torch.compile
cached this graph from the first run and re-used it.
Let’s investigate by example how TorchDynamo would step through bar
.
If b.sum() < 0
, then TorchDynamo would run graph 1, let
Python determine the result of the conditional, then run
graph 2. On the other hand, if not b.sum() < 0
, then TorchDynamo
would run graph 1, let Python determine the result of the conditional, then
run graph 3.
We can see all graph breaks by using torch._logging.set_logs(graph_breaks=True)
.
TRACED GRAPH
===== __compiled_fn_28_e75c1c8c_4795_4a16_8d6f_90d489a9e78e =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
l_a_ = L_a_
l_b_ = L_b_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:312 in bar, code: x = a / (torch.abs(a) + 1)
abs_1: "f32[10][1]cpu" = torch.abs(l_a_)
add: "f32[10][1]cpu" = abs_1 + 1; abs_1 = None
x: "f32[10][1]cpu" = l_a_ / add; l_a_ = add = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:313 in bar, code: if b.sum() < 0:
sum_1: "f32[][]cpu" = l_b_.sum(); l_b_ = None
lt: "b8[][]cpu" = sum_1 < 0; sum_1 = None
return (lt, x)
TRACED GRAPH
===== __compiled_fn_32_e26b0760_f8cc_414d_a852_6092ac007ca7 =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
l_x_ = L_x_
l_b_ = L_b_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b
mul: "f32[10][1]cpu" = l_x_ * l_b_; l_x_ = l_b_ = None
return (mul,)
TRACED GRAPH
===== __compiled_fn_34_2b406644_b833_40a0_96ec_c1f387d13c7f =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_b_: "f32[10][1]cpu", L_x_: "f32[10][1]cpu"):
l_b_ = L_b_
l_x_ = L_x_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:314 in torch_dynamo_resume_in_bar_at_313, code: b = b * -1
b: "f32[10][1]cpu" = l_b_ * -1; l_b_ = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:315 in torch_dynamo_resume_in_bar_at_313, code: return x * b
mul_1: "f32[10][1]cpu" = l_x_ * b; l_x_ = b = None
return (mul_1,)
tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
0.5000])
In order to maximize speedup, graph breaks should be limited.
We can force TorchDynamo to raise an error upon the first graph
break encountered by using fullgraph=True
:
# Reset to clear the torch.compile cache
torch._dynamo.reset()
opt_bar_fullgraph = torch.compile(bar, fullgraph=True)
try:
opt_bar_fullgraph(torch.randn(10), torch.randn(10))
except:
tb.print_exc()
Traceback (most recent call last):
File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 360, in <module>
opt_bar_fullgraph(torch.randn(10), torch.randn(10))
File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 841, in compile_wrapper
raise e.with_traceback(None) from e.__cause__ # User compiler error
torch._dynamo.exc.Unsupported: Data-dependent branching
Explanation: Detected data-dependent branching (e.g. `if my_tensor.sum() > 0:`). Dynamo does not support tracing dynamic control flow.
Hint: This graph break is fundamental - it is unlikely that Dynamo will ever be able to trace through your code. Consider finding a workaround.
Hint: Use `torch.cond` to express dynamic control flow.
Developer debug context: attempted to jump with TensorVariable()
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0170.html
from user code:
File "/var/lib/workspace/intermediate_source/torch_compile_tutorial.py", line 313, in bar
if b.sum() < 0:
Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
In our example above, we can work around this graph break by replacing
the if statement with a torch.cond
:
from functorch.experimental.control_flow import cond
@torch.compile(fullgraph=True)
def bar_fixed(a, b):
x = a / (torch.abs(a) + 1)
def true_branch(y):
return y * -1
def false_branch(y):
# NOTE: torch.cond doesn't allow aliased outputs
return y.clone()
x = cond(b.sum() < 0, true_branch, false_branch, (b,))
return x * b
bar_fixed(inp1, inp2)
bar_fixed(inp1, -inp2)
TRACED GRAPH
===== __compiled_fn_37_6c5f108a_d951_495b_a538_024359c8fc5a =====
/usr/local/lib/python3.10/dist-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module):
def forward(self, L_a_: "f32[10][1]cpu", L_b_: "f32[10][1]cpu"):
l_a_ = L_a_
l_b_ = L_b_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:373 in bar_fixed, code: x = a / (torch.abs(a) + 1)
abs_1: "f32[10][1]cpu" = torch.abs(l_a_)
add: "f32[10][1]cpu" = abs_1 + 1; abs_1 = None
x: "f32[10][1]cpu" = l_a_ / add; l_a_ = add = x = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:382 in bar_fixed, code: x = cond(b.sum() < 0, true_branch, false_branch, (b,))
sum_1: "f32[][]cpu" = l_b_.sum()
lt: "b8[][]cpu" = sum_1 < 0; sum_1 = None
# File: /usr/local/lib/python3.10/dist-packages/torch/_higher_order_ops/cond.py:186 in cond, code: return cond_op(pred, true_fn, false_fn, operands)
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(lt, cond_true_0, cond_false_0, (l_b_,)); lt = cond_true_0 = cond_false_0 = None
x_1: "f32[10][1]cpu" = cond[0]; cond = None
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:383 in bar_fixed, code: return x * b
mul: "f32[10][1]cpu" = x_1 * l_b_; x_1 = l_b_ = None
return (mul,)
class cond_true_0(torch.nn.Module):
def forward(self, l_b_: "f32[10][1]cpu"):
l_b__1 = l_b_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:376 in true_branch, code: return y * -1
mul: "f32[10][1]cpu" = l_b__1 * -1; l_b__1 = None
return (mul,)
class cond_false_0(torch.nn.Module):
def forward(self, l_b_: "f32[10][1]cpu"):
l_b__1 = l_b_
# File: /var/lib/workspace/intermediate_source/torch_compile_tutorial.py:380 in false_branch, code: return y.clone()
clone: "f32[10][1]cpu" = l_b__1.clone(); l_b__1 = None
return (clone,)
tensor([-1., -1., -1., -1., -1., -1., -1., -1., -1., -1.])
In order to serialize graphs or to run graphs on different (i.e. Python-less)
environments, consider using torch.export
instead (from PyTorch 2.1+).
One important restriction is that torch.export
does not support graph breaks. Please check
the torch.export tutorial
for more details on torch.export
.
Check out our section on graph breaks in the torch.compile programming model for tips on how to work around graph breaks.
Troubleshooting#
Is torch.compile
failing to speed up your model? Is compile time unreasonably long?
Is your code recompiling excessively? Are you having difficulties dealing with graph breaks?
Are you looking for tips on how to best use torch.compile
?
Or maybe you simply want to learn more about the inner workings of torch.compile
?
Check out the torch.compile programming model.
Conclusion#
In this tutorial, we introduced torch.compile
by covering
basic usage, demonstrating speedups over eager mode, comparing to TorchScript,
and briefly describing graph breaks.
For an end-to-end example on a real model, check out our end-to-end torch.compile tutorial.
To troubleshoot issues and to gain a deeper understanding of how to apply torch.compile
to your code, check out the torch.compile programming model.
We hope that you will give torch.compile
a try!
Total running time of the script: (0 minutes 16.527 seconds)