Rate this Page
torch.compile End-to-End Tutorial">

torch.compile End-to-End Tutorial#

Author: William Wen

torch.compile is the new way to speed up your PyTorch code! torch.compile makes PyTorch code run faster by JIT-compiling PyTorch code into optimized kernels, while requiring minimal code changes.

This tutorial covers an end-to-end example of training and evaluating a real model with torch.compile. For a gentle introduction to torch.compile, please check out the introduction to torch.compile tutorial.

Required pip Dependencies

  • torch >= 2.0

  • torchvision

What you will learn
  • How to apply torch.compile to a real model

  • torch.compile speedups on a real model

  • torch.compile’s first few iterations are expected to be slower due to compilation overhead

# NOTE: a modern NVIDIA GPU (H100, A100, or V100) is recommended for this tutorial in
# order to reproduce the speedup numbers shown below and documented elsewhere.

import torch
import warnings

gpu_ok = False
if torch.cuda.is_available():
    device_cap = torch.cuda.get_device_capability()
    if device_cap in ((7, 0), (8, 0), (9, 0)):
        gpu_ok = True

if not gpu_ok:
    warnings.warn(
        "GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower "
        "than expected."
    )
/var/lib/workspace/intermediate_source/torch_compile_full_example.py:51: UserWarning:

GPU is not NVIDIA V100, A100, or H100. Speedup numbers may be lower than expected.

Let’s demonstrate how using torch.compile can speed up a real model. We will compare standard eager mode and torch.compile by evaluating and training a torchvision model on random data.

Before we start, we need to define some utility functions.

# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000


# Generates random input and targets data for the model, where `b` is
# batch size.
def generate_data(b):
    return (
        torch.randn(b, 3, 128, 128).to().cuda(),
        torch.randint(1000, (b,)).cuda(),
    )


N_ITERS = 10

from torchvision.models import densenet121


def init_model():
    return densenet121().cuda()

First, let’s compare inference.

Note that in the call to torch.compile, we have the additional mode argument, which we will discuss below.

model = init_model()

# Note that we generally recommend directly compiling a torch.nn.Module by calling
# its .compile() method.
model_opt = init_model()
model_opt.compile(mode="reduce-overhead")

inp = generate_data(16)[0]
with torch.no_grad():
    print("eager:", timed(lambda: model(inp))[1])
    print("compile:", timed(lambda: model_opt(inp))[1])
eager: 0.3604090576171875
/usr/local/lib/python3.10/dist-packages/torch/backends/cuda/__init__.py:131: UserWarning:

Please use the new API settings to control TF32 behavior, such as torch.backends.cudnn.conv.fp32_precision = 'tf32' or torch.backends.cuda.matmul.fp32_precision = 'ieee'. Old settings, e.g, torch.backends.cuda.matmul.allow_tf32 = True, torch.backends.cudnn.allow_tf32 = True, allowTF32CuDNN() and allowTF32CuBLAS() will be deprecated after Pytorch 2.9. Please see https://pytorch.org/docs/main/notes/cuda.html#tensorfloat-32-tf32-on-ampere-and-later-devices (Triggered internally at /pytorch/aten/src/ATen/Context.cpp:80.)

/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:312: UserWarning:

TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.

compile: 51.42688671875

Notice that torch.compile takes a lot longer to complete compared to eager. This is because torch.compile compiles the model into optimized kernels as it executes. In our example, the structure of the model doesn’t change, and so recompilation is not needed. So if we run our optimized model several more times, we should see a significant improvement compared to eager.

eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, eager_time = timed(lambda: model(inp))
    eager_times.append(eager_time)
    print(f"eager eval time {i}: {eager_time}")

print("~" * 10)

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)[0]
    with torch.no_grad():
        _, compile_time = timed(lambda: model_opt(inp))
    compile_times.append(compile_time)
    print(f"compile eval time {i}: {compile_time}")
print("~" * 10)

import numpy as np

eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert speedup > 1
print(
    f"(eval) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x"
)
print("~" * 10)
eager eval time 0: 0.01820876884460449
eager eval time 1: 0.016675840377807616
eager eval time 2: 0.016416767120361327
eager eval time 3: 0.01638400077819824
eager eval time 4: 0.016457696914672852
eager eval time 5: 0.016348159790039063
eager eval time 6: 0.016328704833984374
eager eval time 7: 0.016314367294311523
eager eval time 8: 0.01641472053527832
eager eval time 9: 0.01641164779663086
~~~~~~~~~~
compile eval time 0: 0.061233150482177735
compile eval time 1: 0.007819263935089112
compile eval time 2: 0.008339455604553223
compile eval time 3: 0.007483391761779785
compile eval time 4: 0.007483359813690186
compile eval time 5: 0.007465983867645264
compile eval time 6: 0.0074670081138610836
compile eval time 7: 0.0074670081138610836
compile eval time 8: 0.007468031883239746
compile eval time 9: 0.0074700798988342285
~~~~~~~~~~
(eval) eager median: 0.016413184165954588, compile median: 0.007476719856262207, speedup: 2.1952386182033488x
~~~~~~~~~~

And indeed, we can see that running our model with torch.compile results in a significant speedup. Speedup mainly comes from reducing Python overhead and GPU read/writes, and so the observed speedup may vary on factors such as model architecture and batch size. For example, if a model’s architecture is simple and the amount of data is large, then the bottleneck would be GPU compute and the observed speedup may be less significant.

You may also see different speedup results depending on the chosen mode argument. The "reduce-overhead" mode uses CUDA graphs to further reduce the overhead of Python. For your own models, you may need to experiment with different modes to maximize speedup. You can read more about modes here.

You may might also notice that the second time we run our model with torch.compile is significantly slower than the other runs, although it is much faster than the first run. This is because the "reduce-overhead" mode runs a few warm-up iterations for CUDA graphs.

Now, let’s consider comparing training.

model = init_model()
opt = torch.optim.Adam(model.parameters())


def train(mod, data):
    opt.zero_grad(True)
    pred = mod(data[0])
    loss = torch.nn.CrossEntropyLoss()(pred, data[1])
    loss.backward()
    opt.step()


eager_times = []
for i in range(N_ITERS):
    inp = generate_data(16)
    _, eager_time = timed(lambda: train(model, inp))
    eager_times.append(eager_time)
    print(f"eager train time {i}: {eager_time}")
print("~" * 10)

model = init_model()
opt = torch.optim.Adam(model.parameters())

# Note that because we are compiling a regular Python function, we do not
# call any .compile() method.
train_opt = torch.compile(train, mode="reduce-overhead")

compile_times = []
for i in range(N_ITERS):
    inp = generate_data(16)
    _, compile_time = timed(lambda: train_opt(model, inp))
    compile_times.append(compile_time)
    print(f"compile train time {i}: {compile_time}")
print("~" * 10)

eager_med = np.median(eager_times)
compile_med = np.median(compile_times)
speedup = eager_med / compile_med
assert speedup > 1
print(
    f"(train) eager median: {eager_med}, compile median: {compile_med}, speedup: {speedup}x"
)
print("~" * 10)
eager train time 0: 0.2882539367675781
eager train time 1: 0.05161676788330078
eager train time 2: 0.049276927947998046
eager train time 3: 0.05065420913696289
eager train time 4: 0.8006707153320313
eager train time 5: 0.05070438385009766
eager train time 6: 0.05034195327758789
eager train time 7: 0.05022825622558594
eager train time 8: 0.050223102569580076
eager train time 9: 0.05043302536010742
~~~~~~~~~~
compile train time 0: 151.00690625
compile train time 1: 2.915029052734375
compile train time 2: 0.02395030403137207
compile train time 3: 0.021402624130249022
compile train time 4: 0.020746240615844725
compile train time 5: 0.02069811248779297
compile train time 6: 0.020706304550170897
compile train time 7: 0.020715520858764647
compile train time 8: 0.02070425605773926
compile train time 9: 0.020745216369628908
~~~~~~~~~~
(train) eager median: 0.05054361724853516, compile median: 0.020745728492736815, speedup: 2.436338510177203x
~~~~~~~~~~

Again, we can see that torch.compile takes longer in the first iteration, as it must compile the model, but in subsequent iterations, we see significant speedups compared to eager.

We remark that the speedup numbers presented in this tutorial are for demonstration purposes only. Official speedup values can be seen at the TorchInductor performance dashboard.

Conclusion#

In this tutorial, we applied torch.compile to training and inference on a real model, demonstrating speedups.

Importantly, we note that the first few iterations of a compiled model are slower than eager mode due to compilation overhead, but subsequent iterations are expected to have speedups.

For a gentle introduction to torch.compile, please check out the introduction to torch.compile tutorial.

To troubleshoot issues and to gain a deeper understanding of how to apply torch.compile to your code, check out the torch.compile programming model.

We hope that you will give torch.compile a try!

Total running time of the script: (3 minutes 29.786 seconds)