• Tutorials >
  • (prototype) GPU Quantization with TorchAO
Shortcuts

(prototype) GPU Quantization with TorchAO

Created On: Feb 06, 2024 | Last Updated: Jul 10, 2025 | Last Verified: Nov 05, 2024

Author: HDCharles

In this tutorial, we will walk you through the quantization and optimization of the popular segment anything model. These steps will mimic some of those taken to develop the segment-anything-fast repo. This step-by-step guide demonstrates how you can apply these techniques to speed up your own models, especially those that use transformers. To that end, we will focus on widely applicable techniques, such as optimizing performance with torch.compile and quantization and measure their impact.

Set up Your Environment

First, let’s configure your environment. This guide was written for CUDA 12.1. We have run this tutorial on an A100-PG509-200 power limited to 330.00 W. If you are using a different hardware, you might see different performance numbers.

> conda create -n myenv python=3.10
> pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
> pip install git+https://github.com/facebookresearch/segment-anything.git
> pip install git+https://github.com/pytorch/ao.git

Segment Anything Model checkpoint setup:

  1. Go to the segment-anything repo checkpoint and download the vit_h checkpoint. Alternatively, you can use wget (for example, wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth --directory-prefix=<path>).

  2. Pass in that directory by editing the code below to say:

{sam_checkpoint_base_path}=<path>
import torch
from torchao.quantization.quant_api import quantize_, Int8DynamicActivationInt8WeightConfig
from torchao.utils import unwrap_tensor_subclass, TORCH_VERSION_AT_LEAST_2_5
from segment_anything import sam_model_registry
from torch.utils.benchmark import Timer

sam_checkpoint_base_path = "data"
model_type = 'vit_h'
model_name = 'sam_vit_h_4b8939.pth'
checkpoint_path = f"{sam_checkpoint_base_path}/{model_name}"
batchsize = 16
only_one_block = True


@torch.no_grad()
def benchmark(f, *args, **kwargs):
    for _ in range(3):
        f(*args, **kwargs)
        torch.cuda.synchronize()

    torch.cuda.reset_peak_memory_stats()
    t0 = Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    res = t0.adaptive_autorange(.03, min_run_time=.2, max_run_time=20)
    return {'time':res.median * 1e3, 'memory': torch.cuda.max_memory_allocated()/1e9}

def get_sam_model(only_one_block=False, batchsize=1):
    sam = sam_model_registry[model_type](checkpoint=checkpoint_path).cuda()
    model = sam.image_encoder.eval()
    image = torch.randn(batchsize, 3, 1024, 1024, device='cuda')

    # code to use just a single block of the model
    if only_one_block:
        model = model.blocks[0]
        image = torch.randn(batchsize, 64, 64, 1280, device='cuda')
    return model, image

In this tutorial, we focus on quantizing the image_encoder because the inputs to it are statically sized while the prompt encoder and mask decoder have variable sizes which makes them harder to quantize.

We’ll focus on just a single block at first to make the analysis easier.

Let’s start by measuring the baseline runtime.

try:
    model, image = get_sam_model(only_one_block, batchsize)
    fp32_res = benchmark(model, image)
    print(f"base fp32 runtime of the model is {fp32_res['time']:0.2f}ms and peak memory {fp32_res['memory']:0.2f}GB")
    # base fp32 runtime of the model is 186.16ms and peak memory 6.33GB
except Exception as e:
    print("unable to run fp32 model: ", e)
base fp32 runtime of the model is 202.44ms and peak memory 6.33GB

We can achieve an instant performance boost by converting the model to bfloat16. The reason we opt for bfloat16 over fp16 is due to its dynamic range, which is comparable to that of fp32. Both bfloat16 and fp32 possess 8 exponential bits, whereas fp16 only has 4. This larger dynamic range helps protect us from overflow errors and other issues that can arise when scaling and rescaling tensors due to quantization.

model, image = get_sam_model(only_one_block, batchsize)
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
bf16_res = benchmark(model, image)
print(f"bf16 runtime of the block is {bf16_res['time']:0.2f}ms and peak memory {bf16_res['memory']: 0.2f}GB")
# bf16 runtime of the block is 25.43ms and peak memory  3.17GB
bf16 runtime of the block is 70.55ms and peak memory  3.17GB

