Shortcuts

quantize

torchao.quantization.quantize_(model: Module, apply_tensor_subclass: Callable[[Module], Module], filter_fn: Optional[Callable[[Module, str], bool]] = None, set_inductor_config: bool = True, device: Optional[Union[device, str, int]] = None)[source]

Convert the weight of linear modules in the model with apply_tensor_subclass, model is modified inplace

Parameters:
  • model (torch.nn.Module) – input model

  • apply_tensor_subclass (Callable[[torch.nn.Module], torch.nn.Module]) – function that applies tensor subclass conversion to the weight of a module and return the module (e.g. convert the weight tensor of linear to affine quantized tensor)

  • filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]) – function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run apply_tensor_subclass on

  • module (the weight of the) –

  • set_inductor_config (bool, optional) – Whether to automatically use recommended inductor config settings (defaults to True)

  • device (device, optional) – Device to move module to before applying filter_fn. This can be set to “cuda” to speed up quantization. The final model will be on the specified device. Defaults to None (do not change device).

Example:

import torch
import torch.nn as nn
from torchao import quantize_

# 1. quantize with some predefined `apply_tensor_subclass` method that corresponds to
# optimized execution paths or kernels (e.g. int4 tinygemm kernel)
# also customizable with arguments
# currently options are
# int8_dynamic_activation_int4_weight (for executorch)
# int8_dynamic_activation_int8_weight (optimized with int8 mm op and torch.compile)
# int4_weight_only (optimized with int4 tinygemm kernel and torch.compile)
# int8_weight_only (optimized with int8 mm op and torch.compile
from torchao.quantization.quant_api import int4_weight_only

m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
quantize_(m, int4_weight_only(group_size=32))

# 2. write your own new apply_tensor_subclass
# You can also add your own apply_tensor_subclass by manually calling tensor subclass constructor
# on weight

from torchao.dtypes import to_affine_quantized_intx

# weight only uint4 asymmetric groupwise quantization
groupsize = 32
apply_weight_quant = lambda x: to_affine_quantized_intx(
  x, "asymmetric", (1, groupsize), torch.int32, 0, 15, 1e-6,
  zero_point_dtype=torch.bfloat16, preserve_zero=False, zero_point_domain="float")

def apply_weight_quant_to_linear(linear):
    linear.weight = torch.nn.Parameter(apply_weight_quant(linear.weight), requires_grad=False)
    return linear

# apply to modules under block0 submodule
def filter_fn(module: nn.Module, fqn: str) -> bool:
    return isinstance(module, nn.Linear)

m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
quantize_(m, apply_weight_quant_to_linear, filter_fn)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources