Quantized Inference#
Created On: Mar 24, 2026 | Last Updated On: Mar 24, 2026
For inference, we support dynamic and weight-only quantization of torch.nn.funtional.linear across various dtype configurations.
The pseudocode is as follows:
# high precision (baseline)
output_bf16 = input_bf16 @ weight_bf16.t()
# dynamic quantization (shown for fp8 rowwise)
output_bf16 = to_fp8(input_bf16) @ to_fp8(weight_fp8.t())
# weight-only quantization (shown for int4)
output_bf16 = input_bf16 @ weight_int4.t()
Inference Workflows#
Below are the stable and near-stable inference workflows in torchao:
weight dtype |
act dtype |
summary |
|---|---|---|
float8 |
float8 |
|
float8 |
bf16 |
|
int8 |
int8 |
|
int8 |
bf16 |
|
mxfp8 |
mxfp8 |
|
int4 |
bf16 |
|
int4 |
float8 |
|
nvfp4 |
bf16 |
|
nvfp4 |
nvfp4 |
|
mxfp4 |
mxfp4 |
|
intx |
bf16 |
|
intx |
int8 |
|
uintx (4/8-bit) |
bf16 |
|
uintx (4/8-bit) |
int8 |
|
int4 |
int8 |
|
Accuracy benchmarks#
All the following benchmarks are for meta-llama/Llama-3.1-8B using lm-eval.
weight |
activation |
wikitext-perplexity |
winogrande |
checkpoint size (GB) |
|---|---|---|---|---|
bfloat16 |
bfloat16 |
7.3315 |
0.7380 |
16.1 |
float8_rowwise |
float8_rowwise |
7.4197 |
0.7388 |
9.1 |
int8_rowwise |
bfloat16 |
7.3451 |
0.7340 |
9.1 |
int8_rowwise |
int8_rowwise |
7.4535 |
0.7285 |
9.1 |
mxfp8 |
mxfp8 |
7.6034 |
0.7316 |
9.32 |
nvfp4 |
nvfp4 |
8.4459 |
0.7135 |
6.05 |
To reproduce, run the following command:
// on an H100
SKIP_VLLM=1 ./benchmarks/quantization/measure_accuracy_and_performance.sh h100
// on a B200
SKIP_VLLM=1 ./benchmarks/quantization/measure_accuracy_and_performance.sh b200
Performance benchmarks#
e2e model level benchmarks#
All the following benchmarks are for meta-llama/Llama-3.1-8B using torch==2.9.0 and vllm==0.13.0.
NVIDIA B200#
weight |
activation |
prefill toks/s |
decode toks/s |
prefill_speedup |
decode_speedup |
|---|---|---|---|---|---|
bfloat16 |
bfloat16 |
59099.9 |
14380 |
1 |
1 |
mxfp8 |
mxfp8 |
TODO(https://github.com/pytorch/ao/issues/3549) |
- |
- |
- |
nvfp4 |
nvfp4 |
102786 |
15218.9 |
1.739 |
1.058 |
float8_rowwise |
float8_rowwise |
69313.7 |
15984 |
1.173 |
1.112 |
NVIDIA H100#
weight |
activation |
prefill toks/s |
decode toks/s |
prefill_speedup |
decode_speedup |
|---|---|---|---|---|---|
bfloat16 |
bfloat16 |
30946.5 |
6612 |
1 |
1 |
float8_rowwise |
float8_rowwise |
45312.5 |
8025.95 |
1.464 |
1.214 |
int8_rowwwise |
bfloat16 |
28231.9 |
4309.8 |
0.912 |
0.652 |
int4 |
float8_rowwise |
TODO(https://github.com/pytorch/ao/issues/3550) |
- |
- |
- |
To reproduce these benchmarks, run
// on an h100
SKIP_LM_EVAL=1 ./benchmarks/quantization/measure_accuracy_and_performance.sh h100
// on a b200
SKIP_LM_EVAL=1 ./benchmarks/quantization/measure_accuracy_and_performance.sh h100
// under the hood, the actual vllm benchmark is doing the following:
// 1. prefill
vllm bench throughput --num_prompts 32 --input_len 4096 --output_len 32 --max_model_len 4128
// 2. decode
vllm bench throughput --num_prompts 128 --input_len 32 --output_len 2048 --max_model_len 2080
Microbenchmarks and roofline model#
The following set of microbenchmarks show the roofline expected and observed execution times of
a ReLU -> Linear toy model across a sweep of (M, K, N) shapes, with the activation
shaped (M, K) and the weight shaped (K, N). This can be used to estimate expected speedup
of quantizing torch.nn.Linear layers with various recipes based on shapes in your model
during inference.
Explanation: to see speedup from quantization of activation -> gemm during inference, we want
(bf16_activation_time + bf16_gemm_time) > (bf16_activation_and_quantize_tensor_time + fp8_gemm_time)
In a perfect world (and our roofline model),
bf16_activation_time > bf16_activation_and_quantize_tensor_timeis always true becausebf16_activationreads+writesM*K*2 bytesandbf16_activation_and_quantize_tensoris a single fused kernel that reads+writesM*K*1.5 bytes.bf16_gemm_time>fp8_gemm_timeis always true as fp8 gemm has ~2x peak efficiency vs bf16 gemm
In the real world, both (1) and (2) are not always true due to kernel launch overhead, kernel efficiency, lack of fusion for some recipes, etc. Therefore, the observed speedups are often significantly below the roofline peak. In general you should expect the observed speedup from inference quantization to increase as MKN increases.
NVIDIA B200#
# `r_fp8_gemm_and_ovhd_spdp` is the roofline expected speedup of the
# quantized ReLU -> Linear layer vs high precision version
# `b_fp8_e2e_spdp` is the observed speedup of the quantized
# ReLU -> Linear layer vs high precision version
#
# mxfp8
#
> python benchmarks/float8/float8_inference_roofline.py --recipe_name mxfp8_cublas --enable_fusion_modeling True --skip_printing_detailed_metrics True
...
GPU NVIDIA B200
torch version 2.12.0.dev20260218+cu130
torchao version 0.17.0+git3075bb624
...
fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp
0 1024 1024 1024 1.00 0.93
1 2048 2048 2048 1.75 1.20
2 4096 4096 4096 1.90 1.46
3 8192 8192 8192 1.94 1.76
4 16384 16384 16384 1.97 1.77
#
# nvfp4 with dynamic global scaling
#
> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4 --enable_fusion_modeling True --skip_printing_detailed_metrics True
...
GPU NVIDIA B200
torch version 2.12.0.dev20260312+cu130
torchao version 0.17.0+gitbd7717d20
...
fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp
0 1024 1024 1024 1.00 0.46
1 2048 2048 2048 2.36 0.76
2 4096 4096 4096 2.89 1.37
3 8192 8192 8192 3.32 1.97
4 16384 16384 16384 3.62 2.77
#
# nvfp4 with static global scaling (user API in progress)
#
> python benchmarks/float8/float8_inference_roofline.py --recipe_name nvfp4_static --enable_fusion_modeling True --skip_printing_detailed_metrics True
...
GPU NVIDIA B200
torch version 2.12.0.dev20260312+cu130
torchao version 0.17.0+gitbd7717d20
...
fwd_M fwd_K fwd_N r_fp8_gemm_and_ovhd_spdp b_fp8_e2e_spdp
0 1024 1024 1024 1.00 0.55
1 2048 2048 2048 2.74 0.95
2 4096 4096 4096 3.42 1.69
3 8192 8192 8192 3.67 2.29
4 16384 16384 16384 3.82 2.98
e2e flux-1.schnell benchmarks#
These benchmarks compare accuracy and performance of torchao inference quantization on the flux-1.schnell model.
For accuracy, we measure the LPIPS score between images generated by the quantized model and the high precision (bfloat16) baseline, averaged over the prompts from the sayakpaul/drawbench dataset — lower is better, with 0 meaning identical.
Note that this benchmark optimizes for speed of iteration and does not represent the best possible metrics someone could achieve on this model. Instead, this is an apples-to-apples comparison intended to compare different quantization recipes at a high level, and measure performance improvements.
experiment |
lpips_avg |
time_s_bsz_1 |
speedup_bsz_1 |
time_s_bsz_4 |
speedup_bsz_4 |
|---|---|---|---|---|---|
bfloat16 |
0 |
0.4178 |
1.00 |
1.4914 |
1.00 |
float8_rowwise |
0.1236 |
0.3455 |
1.21 |
1.1986 |
1.24 |
mxfp8 |
0.1260 |
0.3673 |
1.14 |
1.2820 |
1.16 |
nvfp4 |
0.2694 |
0.3203 |
1.30 |
1.0913 |
1.37 |
To reproduce, run:
./benchmarks/quantization/eval_accuracy_and_perf_of_flux.sh
Other Available Quantization Techniques#
Int8DynamicActivationIntxWeightConfig Quantization#
We have kernels that do 8-bit dynamic quantization of activations and uintx groupwise quantization of weights. These kernels are experimental and can only be run on a device with an ARM CPU (e.g., a Mac computers with Apple silicon). The benchmarks below were run on an M1 Mac Pro, with 8 perf cores, and 2 efficiency cores, and 32GB of RAM. In all cases, torch.compile was used.
Model |
Technique |
Tokens/Second |
Memory Bandwidth (GB/s) |
Peak Memory (GB) |
Model Size (GB) |
|---|---|---|---|---|---|
Llama-3.1-8B |
Base (bfloat16) |
1.24 |
18.62 |
NA |
15.01 |
int8_dynamic_activation_intx_weight-4-256-false |
16.03 |
65.81 |
NA |
4.11 |
|
int8_dynamic_activation_intx_weight-3-256-false |
18.94 |
59.97 |
NA |
3.17 |
You can try out these apis with the quantize_ api as above alongside the config Int8DynamicActivationIntxWeightConfig. An example can be found in torchao/_models/llama/generate.py.
Codebook Quantization#
The benchmarks below were run on a single NVIDIA-A6000 GPU.
Model |
Technique |
wikitext-perplexity |
Tokens/Second |
Memory Bandwidth (GB/s) |
Peak Memory (GB) |
Model Size (GB) |
|---|---|---|---|---|---|---|
Llama-3-8B |
Base (bfloat16) |
7.590 |
32.36 |
485.71 |
16.19 |
15.01 |
codebook-4-64 |
9.533 |
1.73 |
8.62 |
23.11 |
4.98 |
|
Llama-3.1-8B |
Base (bfloat16) |
7.713 |
32.16 |
482.70 |
16.35 |
15.01 |
codebook-4-64 |
10.095 |
1.73 |
8.63 |
23.11 |
4.98 |
You try can out these apis with the quantize_ api as above alongside the config CodebookWeightOnlyConfig an example can be found in in torchao/_models/llama/generate.py.
Low-Precision FP8 Attention (Prototype)#
FP8 low-precision attention for inference, built on Flash Attention backends. Currently supports FA3 on Hopper (SM90) and FA4 on Blackwell (SM100).
Requirements: PyTorch >= 2.11, Hopper or Blackwell GPU, Flash Attention 3 (pip install flash-attn-3 --index-url=https://download.pytorch.org/whl/{cuda_version}).
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchao.prototype.attention import apply_low_precision_attention
# Simple model with attention
class MyModel(nn.Module):
def __init__(self, embed_dim=512, num_heads=8):
super().__init__()
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
def forward(self, x):
B, S, _ = x.shape
q = self.q_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, S, self.num_heads, self.head_dim).transpose(1, 2)
attn_out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.out_proj(attn_out.transpose(1, 2).contiguous().view(B, S, -1))
model = MyModel().to(device="cuda", dtype=torch.bfloat16).eval()
# Auto-detect best backend
model = apply_low_precision_attention(model)
# Or specify a backend explicitly
# model = apply_low_precision_attention(model, backend=AttentionBackend.FP8_FA3)
# Optional: torch.compile for RoPE fusion
model = torch.compile(model)
apply_low_precision_attention replaces all F.scaled_dot_product_attention calls with FP8 attention for eager execution. When combined with torch.compile, RoPE patterns are automatically detected and fused into a single kernel. KV caching should be disabled before calling for best results with torch.compile. See the API reference for details.