Just this quick change improves runtime by a factor of ~7x in the tests we have conducted (186.16ms to 25.43ms).

Next, let’s use torch.compile with our model to see how much the performance improves.

model_c = torch.compile(model, mode='max-autotune')
comp_res = benchmark(model_c, image)
print(f"bf16 compiled runtime of the block is {comp_res['time']:0.2f}ms and peak memory {comp_res['memory']: 0.2f}GB")
# bf16 compiled runtime of the block is 19.95ms and peak memory  2.24GB
AUTOTUNE mm(65536x1280, 1280x5120)
  triton_mm_124 12.6689 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_118 12.7396 ms 99.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_121 12.7508 ms 99.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_122 12.7590 ms 99.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_117 12.9157 ms 98.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  mm 13.0232 ms 97.3%
  triton_mm_119 13.0365 ms 97.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_115 13.1000 ms 96.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_114 13.3181 ms 95.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  triton_mm_113 13.5660 ms 93.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 2.1139 seconds and 0.5790 seconds precompiling for 20 choices
AUTOTUNE mm(65536x5120, 5120x1280)
  triton_mm_143 12.5501 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  mm 12.5573 ms 99.9%
  triton_mm_140 12.6833 ms 99.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_138 12.7489 ms 98.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_137 12.8451 ms 97.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_136 12.8809 ms 97.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_141 13.0509 ms 96.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_133 13.3284 ms 94.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  triton_mm_134 13.5813 ms 92.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_144 13.7974 ms 91.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 2.1423 seconds and 0.0002 seconds precompiling for 20 choices
AUTOTUNE mm(78400x1280, 1280x1280)
  triton_mm_98 3.8164 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_105 3.8206 ms 99.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_102 3.8369 ms 99.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_99 3.8492 ms 99.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_103 3.8758 ms 98.5% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_95 3.8881 ms 98.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  triton_mm_100 3.8984 ms 97.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_96 3.9076 ms 97.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  mm 3.9567 ms 96.5%
  triton_mm_94 3.9711 ms 96.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 1.0995 seconds and 0.0002 seconds precompiling for 20 choices
AUTOTUNE mm(78400x1280, 1280x3840)
  triton_mm_16 11.3818 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_13 11.4729 ms 99.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_14 11.4760 ms 99.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_9 11.4811 ms 99.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_10 11.5149 ms 98.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_11 11.6111 ms 98.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_6 11.7309 ms 97.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  mm 11.7955 ms 96.5%
  triton_mm_7 11.9378 ms 95.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_5 12.1569 ms 93.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 2.0391 seconds and 0.0018 seconds precompiling for 20 choices
AUTOTUNE bmm(6400x196x80, 6400x80x196)
  triton_bmm_29 1.8238 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
  triton_bmm_33 1.8412 ms 99.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
  triton_bmm_25 1.9333 ms 94.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4
  triton_bmm_34 1.9416 ms 93.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=8
  triton_bmm_35 1.9446 ms 93.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_bmm_28 1.9630 ms 92.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_bmm_27 1.9692 ms 92.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
  triton_bmm_23 2.0081 ms 90.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
  triton_bmm_30 2.0470 ms 89.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_bmm_26 2.0480 ms 89.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.9431 seconds and 0.0003 seconds precompiling for 20 choices
AUTOTUNE bmm(14x89600x80, 14x80x16)
  triton_bmm_53 0.4689 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
  triton_bmm_39 0.4731 ms 99.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=2
  triton_bmm_42 0.4813 ms 97.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
  triton_bmm_48 0.4813 ms 97.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
  bmm 0.4823 ms 97.2%
  triton_bmm_52 0.4915 ms 95.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_bmm_45 0.4925 ms 95.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_bmm_41 0.5028 ms 93.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
  triton_bmm_40 0.5110 ms 91.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=2
  triton_bmm_47 0.5263 ms 89.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.6912 seconds and 0.0002 seconds precompiling for 17 choices
