Note
Click here to download the full example code
(prototype) GPU Quantization with TorchAO¶
Created On: Feb 06, 2024 | Last Updated: Oct 01, 2024 | Last Verified: Nov 05, 2024
Author: HDCharles
In this tutorial, we will walk you through the quantization and optimization
of the popular segment anything model. These
steps will mimic some of those taken to develop the
segment-anything-fast
repo. This step-by-step guide demonstrates how you can
apply these techniques to speed up your own models, especially those
that use transformers. To that end, we will focus on widely applicable
techniques, such as optimizing performance with torch.compile
and
quantization and measure their impact.
Set up Your Environment¶
First, let’s configure your environment. This guide was written for CUDA 12.1. We have run this tutorial on an A100-PG509-200 power limited to 330.00 W. If you are using a different hardware, you might see different performance numbers.
> conda create -n myenv python=3.10
> pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
> pip install git+https://github.com/facebookresearch/segment-anything.git
> pip install git+https://github.com/pytorch-labs/ao.git
Segment Anything Model checkpoint setup:
Go to the segment-anything repo checkpoint and download the
vit_h
checkpoint. Alternatively, you can usewget
(for example,wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth --directory-prefix=<path>
).Pass in that directory by editing the code below to say:
{sam_checkpoint_base_path}=<path>
import torch
from torchao.quantization.quant_api import quantize_, int8_dynamic_activation_int8_weight
from torchao.utils import unwrap_tensor_subclass, TORCH_VERSION_AT_LEAST_2_5
from segment_anything import sam_model_registry
from torch.utils.benchmark import Timer
sam_checkpoint_base_path = "data"
model_type = 'vit_h'
model_name = 'sam_vit_h_4b8939.pth'
checkpoint_path = f"{sam_checkpoint_base_path}/{model_name}"
batchsize = 16
only_one_block = True
@torch.no_grad()
def benchmark(f, *args, **kwargs):
for _ in range(3):
f(*args, **kwargs)
torch.cuda.synchronize()
torch.cuda.reset_peak_memory_stats()
t0 = Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
res = t0.adaptive_autorange(.03, min_run_time=.2, max_run_time=20)
return {'time':res.median * 1e3, 'memory': torch.cuda.max_memory_allocated()/1e9}
def get_sam_model(only_one_block=False, batchsize=1):
sam = sam_model_registry[model_type](checkpoint=checkpoint_path).cuda()
model = sam.image_encoder.eval()
image = torch.randn(batchsize, 3, 1024, 1024, device='cuda')
# code to use just a single block of the model
if only_one_block:
model = model.blocks[0]
image = torch.randn(batchsize, 64, 64, 1280, device='cuda')
return model, image
In this tutorial, we focus on quantizing the image_encoder
because the
inputs to it are statically sized while the prompt encoder and mask
decoder have variable sizes which makes them harder to quantize.
We’ll focus on just a single block at first to make the analysis easier.
Let’s start by measuring the baseline runtime.
try:
model, image = get_sam_model(only_one_block, batchsize)
fp32_res = benchmark(model, image)
print(f"base fp32 runtime of the model is {fp32_res['time']:0.2f}ms and peak memory {fp32_res['memory']:0.2f}GB")
# base fp32 runtime of the model is 186.16ms and peak memory 6.33GB
except Exception as e:
print("unable to run fp32 model: ", e)
base fp32 runtime of the model is 200.66ms and peak memory 6.33GB
We can achieve an instant performance boost by converting the model to bfloat16. The reason we opt for bfloat16 over fp16 is due to its dynamic range, which is comparable to that of fp32. Both bfloat16 and fp32 possess 8 exponential bits, whereas fp16 only has 4. This larger dynamic range helps protect us from overflow errors and other issues that can arise when scaling and rescaling tensors due to quantization.
model, image = get_sam_model(only_one_block, batchsize)
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
bf16_res = benchmark(model, image)
print(f"bf16 runtime of the block is {bf16_res['time']:0.2f}ms and peak memory {bf16_res['memory']: 0.2f}GB")
# bf16 runtime of the block is 25.43ms and peak memory 3.17GB
bf16 runtime of the block is 70.49ms and peak memory 3.17GB
Just this quick change improves runtime by a factor of ~7x in the tests we have conducted (186.16ms to 25.43ms).
Next, let’s use torch.compile
with our model to see how much the performance
improves.
model_c = torch.compile(model, mode='max-autotune')
comp_res = benchmark(model_c, image)
print(f"bf16 compiled runtime of the block is {comp_res['time']:0.2f}ms and peak memory {comp_res['memory']: 0.2f}GB")
# bf16 compiled runtime of the block is 19.95ms and peak memory 2.24GB
AUTOTUNE mm(65536x1280, 1280x5120)
triton_mm_124 12.6710 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_121 12.7396 ms 99.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_122 12.7734 ms 99.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_117 12.7918 ms 99.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_118 12.8133 ms 98.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_114 12.9444 ms 97.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
triton_mm_119 12.9720 ms 97.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
mm 13.0150 ms 97.4%
triton_mm_115 13.2024 ms 96.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_113 13.3939 ms 94.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 2.1042 seconds and 0.7216 seconds precompiling for 20 choices
AUTOTUNE mm(65536x5120, 5120x1280)
triton_mm_143 12.5225 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
mm 12.5563 ms 99.7%
triton_mm_140 12.6474 ms 99.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_138 12.7406 ms 98.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_136 12.8666 ms 97.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_137 12.9270 ms 96.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_141 13.0693 ms 95.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_134 13.5567 ms 92.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_133 13.6858 ms 91.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
triton_mm_144 13.8066 ms 90.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 2.1415 seconds and 0.0002 seconds precompiling for 20 choices
AUTOTUNE mm(78400x1280, 1280x1280)
triton_mm_99 3.8144 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_98 3.8154 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_105 3.8185 ms 99.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_102 3.8380 ms 99.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_103 3.8420 ms 99.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_100 3.8984 ms 97.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_95 3.9240 ms 97.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
triton_mm_96 3.9363 ms 96.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_94 3.9516 ms 96.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
mm 3.9557 ms 96.4%
SingleProcess AUTOTUNE benchmarking takes 1.0981 seconds and 0.0002 seconds precompiling for 20 choices
AUTOTUNE mm(78400x1280, 1280x3840)
triton_mm_16 11.3910 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_13 11.4739 ms 99.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_9 11.4852 ms 99.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_14 11.4924 ms 99.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_10 11.5149 ms 98.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_11 11.6060 ms 98.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
mm 11.6828 ms 97.5%
triton_mm_6 11.8405 ms 96.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
triton_mm_7 11.9153 ms 95.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_5 12.0586 ms 94.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 2.0337 seconds and 0.0018 seconds precompiling for 20 choices
AUTOTUNE bmm(6400x196x80, 6400x80x196)
triton_bmm_29 1.8237 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
triton_bmm_33 1.8401 ms 99.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
triton_bmm_25 1.9323 ms 94.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4
triton_bmm_34 1.9436 ms 93.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=8
triton_bmm_35 1.9476 ms 93.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_bmm_28 1.9599 ms 93.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_bmm_27 1.9692 ms 92.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
triton_bmm_23 2.0142 ms 90.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
triton_bmm_26 2.0326 ms 89.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=8
triton_bmm_30 2.0429 ms 89.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.9446 seconds and 0.0003 seconds precompiling for 20 choices
AUTOTUNE bmm(14x89600x80, 14x80x16)
triton_bmm_53 0.4690 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
triton_bmm_39 0.4731 ms 99.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=2
triton_bmm_42 0.4813 ms 97.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
triton_bmm_48 0.4813 ms 97.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
bmm 0.4833 ms 97.0%
triton_bmm_45 0.4915 ms 95.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_bmm_52 0.4915 ms 95.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_bmm_41 0.5028 ms 93.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
triton_bmm_40 0.5089 ms 92.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=2
triton_bmm_47 0.5263 ms 89.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.6906 seconds and 0.0003 seconds precompiling for 17 choices
AUTOTUNE bmm(14x89600x80, 14x80x14)
triton_bmm_55 0.5437 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=2
triton_bmm_58 0.5437 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
triton_bmm_64 0.5437 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
triton_bmm_69 0.5448 ms 99.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
triton_bmm_61 0.5560 ms 97.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_bmm_68 0.5571 ms 97.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_bmm_57 0.5755 ms 94.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
triton_bmm_56 0.6164 ms 88.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=2
triton_bmm_63 0.6390 ms 85.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
triton_bmm_66 0.6861 ms 79.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.7046 seconds and 0.0003 seconds precompiling for 17 choices
AUTOTUNE bmm(6400x196x196, 6400x196x80)
triton_bmm_81 1.8504 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_bmm_71 1.8852 ms 98.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4
triton_bmm_72 1.9732 ms 93.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
triton_bmm_80 1.9835 ms 93.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
triton_bmm_77 2.0060 ms 92.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=8
triton_bmm_79 2.0183 ms 91.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_bmm_76 2.0490 ms 90.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4
triton_bmm_85 2.0859 ms 88.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=8
triton_bmm_84 2.1207 ms 87.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
triton_bmm_86 2.1688 ms 85.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.8771 seconds and 0.0003 seconds precompiling for 20 choices
bf16 compiled runtime of the block is 59.48ms and peak memory 2.24GB
The first time this is run, you should see a sequence of AUTOTUNE
outputs which occurs when inductor compares the performance between
various kernel parameters for a kernel. This only happens once (unless
you delete your cache) so if you run the cell again you should just get
the benchmark output.
torch.compile
yields about another 27% improvement. This brings the
model to a reasonable baseline where we now have to work a bit harder
for improvements.
Next, let’s apply quantization. Quantization for GPUs comes in three main forms in torchao which is just native pytorch+python code. This includes:
int8 dynamic quantization
int8 weight-only quantization
int4 weight-only quantization
Different models, or sometimes different layers in a model can require different techniques. For models which are heavily compute bound, dynamic quantization tends to work the best since it swaps the normal expensive floating point matmul ops with integer versions. Weight-only quantization works better in memory bound situations where the benefit comes from loading less weight data, rather than doing less computation. The torchao APIs:
int8_dynamic_activation_int8_weight()
,
int8_weight_only()
or
int4_weight_only()
can be used to easily apply the desired quantization technique and then
once the model is compiled with torch.compile
with max-autotune
, quantization is
complete and we can see our speedup.
Note
You might experience issues with these on older versions of PyTorch. If you run
into an issue, you can use apply_dynamic_quant
and
apply_weight_only_int8_quant
instead as drop in replacement for the two
above (no replacement for int4).
The difference between the two APIs is that int8_dynamic_activation
API
alters the weight tensor of the linear module so instead of doing a
normal linear, it does a quantized operation. This is helpful when you
have non-standard linear ops that do more than one thing. The apply
APIs directly swap the linear modules for a quantized module which
works on older versions but doesn’t work with non-standard linear
modules.
In this case Segment Anything is compute-bound so we’ll use dynamic quantization:
del model_c, model, image
model, image = get_sam_model(only_one_block, batchsize)
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
quantize_(model, int8_dynamic_activation_int8_weight())
if not TORCH_VERSION_AT_LEAST_2_5:
# needed for subclass + compile to work on older versions of pytorch
unwrap_tensor_subclass(model)
model_c = torch.compile(model, mode='max-autotune')
quant_res = benchmark(model_c, image)
print(f"bf16 compiled runtime of the quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
# bf16 compiled runtime of the quantized block is 19.04ms and peak memory 3.58GB
AUTOTUNE int_mm(65536x5120, 5120x1280)
triton_mm_257 6.2730 ms 100.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
_int_mm 6.4635 ms 97.1%
triton_mm_258 6.6191 ms 94.8% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=128, BLOCK_N=256, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_259 6.6857 ms 93.8% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=256, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_251 6.7287 ms 93.2% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_250 6.7379 ms 93.1% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_252 6.9120 ms 90.8% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_256 7.0164 ms 89.4% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=8
triton_mm_249 7.1793 ms 87.4% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
triton_mm_253 7.1854 ms 87.3% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.9759 seconds and 0.1450 seconds precompiling for 12 choices
AUTOTUNE int_mm(78400x1280, 1280x3840)
triton_mm_154 5.7938 ms 100.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_155 6.2966 ms 92.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=128, BLOCK_N=256, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_156 6.3171 ms 91.7% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=256, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
_int_mm 6.3601 ms 91.1%
triton_mm_147 6.4840 ms 89.4% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_146 6.7850 ms 85.4% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
triton_mm_148 7.0461 ms 82.2% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_149 7.2438 ms 80.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_150 7.6298 ms 75.9% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_153 9.0020 ms 64.4% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.9755 seconds and 0.0002 seconds precompiling for 12 choices
AUTOTUNE bmm(6400x196x80, 6400x80x196)
triton_bmm_167 1.8135 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
triton_bmm_171 1.8237 ms 99.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
triton_bmm_163 1.9374 ms 93.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4
triton_bmm_173 1.9384 ms 93.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_bmm_172 1.9415 ms 93.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=8
triton_bmm_166 1.9548 ms 92.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_bmm_168 1.9579 ms 92.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_bmm_161 2.0060 ms 90.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
triton_bmm_169 2.0255 ms 89.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
triton_bmm_164 2.0285 ms 89.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.9467 seconds and 0.0003 seconds precompiling for 20 choices
AUTOTUNE bmm(14x89600x80, 14x80x16)
triton_bmm_191 0.4690 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
triton_bmm_177 0.4731 ms 99.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=2
triton_bmm_180 0.4813 ms 97.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
triton_bmm_186 0.4813 ms 97.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
bmm 0.4844 ms 96.8%
triton_bmm_183 0.4915 ms 95.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_bmm_190 0.4915 ms 95.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_bmm_179 0.5038 ms 93.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
triton_bmm_178 0.5151 ms 91.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=2
triton_bmm_185 0.5243 ms 89.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.6925 seconds and 0.0002 seconds precompiling for 17 choices
AUTOTUNE bmm(14x89600x80, 14x80x14)
triton_bmm_193 0.5437 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=2
triton_bmm_196 0.5437 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
triton_bmm_202 0.5437 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
triton_bmm_207 0.5448 ms 99.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
triton_bmm_206 0.5581 ms 97.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_bmm_199 0.5591 ms 97.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_bmm_195 0.5755 ms 94.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
triton_bmm_194 0.6093 ms 89.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=2
triton_bmm_201 0.6369 ms 85.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
triton_bmm_204 0.6819 ms 79.7% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.7102 seconds and 0.0003 seconds precompiling for 17 choices
AUTOTUNE bmm(6400x196x196, 6400x196x80)
triton_bmm_219 1.8493 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_bmm_209 1.8842 ms 98.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4
triton_bmm_210 1.9661 ms 94.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
triton_bmm_218 1.9845 ms 93.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
triton_bmm_215 1.9958 ms 92.7% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=8
triton_bmm_217 2.0204 ms 91.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_bmm_214 2.0490 ms 90.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4
triton_bmm_223 2.0951 ms 88.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=8
triton_bmm_222 2.1197 ms 87.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
triton_bmm_224 2.1688 ms 85.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.8793 seconds and 0.0003 seconds precompiling for 20 choices
AUTOTUNE int_mm(78400x1280, 1280x1280)
triton_mm_235 1.9436 ms 100.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_236 2.1596 ms 90.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=128, BLOCK_N=256, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_237 2.1668 ms 89.7% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=256, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
_int_mm 2.1985 ms 88.4%
triton_mm_227 2.2579 ms 86.1% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
triton_mm_228 2.2804 ms 85.2% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_230 2.3101 ms 84.1% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_229 2.4607 ms 79.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_231 2.5446 ms 76.4% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_234 2.9829 ms 65.2% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.6236 seconds and 0.0002 seconds precompiling for 12 choices
AUTOTUNE int_mm(65536x1280, 1280x5120)
triton_mm_246 6.4451 ms 100.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_247 7.0001 ms 92.1% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=128, BLOCK_N=256, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_248 7.0216 ms 91.8% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=256, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
_int_mm 7.0963 ms 90.8%
triton_mm_239 7.1444 ms 90.2% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_238 7.5602 ms 85.2% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
triton_mm_240 7.7619 ms 83.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_241 8.0394 ms 80.2% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_242 8.3855 ms 76.9% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_245 9.9748 ms 64.6% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 1.0487 seconds and 0.0002 seconds precompiling for 12 choices
bf16 compiled runtime of the quantized block is 46.09ms and peak memory 3.55GB
With quantization, we have improved performance a bit more but memory usage increased significantly.
This is for two reasons:
Quantization adds overhead to the model since we need to quantize and dequantize the input and output. For small batch sizes this overhead can actually make the model go slower.
Even though we are doing a quantized matmul, such as
int8 x int8
, the result of the multiplication gets stored in an int32 tensor which is twice the size of the result from the non-quantized model. If we can avoid creating this int32 tensor, our memory usage will improve a lot.
We can fix #2 by fusing the integer matmul with the subsequent rescale operation since the final output will be bf16, if we immediately convert the int32 tensor to bf16 and instead store that we’ll get better performance in terms of both runtime and memory.
The way to do this, is to enable the option
force_fuse_int_mm_with_mul
in the inductor config.
del model_c, model, image
model, image = get_sam_model(only_one_block, batchsize)
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
torch._inductor.config.force_fuse_int_mm_with_mul = True
quantize_(model, int8_dynamic_activation_int8_weight())
if not TORCH_VERSION_AT_LEAST_2_5:
# needed for subclass + compile to work on older versions of pytorch
unwrap_tensor_subclass(model)
model_c = torch.compile(model, mode='max-autotune')
quant_res = benchmark(model_c, image)
print(f"bf16 compiled runtime of the fused quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
# bf16 compiled runtime of the fused quantized block is 18.78ms and peak memory 2.37GB
bf16 compiled runtime of the fused quantized block is 45.78ms and peak memory 3.54GB
The fusion improves performance by another small bit (about 6% over the baseline in total) and removes almost all the memory increase, the remaining amount (2.37GB quantized vs 2.24GB unquantized) is due to quantization overhead which cannot be helped.
We’re still not done though, we can apply a few general purpose optimizations to get our final best-case performance.
We can sometimes improve performance by disabling epilogue fusion since the autotuning process can be confused by fusions and choose bad kernel parameters.
We can apply coordinate descent tuning in all directions to enlarge the search area for kernel parameters.
del model_c, model, image
model, image = get_sam_model(only_one_block, batchsize)
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.force_fuse_int_mm_with_mul = True
quantize_(model, int8_dynamic_activation_int8_weight())
if not TORCH_VERSION_AT_LEAST_2_5:
# needed for subclass + compile to work on older versions of pytorch
unwrap_tensor_subclass(model)
model_c = torch.compile(model, mode='max-autotune')
quant_res = benchmark(model_c, image)
print(f"bf16 compiled runtime of the final quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
# bf16 compiled runtime of the final quantized block is 18.16ms and peak memory 2.39GB
bf16 compiled runtime of the final quantized block is 46.08ms and peak memory 3.54GB
As you can see, we’ve squeezed another small improvement from the model, taking our total improvement to over 10x compared to our original. To get a final estimate of the impact of quantization lets do an apples to apples comparison on the full model since the actual improvement will differ block by block depending on the shapes involved.
try:
del model_c, model, image
model, image = get_sam_model(False, batchsize)
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
model_c = torch.compile(model, mode='max-autotune')
quant_res = benchmark(model_c, image)
print(f"bf16 compiled runtime of the compiled full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
# bf16 compiled runtime of the compiled full model is 729.65ms and peak memory 23.96GB
del model_c, model, image
model, image = get_sam_model(False, batchsize)
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
quantize_(model, int8_dynamic_activation_int8_weight())
if not TORCH_VERSION_AT_LEAST_2_5:
# needed for subclass + compile to work on older versions of pytorch
unwrap_tensor_subclass(model)
model_c = torch.compile(model, mode='max-autotune')
quant_res = benchmark(model_c, image)
print(f"bf16 compiled runtime of the quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
# bf16 compiled runtime of the quantized full model is 677.28ms and peak memory 24.93GB
except Exception as e:
print("unable to run full model: ", e)
AUTOTUNE mm(78400x1280, 1280x3840)
triton_mm_277 11.3910 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_283 11.3920 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_280 11.4534 ms 99.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_281 11.4842 ms 99.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_276 11.5825 ms 98.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_278 11.6060 ms 98.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
mm 11.6787 ms 97.5%
triton_mm_273 11.8641 ms 96.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
triton_mm_274 11.8866 ms 95.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_272 12.1580 ms 93.7% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 2.0365 seconds and 0.0004 seconds precompiling for 20 choices
AUTOTUNE mm(78400x1280, 1280x1280)
triton_mm_366 3.8154 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_372 3.8195 ms 99.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_369 3.8390 ms 99.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_365 3.8482 ms 99.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_370 3.8748 ms 98.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_362 3.8881 ms 98.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
triton_mm_367 3.8994 ms 97.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_363 3.9383 ms 96.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
mm 3.9547 ms 96.5%
triton_mm_361 3.9864 ms 95.7% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 1.1058 seconds and 0.0002 seconds precompiling for 20 choices
AUTOTUNE mm(65536x1280, 1280x1280)
triton_mm_1360 3.1805 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_1353 3.1826 ms 99.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_1354 3.1846 ms 99.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_1357 3.2020 ms 99.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_1358 3.2266 ms 98.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_1355 3.2440 ms 98.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_1350 3.2788 ms 97.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
mm 3.2819 ms 96.9%
triton_mm_1351 3.2932 ms 96.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_1349 3.3423 ms 95.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 1.0183 seconds and 0.0002 seconds precompiling for 20 choices
AUTOTUNE convolution(16x3x1024x1024, 1280x3x16x16)
convolution 13.5096 ms 100.0%
triton_convolution2d_263 15.1030 ms 89.5% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=16, KERNEL_W=16, PADDING_H=0, PADDING_W=0, STRIDE_H=16, STRIDE_W=16, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_266 17.1694 ms 78.7% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=16, KERNEL_W=16, PADDING_H=0, PADDING_W=0, STRIDE_H=16, STRIDE_W=16, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_261 17.9057 ms 75.4% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=16, KERNEL_W=16, PADDING_H=0, PADDING_W=0, STRIDE_H=16, STRIDE_W=16, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_260 18.6890 ms 72.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=16, KERNEL_W=16, PADDING_H=0, PADDING_W=0, STRIDE_H=16, STRIDE_W=16, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_265 24.7747 ms 54.5% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=16, KERNEL_W=16, PADDING_H=0, PADDING_W=0, STRIDE_H=16, STRIDE_W=16, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_264 26.2052 ms 51.6% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=16, KERNEL_W=16, PADDING_H=0, PADDING_W=0, STRIDE_H=16, STRIDE_W=16, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_262 203.9194 ms 6.6% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=16, KERNEL_W=16, PADDING_H=0, PADDING_W=0, STRIDE_H=16, STRIDE_W=16, UNROLL=False, num_stages=1, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 2.5587 seconds and 0.0002 seconds precompiling for 8 choices
AUTOTUNE bmm(14x89600x80, 14x80x14)
triton_bmm_322 0.5448 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=2
triton_bmm_331 0.5448 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
triton_bmm_325 0.5458 ms 99.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
triton_bmm_336 0.5458 ms 99.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
triton_bmm_328 0.5560 ms 98.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_bmm_335 0.5581 ms 97.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_bmm_324 0.5786 ms 94.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
triton_bmm_323 0.6216 ms 87.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=2
triton_bmm_330 0.6380 ms 85.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
triton_bmm_333 0.6840 ms 79.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.7080 seconds and 0.0002 seconds precompiling for 17 choices
AUTOTUNE mm(65536x1280, 1280x5120)
triton_mm_391 12.6853 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_388 12.7580 ms 99.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_389 12.7918 ms 99.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_385 12.8020 ms 99.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_384 12.8123 ms 99.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_386 12.9228 ms 98.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
mm 13.0130 ms 97.5%
triton_mm_381 13.2055 ms 96.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
triton_mm_382 13.2178 ms 96.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_380 13.4124 ms 94.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 2.1084 seconds and 0.0002 seconds precompiling for 20 choices
AUTOTUNE mm(65536x5120, 5120x1280)
triton_mm_410 12.5430 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
mm 12.5532 ms 99.9%
triton_mm_407 12.6566 ms 99.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_405 12.7457 ms 98.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_403 12.8205 ms 97.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_404 12.9526 ms 96.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_408 13.0765 ms 95.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_400 13.3530 ms 93.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
triton_mm_401 13.5588 ms 92.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_411 13.7882 ms 91.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 2.1439 seconds and 0.0002 seconds precompiling for 20 choices
AUTOTUNE mm(65536x1280, 1280x3840)
triton_mm_1305 9.5181 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_1299 9.5273 ms 99.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_1302 9.5652 ms 99.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_1303 9.5713 ms 99.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_1298 9.6020 ms 99.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_1300 9.7239 ms 97.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
mm 9.8488 ms 96.6%
triton_mm_1295 9.9072 ms 96.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
triton_mm_1296 10.0311 ms 94.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_1294 10.1591 ms 93.7% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 1.7652 seconds and 0.0002 seconds precompiling for 20 choices
AUTOTUNE bmm(64x16384x80, 64x80x64)
triton_bmm_1309 0.5970 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4
triton_bmm_1312 0.6124 ms 97.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
triton_bmm_1325 0.6195 ms 96.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
triton_bmm_1320 0.6205 ms 96.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
triton_bmm_1316 0.6216 ms 96.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
triton_bmm_1311 0.6298 ms 94.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
triton_bmm_1315 0.6328 ms 94.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=8
triton_bmm_1319 0.6328 ms 94.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_bmm_1310 0.6369 ms 93.7% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
triton_bmm_1324 0.6410 ms 93.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.7890 seconds and 0.0003 seconds precompiling for 19 choices
AUTOTUNE bmm(64x16384x80, 64x80x64)
triton_bmm_1330 0.6840 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
triton_bmm_1334 0.6881 ms 99.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
triton_bmm_1338 0.6881 ms 99.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
triton_bmm_1327 0.6892 ms 99.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4
triton_bmm_1343 0.7076 ms 96.7% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
triton_bmm_1329 0.7332 ms 93.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
triton_bmm_1333 0.7363 ms 92.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=8
triton_bmm_1337 0.7373 ms 92.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
triton_bmm_1328 0.7516 ms 91.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
triton_bmm_1342 0.7588 ms 90.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.8034 seconds and 0.0003 seconds precompiling for 19 choices
AUTOTUNE convolution(16x1280x64x64, 256x1280x1x1)
triton_convolution2d_4803 0.6789 ms 100.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=1, STRIDE_W=1, UNROLL=True, num_stages=2, num_warps=4
triton_convolution2d_4806 0.7690 ms 88.3% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=1, STRIDE_W=1, UNROLL=True, num_stages=2, num_warps=8
convolution 0.7793 ms 87.1%
triton_convolution2d_4808 0.9288 ms 73.1% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=1, STRIDE_W=1, UNROLL=True, num_stages=2, num_warps=8
triton_convolution2d_4804 1.4008 ms 48.5% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=1, STRIDE_W=1, UNROLL=True, num_stages=2, num_warps=4
triton_convolution2d_4809 1.4172 ms 47.9% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=1, STRIDE_W=1, UNROLL=True, num_stages=2, num_warps=8
triton_convolution2d_4807 1.4469 ms 46.9% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=1, STRIDE_W=1, UNROLL=True, num_stages=2, num_warps=4
conv1x1_via_mm 2.2835 ms 29.7%
triton_convolution2d_4805 4.7688 ms 14.2% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=1, STRIDE_W=1, UNROLL=True, num_stages=1, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.4156 seconds and 0.0002 seconds precompiling for 9 choices
AUTOTUNE convolution(16x256x64x64, 256x256x3x3)
convolution 1.4469 ms 100.0%
triton_convolution2d_4811 2.6194 ms 55.2% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_4816 2.8088 ms 51.5% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_4813 3.5277 ms 41.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_4814 4.3295 ms 33.4% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_4810 4.8323 ms 29.9% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
triton_convolution2d_4815 6.2659 ms 23.1% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
triton_convolution2d_4812 9.1843 ms 15.8% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.4976 seconds and 0.0002 seconds precompiling for 8 choices
bf16 compiled runtime of the compiled full model is 2167.30ms and peak memory 15.29GB
AUTOTUNE int_mm(65536x1280, 1280x3840)
triton_mm_5630 4.8384 ms 100.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_5631 5.2705 ms 91.8% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=128, BLOCK_N=256, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_5632 5.2879 ms 91.5% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=256, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
_int_mm 5.3371 ms 90.7%
triton_mm_5623 5.5194 ms 87.7% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_5622 5.7416 ms 84.3% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
triton_mm_5624 5.9597 ms 81.2% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_5625 6.0078 ms 80.5% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_5626 6.3867 ms 75.8% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_5629 7.5131 ms 64.4% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.9052 seconds and 0.0002 seconds precompiling for 12 choices
AUTOTUNE int_mm(65536x1280, 1280x1280)
triton_mm_5677 1.6271 ms 100.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_5678 1.8145 ms 89.7% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=128, BLOCK_N=256, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
triton_mm_5679 1.8156 ms 89.6% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=256, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
_int_mm 1.8442 ms 88.2%
triton_mm_5669 1.8719 ms 86.9% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
triton_mm_5670 1.8954 ms 85.8% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_5672 1.9169 ms 84.9% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_5671 2.0439 ms 79.6% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
triton_mm_5673 2.1176 ms 76.8% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
triton_mm_5676 2.4852 ms 65.5% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.5979 seconds and 0.0002 seconds precompiling for 12 choices
unable to run full model: CUDA out of memory. Tried to allocate 8.00 GiB. GPU 0 has a total capacity of 22.07 GiB of which 3.19 GiB is free. Including non-PyTorch memory, this process has 18.85 GiB memory in use. Of the allocated memory 15.55 GiB is allocated by PyTorch, with 6.61 GiB allocated in private pools (e.g., CUDA Graphs), and 2.94 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
Conclusion¶
In this tutorial, we have learned about the quantization and optimization techniques on the example of the segment anything model.
In the end, we achieved a full-model apples to apples quantization speedup of about 7.7% on batch size 16 (677.28ms to 729.65ms). We can push this a bit further by increasing the batch size and optimizing other parts of the model. For example, this can be done with some form of flash attention.
For more information visit torchao and try it on your own models.
Total running time of the script: ( 13 minutes 29.557 seconds)