QATConfig¶
- class torchao.quantization.qat.QATConfig(base_config: Optional[AOBaseConfig] = None, activation_config: Optional[FakeQuantizeConfigBase] = None, weight_config: Optional[FakeQuantizeConfigBase] = None, *, step: QATStep = 'prepare')[source]¶
Config for applying quantization-aware training (QAT) to a torch.nn.Module, to be used with
quantize_()
.This config has two steps, “prepare” and “convert”. The prepare step applies “fake” quantization to the model and should be applied before training, while the convert step converts the model into an actual quantized model. Fake quantization here refers to simulating the quantization numerics (e.g. int4) using high precision arithmetic (e.g. bf16), with the goal of reducing eventual degradation from quantization.
There are two ways to use this config. The first involves passing a base post-training quantization (PTQ) config, which we will use to automatically infer the corresponding fake quantization schemes to use in the prepare phase. In the convert phase, we will then apply the base PTQ config to the model. This will be the most common use case.
Example usage:
from torchao.quantization import ( quantize_, Int8DynamicActivationInt4WeightConfig, ) from torchao.quantization.qat import QATConfig base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) quantize_(model, QATConfig(base_config, step="prepare")) train_loop(model) quantize_(model, QATConfig(base_config, step="convert"))
Currently only the following are supported as base configs:
The second way to use this config involves specifying the fake quantization schemes directly. Users will pass in
FakeQuantizeConfigBase
for weights and/or activations instead of the base PTQ config. This use case is mostly for experimentation, e.g. when the corresponding PTQ config does not exist yet.Example usage:
from torchao.quantization import quantize_ from torchao.quantization.qat import IntxFakeQuantizeConfig activation_config = IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=False, ) weight_config = IntxFakeQuantizeConfig( torch.int4, group_size=32, is_symmetric=True, ) qat_config = QATConfig( # must specify one of `base_config` or `weight_config` activation_config=act_config, weight_config=weight_config, step="prepare", ) quantize_(model, qat_config)
- Parameters:
base_config (Optional[AOBaseConfig]) – Base PTQ config to infer the fake quantization configs during the prepare phase, and to apply directly during the convert phase.
activation_config (Optional[FakeQuantizeConfigBase]) – Custom fake quantization config for input activations, always optional. Must be None if base_config is used.
weight_config (Optional[FakeQuantizeConfigBase]) – Custom fake quantization config for weights. Must be None if base_config is used.
- Keyword Arguments:
step (str) – One of “prepare” or “convert”, determines the QAT phase
- Raises:
ValueError – If base_config and activation_config are both specified
ValueError – If base_config and weight_config are both specified
ValueError – If none of base_config, activation_config, or weight_config are specified
ValueError – If either activation_config or weight_config is specified and step is “convert”
ValueError – If step is not one of “prepare” or “convert”
ValueError – If the config is applied on a module that is not a torch.nn.Linear or torch.nn.Embedding, or it is applied on torch.nn.Embedding with an activation config