AUTOTUNE bmm(14x89600x80, 14x80x14)
  triton_bmm_55 0.5437 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=2
  triton_bmm_58 0.5437 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
  triton_bmm_64 0.5437 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
  triton_bmm_69 0.5448 ms 99.8% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
  triton_bmm_61 0.5540 ms 98.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_bmm_68 0.5571 ms 97.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_bmm_57 0.5745 ms 94.7% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
  triton_bmm_56 0.6175 ms 88.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=2
  triton_bmm_63 0.6380 ms 85.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
  triton_bmm_66 0.6871 ms 79.1% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.7043 seconds and 0.0002 seconds precompiling for 17 choices
AUTOTUNE bmm(6400x196x196, 6400x196x80)
  triton_bmm_81 1.8483 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_bmm_71 1.8883 ms 97.9% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4
  triton_bmm_72 1.9743 ms 93.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
  triton_bmm_80 1.9825 ms 93.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
  triton_bmm_77 2.0050 ms 92.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=8
  triton_bmm_79 2.0183 ms 91.6% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_bmm_76 2.0490 ms 90.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4
  triton_bmm_85 2.0920 ms 88.4% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=8
  triton_bmm_84 2.1207 ms 87.2% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
  triton_bmm_86 2.1678 ms 85.3% ACC_TYPE='tl.float32', ALLOW_TF32=False, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.8774 seconds and 0.0003 seconds precompiling for 20 choices
bf16 compiled runtime of the block is 59.55ms and peak memory  2.24GB

The first time this is run, you should see a sequence of AUTOTUNE outputs which occurs when inductor compares the performance between various kernel parameters for a kernel. This only happens once (unless you delete your cache) so if you run the cell again you should just get the benchmark output.

torch.compile yields about another 27% improvement. This brings the model to a reasonable baseline where we now have to work a bit harder for improvements.

Next, let’s apply quantization. Quantization for GPUs comes in three main forms in torchao which is just native pytorch+python code. This includes:

  • int8 dynamic quantization

  • int8 weight-only quantization

  • int4 weight-only quantization

Different models, or sometimes different layers in a model can require different techniques. For models which are heavily compute bound, dynamic quantization tends to work the best since it swaps the normal expensive floating point matmul ops with integer versions. Weight-only quantization works better in memory bound situations where the benefit comes from loading less weight data, rather than doing less computation. The torchao APIs:

Int8DynamicActivationInt8WeightConfig(), Int8WeightOnlyConfig() or Int4WeightOnlyConfig()

can be used to easily apply the desired quantization technique and then once the model is compiled with torch.compile with max-autotune, quantization is complete and we can see our speedup.

Note

You might experience issues with these on older versions of PyTorch. If you run into an issue, you can use apply_dynamic_quant and apply_weight_only_int8_quant instead as drop in replacement for the two above (no replacement for int4).

The difference between the two APIs is that the Int8DynamicActivationInt8WeightConfig API alters the weight tensor of the linear module so instead of doing a normal linear, it does a quantized operation. This is helpful when you have non-standard linear ops that do more than one thing. The apply APIs directly swap the linear modules for a quantized module which works on older versions but doesn’t work with non-standard linear modules.

In this case Segment Anything is compute-bound so we’ll use dynamic quantization:

del model_c, model, image
model, image = get_sam_model(only_one_block, batchsize)
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
quantize_(model, Int8DynamicActivationInt8WeightConfig())
if not TORCH_VERSION_AT_LEAST_2_5:
    # needed for subclass + compile to work on older versions of pytorch
    unwrap_tensor_subclass(model)
