Shortcuts

QATConfig

class torchao.quantization.qat.QATConfig(base_config: Optional[AOBaseConfig] = None, activation_config: Optional[FakeQuantizeConfigBase] = None, weight_config: Optional[FakeQuantizeConfigBase] = None, *, step: QATStep = 'prepare')[source]

Config for applying quantization-aware training (QAT) to a torch.nn.Module, to be used with quantize_().

This config has two steps, “prepare” and “convert”. The prepare step applies “fake” quantization to the model and should be applied before training, while the convert step converts the model into an actual quantized model. Fake quantization here refers to simulating the quantization numerics (e.g. int4) using high precision arithmetic (e.g. bf16), with the goal of reducing eventual degradation from quantization.

There are two ways to use this config. The first involves passing a base post-training quantization (PTQ) config, which we will use to automatically infer the corresponding fake quantization schemes to use in the prepare phase. In the convert phase, we will then apply the base PTQ config to the model. This will be the most common use case.

Example usage:

from torchao.quantization import (
    quantize_,
    Int8DynamicActivationInt4WeightConfig,
)
from torchao.quantization.qat import QATConfig

base_config = Int8DynamicActivationInt4WeightConfig(group_size=32)
quantize_(model, QATConfig(base_config, step="prepare"))
train_loop(model)
quantize_(model, QATConfig(base_config, step="convert"))

Currently only the following are supported as base configs:

The second way to use this config involves specifying the fake quantization schemes directly. Users will pass in FakeQuantizeConfigBase for weights and/or activations instead of the base PTQ config. This use case is mostly for experimentation, e.g. when the corresponding PTQ config does not exist yet.

Example usage:

from torchao.quantization import quantize_
from torchao.quantization.qat import IntxFakeQuantizeConfig

activation_config = IntxFakeQuantizeConfig(
    torch.int8, "per_token", is_symmetric=False,
)
weight_config = IntxFakeQuantizeConfig(
    torch.int4, group_size=32, is_symmetric=True,
)
qat_config = QATConfig(
    # must specify one of `base_config` or `weight_config`
    activation_config=act_config,
    weight_config=weight_config,
    step="prepare",
)
quantize_(model, qat_config)
Parameters:
  • base_config (Optional[AOBaseConfig]) – Base PTQ config to infer the fake quantization configs during the prepare phase, and to apply directly during the convert phase.

  • activation_config (Optional[FakeQuantizeConfigBase]) – Custom fake quantization config for input activations, always optional. Must be None if base_config is used.

  • weight_config (Optional[FakeQuantizeConfigBase]) – Custom fake quantization config for weights. Must be None if base_config is used.

Keyword Arguments:

step (str) – One of “prepare” or “convert”, determines the QAT phase

Raises:
  • ValueError – If base_config and activation_config are both specified

  • ValueError – If base_config and weight_config are both specified

  • ValueError – If none of base_config, activation_config, or weight_config are specified

  • ValueError – If either activation_config or weight_config is specified and step is “convert”

  • ValueError – If step is not one of “prepare” or “convert”

  • ValueError – If the config is applied on a module that is not a torch.nn.Linear or torch.nn.Embedding, or it is applied on torch.nn.Embedding with an activation config

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources