Note
Click here to download the full example code
torch.export
AOTInductor Tutorial for Python runtime (Beta)¶
Created On: Aug 23, 2024 | Last Updated: Jan 24, 2025 | Last Verified: Nov 05, 2024
Author: Ankith Gunapal, Bin Bao, Angela Yi
Warning
torch._inductor.aoti_compile_and_package
and
torch._inductor.aoti_load_package
are in Beta status and are subject
to backwards compatibility breaking changes. This tutorial provides an
example of how to use these APIs for model deployment using Python
runtime.
It has been shown previously how AOTInductor can be used to do Ahead-of-Time compilation of PyTorch exported models by creating an artifact that can be run in a non-Python environment. In this tutorial, you will learn an end-to-end example of how to use AOTInductor for Python runtime.
Contents
Prerequisites¶
PyTorch 2.6 or later
Basic understanding of
torch.export
and AOTInductorComplete the AOTInductor: Ahead-Of-Time Compilation for Torch.Export-ed Models tutorial
What you will learn¶
How to use AOTInductor for Python runtime.
How to use
torch._inductor.aoti_compile_and_package()
along withtorch.export.export()
to generate a compiled artifactHow to load and run the artifact in a Python runtime using
torch._export.aot_load()
.When to you use AOTInductor with a Python runtime
Model Compilation¶
We will use the TorchVision pretrained ResNet18
model as an example.
The first step is to export the model to a graph representation using
torch.export.export()
. To learn more about using this function, you can
check out the docs or the
tutorial.
Once we have exported the PyTorch model and obtained an ExportedProgram
,
we can apply torch._inductor.aoti_compile_and_package()
to AOTInductor
to compile the program to a specified device, and save the generated contents
into a “.pt2” artifact.
Note
This API supports the same available options that torch.compile()
has, such as mode
and max_autotune
(for those who want to enable
CUDA graphs and leverage Triton based matrix multiplications and
convolutions)
import os
import torch
import torch._inductor
from torchvision.models import ResNet18_Weights, resnet18
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.eval()
with torch.inference_mode():
inductor_configs = {}
if torch.cuda.is_available():
device = "cuda"
inductor_configs["max_autotune"] = True
else:
device = "cpu"
model = model.to(device=device)
example_inputs = (torch.randn(2, 3, 224, 224, device=device),)
exported_program = torch.export.export(
model,
example_inputs,
)
path = torch._inductor.aoti_compile_and_package(
exported_program,
package_path=os.path.join(os.getcwd(), "resnet18.pt2"),
inductor_configs=inductor_configs
)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /var/lib/ci-user/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
0%| | 0.00/44.7M [00:00<?, ?B/s]
92%|#########1| 40.9M/44.7M [00:00<00:00, 429MB/s]
100%|##########| 44.7M/44.7M [00:00<00:00, 429MB/s]
/usr/local/lib/python3.10/dist-packages/torch/_inductor/compile_fx.py:236: UserWarning:
TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
AUTOTUNE convolution(2x3x224x224, 64x3x7x7)
convolution 0.0973 ms 100.0%
triton_convolution2d_4 0.1505 ms 64.6% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_0 0.1536 ms 63.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_3 0.1956 ms 49.7% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_5 0.2386 ms 40.8% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_2 0.3041 ms 32.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=1, num_warps=8
triton_convolution2d_1 0.4076 ms 23.9% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=7, KERNEL_W=7, PADDING_H=3, PADDING_W=3, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.4203 seconds and 0.0019 seconds precompiling for 7 choices
AUTOTUNE convolution(2x64x56x56, 64x64x3x3)
triton_convolution2d_11 0.0369 ms 100.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_6 0.0379 ms 97.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_10 0.0410 ms 90.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_9 0.0420 ms 87.8% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_7 0.0604 ms 61.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_12 0.0655 ms 56.3% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
convolution 0.0799 ms 46.2%
triton_convolution2d_8 0.1208 ms 30.5% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.1812 seconds and 0.0004 seconds precompiling for 8 choices
AUTOTUNE convolution(2x64x56x56, 128x64x3x3)
triton_convolution2d_38 0.0307 ms 100.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_39 0.0358 ms 85.7% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_34 0.0502 ms 61.2% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_35 0.0584 ms 52.6% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_40 0.0625 ms 49.2% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_37 0.0645 ms 47.6% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
convolution 0.0870 ms 35.3%
triton_convolution2d_36 0.1126 ms 27.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=1, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.1761 seconds and 0.0002 seconds precompiling for 8 choices
AUTOTUNE convolution(2x128x28x28, 128x128x3x3)
triton_convolution2d_45 0.0492 ms 100.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_46 0.0748 ms 65.8% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
convolution 0.0758 ms 64.9%
triton_convolution2d_41 0.0901 ms 54.5% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_42 0.1106 ms 44.4% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_44 0.1147 ms 42.9% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_47 0.1198 ms 41.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_43 0.2273 ms 21.6% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.2084 seconds and 0.0002 seconds precompiling for 8 choices
AUTOTUNE convolution(2x64x56x56, 128x64x1x1)
triton_convolution2d_52 0.0082 ms 100.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
triton_convolution2d_53 0.0102 ms 80.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
triton_convolution2d_48 0.0113 ms 72.7% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
triton_convolution2d_51 0.0133 ms 61.5% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
triton_convolution2d_54 0.0133 ms 61.5% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
triton_convolution2d_49 0.0143 ms 57.1% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
triton_convolution2d_50 0.0184 ms 44.4% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=1, num_warps=8
convolution 0.0307 ms 26.7%
SingleProcess AUTOTUNE benchmarking takes 0.1477 seconds and 0.0002 seconds precompiling for 8 choices
AUTOTUNE convolution(2x128x28x28, 256x128x3x3)
convolution 0.0461 ms 100.0%
triton_convolution2d_73 0.0461 ms 100.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_69 0.1065 ms 43.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_70 0.1085 ms 42.5% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_74 0.1085 ms 42.5% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_75 0.1157 ms 39.8% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_72 0.1208 ms 38.1% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_71 0.1782 ms 25.9% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=1, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.2115 seconds and 0.0002 seconds precompiling for 8 choices
AUTOTUNE convolution(2x256x14x14, 256x256x3x3)
convolution 0.0850 ms 100.0%
triton_convolution2d_80 0.0922 ms 92.2% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_81 0.2038 ms 41.7% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_76 0.2058 ms 41.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_77 0.2130 ms 39.9% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_78 0.2130 ms 39.9% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=512, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8
triton_convolution2d_79 0.2232 ms 38.1% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_82 0.2335 ms 36.4% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.2763 seconds and 0.0002 seconds precompiling for 8 choices
AUTOTUNE convolution(2x128x28x28, 256x128x1x1)
triton_convolution2d_87 0.0102 ms 100.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
triton_convolution2d_86 0.0184 ms 55.6% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
triton_convolution2d_88 0.0184 ms 55.6% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
triton_convolution2d_89 0.0184 ms 55.6% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
triton_convolution2d_83 0.0225 ms 45.5% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
triton_convolution2d_84 0.0225 ms 45.5% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
convolution 0.0246 ms 41.7%
triton_convolution2d_85 0.0256 ms 40.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=1, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.1467 seconds and 0.0002 seconds precompiling for 8 choices
AUTOTUNE convolution(2x256x14x14, 512x256x3x3)
convolution 0.0860 ms 100.0%
triton_convolution2d_108 0.0901 ms 95.4% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_104 0.2079 ms 41.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_106 0.2099 ms 40.9% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=512, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=1, num_warps=8
triton_convolution2d_109 0.2150 ms 40.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_110 0.2335 ms 36.8% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_107 0.2365 ms 36.3% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_105 0.2724 ms 31.6% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=2, STRIDE_W=2, UNROLL=False, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.2768 seconds and 0.0002 seconds precompiling for 8 choices
AUTOTUNE convolution(2x512x7x7, 512x512x3x3)
convolution 0.0829 ms 100.0%
triton_convolution2d_115 0.1823 ms 45.5% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_113 0.2161 ms 38.4% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=128, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8
triton_convolution2d_117 0.2714 ms 30.6% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_112 0.2857 ms 29.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=128, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_111 0.4106 ms 20.2% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_116 0.4229 ms 19.6% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_114 0.4434 ms 18.7% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.2957 seconds and 0.0002 seconds precompiling for 8 choices
AUTOTUNE convolution(2x256x14x14, 512x256x1x1)
triton_convolution2d_122 0.0154 ms 100.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
triton_convolution2d_120 0.0256 ms 60.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=512, BLOCK_N=16, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=1, num_warps=8
triton_convolution2d_118 0.0287 ms 53.6% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
triton_convolution2d_119 0.0287 ms 53.6% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=4
triton_convolution2d_124 0.0307 ms 50.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
triton_convolution2d_121 0.0317 ms 48.4% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
triton_convolution2d_123 0.0317 ms 48.4% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=2, STRIDE_W=2, UNROLL=True, num_stages=2, num_warps=8
convolution 0.0358 ms 42.9%
SingleProcess AUTOTUNE benchmarking takes 0.1557 seconds and 0.0002 seconds precompiling for 8 choices
AUTOTUNE addmm(2x1000, 2x512, 512x1000)
addmm 0.0154 ms 100.0%
triton_mm_142 0.0215 ms 71.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
triton_mm_143 0.0225 ms 68.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=2
triton_mm_140 0.0236 ms 65.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=16, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=2
triton_mm_152 0.0297 ms 51.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_141 0.0307 ms 50.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=5, num_warps=4
triton_mm_146 0.0307 ms 50.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=16, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_153 0.0307 ms 50.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=4
triton_mm_139 0.0338 ms 45.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=16, BLOCK_N=32, EVEN_K=True, GROUP_M=8, num_stages=1, num_warps=2
triton_mm_145 0.0369 ms 41.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=16, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.4915 seconds and 0.0002 seconds precompiling for 18 choices
The result of aoti_compile_and_package()
is an artifact “resnet18.pt2”
which can be loaded and executed in Python and C++.
The artifact itself contains a bunch of AOTInductor generated code, such as a generated C++ runner file, a shared library compiled from the C++ file, and CUDA binary files, aka cubin files, if optimizing for CUDA.
Structure-wise, the artifact is a structured .zip
file, with the following
specification:
We can use the following command to inspect the artifact contents:
$ unzip -l resnet18.pt2
Archive: resnet18.pt2
Length Date Time Name
--------- ---------- ----- ----
1 01-08-2025 16:40 version
3 01-08-2025 16:40 archive_format
10088 01-08-2025 16:40 data/aotinductor/model/cagzt6akdaczvxwtbvqe34otfe5jlorktbqlojbzqjqvbfsjlge4.cubin
17160 01-08-2025 16:40 data/aotinductor/model/c6oytfjmt5w4c7onvtm6fray7clirxt7q5xjbwx3hdydclmwoujz.cubin
16616 01-08-2025 16:40 data/aotinductor/model/c7ydp7nocyz323hij4tmlf2kcedmwlyg6r57gaqzcsy3huneamu6.cubin
17776 01-08-2025 16:40 data/aotinductor/model/cyqdf46ordevqhiddvpdpp3uzwatfbzdpl3auj2nx23uxvplnne2.cubin
10856 01-08-2025 16:40 data/aotinductor/model/cpzfebfgrusqslui7fxsuoo4tvwulmrxirc5tmrpa4mvrbdno7kn.cubin
14608 01-08-2025 16:40 data/aotinductor/model/c5ukeoz5wmaszd7vczdz2qhtt6n7tdbl3b6wuy4rb2se24fjwfoy.cubin
11376 01-08-2025 16:40 data/aotinductor/model/csu3nstcp56tsjfycygaqsewpu64l5s6zavvz7537cm4s4cv2k3r.cubin
10984 01-08-2025 16:40 data/aotinductor/model/cp76lez4glmgq7gedf2u25zvvv6rksv5lav4q22dibd2zicbgwj3.cubin
14736 01-08-2025 16:40 data/aotinductor/model/c2bb5p6tnwz4elgujqelsrp3unvkgsyiv7xqxmpvuxcm4jfl7pc2.cubin
11376 01-08-2025 16:40 data/aotinductor/model/c6eopmb2b4ngodwsayae4r5q6ni3jlfogfbdk3ypg56tgpzhubfy.cubin
11624 01-08-2025 16:40 data/aotinductor/model/chmwe6lvoekzfowdbiizitm3haiiuad5kdm6sd2m6mv6dkn2zk32.cubin
15632 01-08-2025 16:40 data/aotinductor/model/c3jop5g344hj3ztsu4qm6ibxyaaerlhkzh2e6emak23rxfje6jam.cubin
25472 01-08-2025 16:40 data/aotinductor/model/chaiixybeiuuitm2nmqnxzijzwgnn2n7uuss4qmsupgblfh3h5hk.cubin
139389 01-08-2025 16:40 data/aotinductor/model/cvk6qzuybruhwxtfblzxiov3rlrziv5fkqc4mdhbmantfu3lmd6t.cpp
27 01-08-2025 16:40 data/aotinductor/model/cvk6qzuybruhwxtfblzxiov3rlrziv5fkqc4mdhbmantfu3lmd6t_metadata.json
47195424 01-08-2025 16:40 data/aotinductor/model/cvk6qzuybruhwxtfblzxiov3rlrziv5fkqc4mdhbmantfu3lmd6t.so
--------- -------
47523148 18 files
Model Inference in Python¶
To load and run the artifact in Python, we can use torch._inductor.aoti_load_package()
.
import os
import torch
import torch._inductor
model_path = os.path.join(os.getcwd(), "resnet18.pt2")
compiled_model = torch._inductor.aoti_load_package(model_path)
example_inputs = (torch.randn(2, 3, 224, 224, device=device),)
with torch.inference_mode():
output = compiled_model(example_inputs)
When to use AOTInductor with a Python Runtime¶
There are mainly two reasons why one would use AOTInductor with a Python Runtime:
torch._inductor.aoti_compile_and_package
generates a singular serialized artifact. This is useful for model versioning for deployments and tracking model performance over time.With
torch.compile()
being a JIT compiler, there is a warmup cost associated with the first compilation. Your deployment needs to account for the compilation time taken for the first inference. With AOTInductor, the compilation is done ahead of time usingtorch.export.export
andtorch._inductor.aoti_compile_and_package
. At deployment time, after loading the model, running inference does not have any additional cost.
The section below shows the speedup achieved with AOTInductor for first inference
We define a utility function timed
to measure the time taken for inference
import time
def timed(fn):
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for accurate
# measurement on CUDA enabled devices.
if torch.cuda.is_available():
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
else:
start = time.time()
result = fn()
if torch.cuda.is_available():
end.record()
torch.cuda.synchronize()
else:
end = time.time()
# Measure time taken to execute the function in miliseconds
if torch.cuda.is_available():
duration = start.elapsed_time(end)
else:
duration = (end - start) * 1000
return result, duration
Lets measure the time for first inference using AOTInductor
torch._dynamo.reset()
model = torch._inductor.aoti_load_package(model_path)
example_inputs = (torch.randn(1, 3, 224, 224, device=device),)
with torch.inference_mode():
_, time_taken = timed(lambda: model(example_inputs))
print(f"Time taken for first inference for AOTInductor is {time_taken:.2f} ms")
Time taken for first inference for AOTInductor is 3.17 ms
Lets measure the time for first inference using torch.compile
torch._dynamo.reset()
model = resnet18(weights=ResNet18_Weights.DEFAULT).to(device)
model.eval()
model = torch.compile(model)
example_inputs = torch.randn(1, 3, 224, 224, device=device)
with torch.inference_mode():
_, time_taken = timed(lambda: model(example_inputs))
print(f"Time taken for first inference for torch.compile is {time_taken:.2f} ms")
Time taken for first inference for torch.compile is 4504.18 ms
We see that there is a drastic speedup in first inference time using AOTInductor compared
to torch.compile
Conclusion¶
In this recipe, we have learned how to effectively use the AOTInductor for Python runtime by
compiling and loading a pretrained ResNet18
model. This process
demonstrates the practical application of generating a compiled artifact and
running it within a Python environment. We also looked at the advantage of using
AOTInductor in model deployments, with regards to speed up in first inference time.
Total running time of the script: ( 0 minutes 28.302 seconds)