model_c = torch.compile(model, mode='max-autotune')
quant_res = benchmark(model_c, image)
print(f"bf16 compiled runtime of the quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
# bf16 compiled runtime of the quantized block is 19.04ms and peak memory  3.58GB
AUTOTUNE int_mm(65536x5120, 5120x1280)
  triton_mm_257 6.2730 ms 100.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  _int_mm 6.4666 ms 97.0%
  triton_mm_258 6.6222 ms 94.7% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=128, BLOCK_N=256, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_259 6.6898 ms 93.8% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=256, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_251 6.7267 ms 93.3% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_250 6.7410 ms 93.1% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_252 6.9130 ms 90.7% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_256 7.0277 ms 89.3% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=8
  triton_mm_253 7.1926 ms 87.2% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_249 7.1977 ms 87.2% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.9777 seconds and 0.1559 seconds precompiling for 12 choices
AUTOTUNE int_mm(78400x1280, 1280x3840)
  triton_mm_154 5.7958 ms 100.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_155 6.3027 ms 92.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=128, BLOCK_N=256, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_156 6.3201 ms 91.7% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=256, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  _int_mm 6.4051 ms 90.5%
  triton_mm_147 6.4491 ms 89.9% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_146 6.7779 ms 85.5% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  triton_mm_148 6.9765 ms 83.1% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_149 7.2397 ms 80.1% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_150 7.6257 ms 76.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_153 8.9364 ms 64.9% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.9730 seconds and 0.0002 seconds precompiling for 12 choices
AUTOTUNE bmm(6400x196x80, 6400x80x196)
  triton_bmm_167 1.8156 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
  triton_bmm_171 1.8248 ms 99.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
  triton_bmm_163 1.9364 ms 93.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4
  triton_bmm_173 1.9405 ms 93.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_bmm_172 1.9436 ms 93.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=8
  triton_bmm_168 1.9507 ms 93.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_bmm_166 1.9558 ms 92.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_bmm_161 2.0060 ms 90.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
  triton_bmm_164 2.0296 ms 89.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=8
  triton_bmm_169 2.0326 ms 89.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.9460 seconds and 0.0002 seconds precompiling for 20 choices
AUTOTUNE bmm(14x89600x80, 14x80x16)
  triton_bmm_191 0.4680 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
  triton_bmm_177 0.4731 ms 98.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=2
  triton_bmm_180 0.4813 ms 97.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
  triton_bmm_186 0.4813 ms 97.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
  bmm 0.4823 ms 97.0%
  triton_bmm_190 0.4905 ms 95.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_bmm_183 0.4915 ms 95.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_bmm_179 0.5038 ms 92.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
  triton_bmm_178 0.5151 ms 90.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=2
  triton_bmm_185 0.5253 ms 89.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.6911 seconds and 0.0003 seconds precompiling for 17 choices
AUTOTUNE bmm(14x89600x80, 14x80x14)
  triton_bmm_196 0.5437 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
  triton_bmm_193 0.5448 ms 99.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=2
  triton_bmm_202 0.5448 ms 99.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
  triton_bmm_207 0.5458 ms 99.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
  triton_bmm_199 0.5571 ms 97.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_bmm_206 0.5571 ms 97.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_bmm_195 0.5755 ms 94.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
  triton_bmm_194 0.6093 ms 89.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=2
  triton_bmm_201 0.6400 ms 85.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
  triton_bmm_204 0.6840 ms 79.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.7040 seconds and 0.0003 seconds precompiling for 17 choices
AUTOTUNE bmm(6400x196x196, 6400x196x80)
  triton_bmm_219 1.8493 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_bmm_209 1.8893 ms 97.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4
  triton_bmm_210 1.9743 ms 93.7% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
  triton_bmm_218 1.9835 ms 93.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
  triton_bmm_215 2.0122 ms 91.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=8
  triton_bmm_217 2.0193 ms 91.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_bmm_214 2.0480 ms 90.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4
  triton_bmm_223 2.0879 ms 88.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=8
  triton_bmm_222 2.1197 ms 87.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
  triton_bmm_224 2.1688 ms 85.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.8789 seconds and 0.0003 seconds precompiling for 20 choices
AUTOTUNE int_mm(78400x1280, 1280x1280)
  triton_mm_235 1.9405 ms 100.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_236 2.1545 ms 90.1% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=128, BLOCK_N=256, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_237 2.1699 ms 89.4% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=256, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  _int_mm 2.1965 ms 88.3%
  triton_mm_228 2.2313 ms 87.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_227 2.2630 ms 85.7% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  triton_mm_230 2.3091 ms 84.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_229 2.4576 ms 79.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_231 2.5364 ms 76.5% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_234 2.9860 ms 65.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.6225 seconds and 0.0002 seconds precompiling for 12 choices
AUTOTUNE int_mm(65536x1280, 1280x5120)
  triton_mm_246 6.4430 ms 100.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_247 7.0072 ms 91.9% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=128, BLOCK_N=256, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_248 7.0216 ms 91.8% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=256, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  _int_mm 7.0748 ms 91.1%
  triton_mm_239 7.2141 ms 89.3% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_238 7.5468 ms 85.4% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  triton_mm_240 7.7793 ms 82.8% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_241 8.0364 ms 80.2% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_242 8.3753 ms 76.9% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_245 9.9881 ms 64.5% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 1.0466 seconds and 0.0002 seconds precompiling for 12 choices
bf16 compiled runtime of the quantized block is 45.98ms and peak memory  3.55GB

With quantization, we have improved performance a bit more but memory usage increased significantly.

This is for two reasons:

  1. Quantization adds overhead to the model since we need to quantize and dequantize the input and output. For small batch sizes this overhead can actually make the model go slower.

  2. Even though we are doing a quantized matmul, such as int8 x int8, the result of the multiplication gets stored in an int32 tensor which is twice the size of the result from the non-quantized model. If we can avoid creating this int32 tensor, our memory usage will improve a lot.

We can fix #2 by fusing the integer matmul with the subsequent rescale operation since the final output will be bf16, if we immediately convert the int32 tensor to bf16 and instead store that we’ll get better performance in terms of both runtime and memory.

The way to do this, is to enable the option force_fuse_int_mm_with_mul in the inductor config.

del model_c, model, image
model, image = get_sam_model(only_one_block, batchsize)
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
torch._inductor.config.force_fuse_int_mm_with_mul = True
quantize_(model, Int8DynamicActivationInt8WeightConfig())
if not TORCH_VERSION_AT_LEAST_2_5:
    # needed for subclass + compile to work on older versions of pytorch
    unwrap_tensor_subclass(model)
model_c = torch.compile(model, mode='max-autotune')
quant_res = benchmark(model_c, image)
print(f"bf16 compiled runtime of the fused quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
# bf16 compiled runtime of the fused quantized block is 18.78ms and peak memory  2.37GB
bf16 compiled runtime of the fused quantized block is 46.06ms and peak memory  3.54GB

The fusion improves performance by another small bit (about 6% over the baseline in total) and removes almost all the memory increase, the remaining amount (2.37GB quantized vs 2.24GB unquantized) is due to quantization overhead which cannot be helped.

We’re still not done though, we can apply a few general purpose optimizations to get our final best-case performance.

  1. We can sometimes improve performance by disabling epilogue fusion since the autotuning process can be confused by fusions and choose bad kernel parameters.

  2. We can apply coordinate descent tuning in all directions to enlarge the search area for kernel parameters.

del model_c, model, image
model, image = get_sam_model(only_one_block, batchsize)
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
torch._inductor.config.epilogue_fusion = False
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.force_fuse_int_mm_with_mul = True
quantize_(model, Int8DynamicActivationInt8WeightConfig())
if not TORCH_VERSION_AT_LEAST_2_5:
    # needed for subclass + compile to work on older versions of pytorch
    unwrap_tensor_subclass(model)
model_c = torch.compile(model, mode='max-autotune')
quant_res = benchmark(model_c, image)
print(f"bf16 compiled runtime of the final quantized block is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
# bf16 compiled runtime of the final quantized block is 18.16ms and peak memory  2.39GB
bf16 compiled runtime of the final quantized block is 46.17ms and peak memory  3.54GB

As you can see, we’ve squeezed another small improvement from the model, taking our total improvement to over 10x compared to our original. To get a final estimate of the impact of quantization lets do an apples to apples comparison on the full model since the actual improvement will differ block by block depending on the shapes involved.

try:
    del model_c, model, image
    model, image = get_sam_model(False, batchsize)
    model = model.to(torch.bfloat16)
    image = image.to(torch.bfloat16)
    model_c = torch.compile(model, mode='max-autotune')
    quant_res = benchmark(model_c, image)
    print(f"bf16 compiled runtime of the compiled full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
    # bf16 compiled runtime of the compiled full model is 729.65ms and peak memory  23.96GB

    del model_c, model, image
    model, image = get_sam_model(False, batchsize)
    model = model.to(torch.bfloat16)
    image = image.to(torch.bfloat16)
    quantize_(model, Int8DynamicActivationInt8WeightConfig())
    if not TORCH_VERSION_AT_LEAST_2_5:
        # needed for subclass + compile to work on older versions of pytorch
        unwrap_tensor_subclass(model)
    model_c = torch.compile(model, mode='max-autotune')
    quant_res = benchmark(model_c, image)
    print(f"bf16 compiled runtime of the quantized full model is {quant_res['time']:0.2f}ms and peak memory {quant_res['memory']: 0.2f}GB")
    # bf16 compiled runtime of the quantized full model is 677.28ms and peak memory  24.93GB
except Exception as e:
    print("unable to run full model: ", e)
AUTOTUNE mm(78400x1280, 1280x3840)
  triton_mm_277 11.3859 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_283 11.3920 ms 99.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_280 11.4616 ms 99.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_281 11.5896 ms 98.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_276 11.6101 ms 98.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  mm 11.6849 ms 97.4%
  triton_mm_278 11.7044 ms 97.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_273 11.8395 ms 96.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  triton_mm_274 11.9173 ms 95.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_272 12.1457 ms 93.7% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 2.0492 seconds and 0.0003 seconds precompiling for 20 choices
AUTOTUNE mm(78400x1280, 1280x1280)
  triton_mm_366 3.8154 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_365 3.8195 ms 99.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_372 3.8205 ms 99.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_369 3.8400 ms 99.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_370 3.8411 ms 99.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_362 3.8902 ms 98.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  triton_mm_367 3.8994 ms 97.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_363 3.9086 ms 97.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  mm 3.9578 ms 96.4%
  triton_mm_361 3.9956 ms 95.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 1.1050 seconds and 0.0002 seconds precompiling for 20 choices
AUTOTUNE mm(65536x1280, 1280x1280)
  triton_mm_1360 3.1764 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_1357 3.1949 ms 99.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_1358 3.1969 ms 99.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_1354 3.2061 ms 99.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_1353 3.2072 ms 99.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_1355 3.2399 ms 98.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_1350 3.2471 ms 97.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  mm 3.2809 ms 96.8%
  triton_mm_1351 3.3290 ms 95.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_1349 3.3352 ms 95.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 1.0200 seconds and 0.0002 seconds precompiling for 20 choices
AUTOTUNE convolution(16x3x1024x1024, 1280x3x16x16)
  convolution 13.6366 ms 100.0%
  triton_convolution2d_263 15.1460 ms 90.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=16, KERNEL_W=16, PADDING_H=0, PADDING_W=0, STRIDE_H=16, STRIDE_W=16, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_266 17.1796 ms 79.4% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=16, KERNEL_W=16, PADDING_H=0, PADDING_W=0, STRIDE_H=16, STRIDE_W=16, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_261 17.8330 ms 76.5% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=16, KERNEL_W=16, PADDING_H=0, PADDING_W=0, STRIDE_H=16, STRIDE_W=16, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_260 18.6665 ms 73.1% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=16, KERNEL_W=16, PADDING_H=0, PADDING_W=0, STRIDE_H=16, STRIDE_W=16, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_265 24.7460 ms 55.1% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=16, KERNEL_W=16, PADDING_H=0, PADDING_W=0, STRIDE_H=16, STRIDE_W=16, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_264 26.1878 ms 52.1% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=16, KERNEL_W=16, PADDING_H=0, PADDING_W=0, STRIDE_H=16, STRIDE_W=16, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_262 203.3981 ms 6.7% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=16, KERNEL_W=16, PADDING_H=0, PADDING_W=0, STRIDE_H=16, STRIDE_W=16, UNROLL=False, num_stages=1, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 2.5561 seconds and 0.0002 seconds precompiling for 8 choices
AUTOTUNE bmm(14x89600x80, 14x80x14)
  triton_bmm_322 0.5448 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=2
  triton_bmm_325 0.5458 ms 99.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
  triton_bmm_331 0.5458 ms 99.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
  triton_bmm_336 0.5468 ms 99.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
  triton_bmm_328 0.5540 ms 98.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_bmm_335 0.5601 ms 97.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_bmm_324 0.5786 ms 94.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
  triton_bmm_323 0.6205 ms 87.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=32, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=2
  triton_bmm_330 0.6400 ms 85.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
  triton_bmm_333 0.6840 ms 79.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=16, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.7069 seconds and 0.0002 seconds precompiling for 17 choices
AUTOTUNE mm(65536x1280, 1280x5120)
  triton_mm_391 12.6915 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_388 12.7713 ms 99.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_384 12.8174 ms 99.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_385 12.8236 ms 99.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_389 12.8922 ms 98.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  mm 13.0171 ms 97.5%
  triton_mm_386 13.0509 ms 97.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_382 13.2434 ms 95.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_381 13.3878 ms 94.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  triton_mm_380 13.5905 ms 93.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 2.1151 seconds and 0.0002 seconds precompiling for 20 choices
AUTOTUNE mm(65536x5120, 5120x1280)
  triton_mm_410 12.5491 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  mm 12.5542 ms 100.0%
  triton_mm_407 12.6351 ms 99.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_405 12.7570 ms 98.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_403 12.7836 ms 98.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_404 13.0570 ms 96.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_408 13.1860 ms 95.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_400 13.2198 ms 94.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  triton_mm_401 13.7277 ms 91.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_411 13.8025 ms 90.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 2.1546 seconds and 0.0002 seconds precompiling for 20 choices
AUTOTUNE mm(65536x1280, 1280x3840)
  triton_mm_1305 9.5212 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_1302 9.5754 ms 99.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_1298 9.6082 ms 99.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_1303 9.6758 ms 98.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_1299 9.7157 ms 98.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  mm 9.7628 ms 97.5%
  triton_mm_1300 9.8458 ms 96.7% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_1296 9.9400 ms 95.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_1295 9.9809 ms 95.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  triton_mm_1294 10.2287 ms 93.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 1.7708 seconds and 0.0002 seconds precompiling for 20 choices
AUTOTUNE bmm(64x16384x80, 64x80x64)
  triton_bmm_1309 0.5970 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4
  triton_bmm_1312 0.6124 ms 97.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
  triton_bmm_1325 0.6175 ms 96.7% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
  triton_bmm_1316 0.6216 ms 96.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
  triton_bmm_1320 0.6226 ms 95.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
  triton_bmm_1311 0.6298 ms 94.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
  triton_bmm_1315 0.6298 ms 94.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=8
  triton_bmm_1319 0.6328 ms 94.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_bmm_1310 0.6369 ms 93.7% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
  triton_bmm_1324 0.6410 ms 93.1% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.7874 seconds and 0.0002 seconds precompiling for 19 choices
AUTOTUNE bmm(64x16384x80, 64x80x64)
  triton_bmm_1330 0.6912 ms 100.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=32, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
  triton_bmm_1334 0.6922 ms 99.9% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=4
  triton_bmm_1338 0.6923 ms 99.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=4, num_warps=4
  triton_bmm_1327 0.6963 ms 99.3% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=32, BLOCK_N=32, EVEN_K=False, GROUP_M=8, num_stages=2, num_warps=4
  triton_bmm_1343 0.7137 ms 96.8% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
  triton_bmm_1329 0.7383 ms 93.6% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=32, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
  triton_bmm_1333 0.7393 ms 93.5% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=8
  triton_bmm_1337 0.7434 ms 93.0% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
  triton_bmm_1328 0.7578 ms 91.2% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=32, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=5, num_warps=8
  triton_bmm_1342 0.7649 ms 90.4% ACC_TYPE='tl.float32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=128, BLOCK_N=64, EVEN_K=False, GROUP_M=8, num_stages=3, num_warps=4
SingleProcess AUTOTUNE benchmarking takes 0.8044 seconds and 0.0003 seconds precompiling for 19 choices
AUTOTUNE convolution(16x1280x64x64, 256x1280x1x1)
  triton_convolution2d_4803 0.6799 ms 100.0% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=1, STRIDE_W=1, UNROLL=True, num_stages=2, num_warps=4
  triton_convolution2d_4806 0.7690 ms 88.4% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=1, STRIDE_W=1, UNROLL=True, num_stages=2, num_warps=8
  convolution 0.7803 ms 87.1%
  triton_convolution2d_4808 0.9482 ms 71.7% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=1, STRIDE_W=1, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_4804 1.3998 ms 48.6% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=1, STRIDE_W=1, UNROLL=True, num_stages=2, num_warps=4
  triton_convolution2d_4809 1.4193 ms 47.9% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=1, STRIDE_W=1, UNROLL=True, num_stages=2, num_warps=8
  triton_convolution2d_4807 1.4479 ms 47.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=1, STRIDE_W=1, UNROLL=True, num_stages=2, num_warps=4
  conv1x1_via_mm 2.2589 ms 30.1%
  triton_convolution2d_4805 4.7555 ms 14.3% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=1, KERNEL_W=1, PADDING_H=0, PADDING_W=0, STRIDE_H=1, STRIDE_W=1, UNROLL=True, num_stages=1, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.4166 seconds and 0.0002 seconds precompiling for 9 choices
AUTOTUNE convolution(16x256x64x64, 256x256x3x3)
  convolution 1.4469 ms 100.0%
  triton_convolution2d_4811 2.6112 ms 55.4% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_4816 2.8273 ms 51.2% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=256, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_4813 3.5226 ms 41.1% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_4814 4.3182 ms 33.5% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_4810 4.8333 ms 29.9% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=4
  triton_convolution2d_4815 6.2812 ms 23.0% ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=256, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=2, num_warps=8
  triton_convolution2d_4812 9.2129 ms 15.7% ALLOW_TF32=True, BLOCK_K=16, BLOCK_M=1024, BLOCK_N=16, GROUPS=1, KERNEL_H=3, KERNEL_W=3, PADDING_H=1, PADDING_W=1, STRIDE_H=1, STRIDE_W=1, UNROLL=False, num_stages=1, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.4956 seconds and 0.0002 seconds precompiling for 8 choices
bf16 compiled runtime of the compiled full model is 2168.24ms and peak memory  15.29GB
AUTOTUNE int_mm(65536x1280, 1280x3840)
  triton_mm_5630 4.8394 ms 100.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_5631 5.2756 ms 91.7% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=128, BLOCK_N=256, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_5632 5.2900 ms 91.5% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=256, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  _int_mm 5.3371 ms 90.7%
  triton_mm_5623 5.4436 ms 88.9% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_5622 5.6986 ms 84.9% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  triton_mm_5624 5.8470 ms 82.8% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_5625 6.0037 ms 80.6% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_5626 6.3754 ms 75.9% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_5629 7.5069 ms 64.5% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.9030 seconds and 0.0002 seconds precompiling for 12 choices
AUTOTUNE int_mm(65536x1280, 1280x1280)
  triton_mm_5677 1.6271 ms 100.0% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=64, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_5678 1.8125 ms 89.8% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=128, BLOCK_N=256, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  triton_mm_5679 1.8176 ms 89.5% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=128, BLOCK_M=256, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=8
  _int_mm 1.8350 ms 88.7%
  triton_mm_5669 1.8627 ms 87.4% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=4
  triton_mm_5670 1.8883 ms 86.2% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_5672 1.9128 ms 85.1% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=64, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_5671 2.0152 ms 80.7% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=3, num_warps=4
  triton_mm_5673 2.1115 ms 77.1% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=64, EVEN_K=True, GROUP_M=8, num_stages=4, num_warps=8
  triton_mm_5676 2.4791 ms 65.6% ACC_TYPE='tl.int32', ALLOW_TF32=True, BLOCK_K=32, BLOCK_M=128, BLOCK_N=128, EVEN_K=True, GROUP_M=8, num_stages=2, num_warps=8
SingleProcess AUTOTUNE benchmarking takes 0.5967 seconds and 0.0002 seconds precompiling for 12 choices
unable to run full model:  CUDA out of memory. Tried to allocate 8.00 GiB. GPU 0 has a total capacity of 22.07 GiB of which 3.19 GiB is free. Including non-PyTorch memory, this process has 18.85 GiB memory in use. Of the allocated memory 15.55 GiB is allocated by PyTorch, with 6.61 GiB allocated in private pools (e.g., CUDA Graphs), and 2.94 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

Conclusion

In this tutorial, we have learned about the quantization and optimization techniques on the example of the segment anything model.

In the end, we achieved a full-model apples to apples quantization speedup of about 7.7% on batch size 16 (677.28ms to 729.65ms). We can push this a bit further by increasing the batch size and optimizing other parts of the model. For example, this can be done with some form of flash attention.

For more information visit torchao and try it on your own models.

Total running time of the script: (12 minutes 54.234 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources