Shortcuts

Source code for torchao.quantization.qat.api

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass
from enum import Enum
from typing import Any, List, Optional, Tuple

import torch

from torchao.core.config import AOBaseConfig
from torchao.quantization.transform_module import (
    _QUANTIZE_CONFIG_HANDLER,
    register_quantize_module_handler,
)
from torchao.quantization.unified import TwoStepQuantizer

from .embedding import FakeQuantizedEmbedding
from .fake_quantize_config import (
    FakeQuantizeConfig,  # noqa: F401, for BC
    FakeQuantizeConfigBase,
    _infer_fake_quantize_configs,
)
from .linear import FakeQuantizedLinear
from .utils import _log_deprecation_warning


[docs]class QATStep(str, Enum): """ Enum value for the `step` field in :class:`~torchao.quantization.qat.QATConfig`. """ PREPARE = "prepare" CONVERT = "convert"
[docs]@dataclass class QATConfig(AOBaseConfig): """ Config for applying quantization-aware training (QAT) to a `torch.nn.Module`, to be used with :func:`~torchao.quantization.quant_api.quantize_`. This config has two steps, "prepare" and "convert". The prepare step applies "fake" quantization to the model and should be applied before training, while the convert step converts the model into an actual quantized model. Fake quantization here refers to simulating the quantization numerics (e.g. int4) using high precision arithmetic (e.g. bf16), with the goal of reducing eventual degradation from quantization. There are two ways to use this config. The first involves passing a base post-training quantization (PTQ) config, which we will use to automatically infer the corresponding fake quantization schemes to use in the prepare phase. In the convert phase, we will then apply the base PTQ config to the model. This will be the most common use case. Example usage:: from torchao.quantization import ( quantize_, Int8DynamicActivationInt4WeightConfig, ) from torchao.quantization.qat import QATConfig base_config = Int8DynamicActivationInt4WeightConfig(group_size=32) quantize_(model, QATConfig(base_config, step="prepare")) train_loop(model) quantize_(model, QATConfig(base_config, step="convert")) Currently only the following are supported as base configs: - :class:`~torchao.quantization.Int8DynamicActivationInt4WeightConfig` - :class:`~torchao.quantization.Int4WeightOnlyConfig` The second way to use this config involves specifying the fake quantization schemes directly. Users will pass in :class:`~torchao.quantization.qat.FakeQuantizeConfigBase` for weights and/or activations instead of the base PTQ config. This use case is mostly for experimentation, e.g. when the corresponding PTQ config does not exist yet. Example usage:: from torchao.quantization import quantize_ from torchao.quantization.qat import IntxFakeQuantizeConfig activation_config = IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=False, ) weight_config = IntxFakeQuantizeConfig( torch.int4, group_size=32, is_symmetric=True, ) qat_config = QATConfig( # must specify one of `base_config` or `weight_config` activation_config=act_config, weight_config=weight_config, step="prepare", ) quantize_(model, qat_config) Args: base_config (Optional[AOBaseConfig]): Base PTQ config to infer the fake quantization configs during the prepare phase, and to apply directly during the convert phase. activation_config (Optional[FakeQuantizeConfigBase]): Custom fake quantization config for input activations, always optional. Must be None if `base_config` is used. weight_config (Optional[FakeQuantizeConfigBase]): Custom fake quantization config for weights. Must be None if `base_config` is used. Keyword args: step (str): One of "prepare" or "convert", determines the QAT phase Raises: ValueError: If `base_config` and `activation_config` are both specified ValueError: If `base_config` and `weight_config` are both specified ValueError: If none of `base_config`, `activation_config`, or `weight_config` are specified ValueError: If either `activation_config` or `weight_config` is specified and `step` is "convert" ValueError: If `step` is not one of "prepare" or "convert" ValueError: If the config is applied on a module that is not a `torch.nn.Linear` or `torch.nn.Embedding`, or it is applied on `torch.nn.Embedding` with an activation config """ base_config: Optional[AOBaseConfig] activation_config: Optional[FakeQuantizeConfigBase] weight_config: Optional[FakeQuantizeConfigBase] step: QATStep # Express `step` as a keyword argument # TODO: Use `kw_only=True` instead, added in python 3.10 def __init__( self, base_config: Optional[AOBaseConfig] = None, activation_config: Optional[FakeQuantizeConfigBase] = None, weight_config: Optional[FakeQuantizeConfigBase] = None, *, step: QATStep = "prepare", ): self.base_config = base_config self.activation_config = activation_config self.weight_config = weight_config self.step = step self.__post_init__() def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.qat.QATConfig") self.step = self.step.lower() all_step_values = [s.value for s in QATStep] if self.step not in all_step_values: raise ValueError(f"`step` must be one of {all_step_values}") if self.base_config is not None and self.activation_config is not None: raise ValueError( "Cannot specify both `base_config` and `activation_config`" ) if self.base_config is not None and self.weight_config is not None: raise ValueError("Cannot specify both `base_config` and `weight_config`") if self.step == QATStep.PREPARE and not any( (self.base_config, self.activation_config, self.weight_config) ): raise ValueError( "Must specify `base_config`, `activation_config`, or `weight_config` in the prepare step" ) if self.step == QATStep.CONVERT and ( self.activation_config is not None or self.weight_config is not None ): raise ValueError( "Cannot specify `weight_config` or `activation_config` in the convert step" ) if isinstance(self.base_config, FakeQuantizeConfigBase): config_type = self.base_config.__class__.__name__ raise ValueError( f"{config_type} was passed as `base_config`. Did you mean to do the following instead?\n" " qat_config = QATConfig(\n" f" activation_config={config_type}(...),\n" f" weight_config={config_type}(...),\n" ' step="prepare",\n' " )" )
@register_quantize_module_handler(QATConfig) def _qat_config_transform( module: torch.nn.Module, config: QATConfig, ) -> torch.nn.Module: """ During the prepare step, perform module swap to apply fake quantization. If the base PTQ config is specified, derive the fake quantization configs from it. During the convert step, first perform module swap to revert all fake quantized modules to the corresponding built-in `torch.nn.Module`s, then apply the base config directly to quantize the module. """ # Prepare step # Swap nn.Linear -> FakeQuantizedLinear # Swap nn.Embedding -> FakeQuantizedEmbedding base_config = config.base_config step = config.step if step == QATStep.PREPARE: if base_config is not None: (act_config, weight_config) = _infer_fake_quantize_configs(base_config) else: act_config = config.activation_config weight_config = config.weight_config if isinstance(module, torch.nn.Linear): return FakeQuantizedLinear.from_linear(module, act_config, weight_config) elif isinstance(module, torch.nn.Embedding): if act_config is not None: raise ValueError( "Activation fake quantization is not supported for embedding" ) return FakeQuantizedEmbedding.from_embedding(module, weight_config) else: raise ValueError( "Module of type '%s' does not have QAT support" % type(module) ) else: # Convert step # Swap FakeQuantizedLinear -> nn.Linear # Swap FakeQuantizedEmbedding -> nn.Embedding # Then apply the base config's transform function to quantize the model # If there is no base config, then simply perform the module swap assert step == QATStep.CONVERT, "unexpected step '%s' in QATConfig" % step assert config.activation_config is None, "unexpected `activation_config`" assert config.weight_config is None, "unexpected `weight_config`" if isinstance(module, FakeQuantizedLinear): module = module.to_linear() elif isinstance(module, FakeQuantizedEmbedding): module = module.to_embedding() else: # Unrelated module, ignore return module if base_config is not None: return _QUANTIZE_CONFIG_HANDLER[type(base_config)](module, base_config) else: return module
[docs]@dataclass class IntXQuantizationAwareTrainingConfig(AOBaseConfig): """ (Deprecated) Please use :class:`~torchao.quantization.qat.QATConfig` instead. Config for applying fake quantization to a `torch.nn.Module`. to be used with :func:`~torchao.quantization.quant_api.quantize_`. Example usage:: from torchao.quantization import quantize_ from torchao.quantization.qat import IntxFakeQuantizeConfig activation_config = IntxFakeQuantizeConfig( torch.int8, "per_token", is_symmetric=False, ) weight_config = IntxFakeQuantizeConfig( torch.int4, group_size=32, is_symmetric=True, ) quantize_( model, IntXQuantizationAwareTrainingConfig(activation_config, weight_config), ) Note: If the config is applied on a module that is not `torch.nn.Linear` or `torch.nn.Embedding`, or it is applied on `torch.nn.Embedding` with an activation config, then we will raise ValueError as these are not supported. """ activation_config: Optional[FakeQuantizeConfigBase] = None weight_config: Optional[FakeQuantizeConfigBase] = None def __post_init__(self): _log_deprecation_warning(self)
# for BC class intx_quantization_aware_training(IntXQuantizationAwareTrainingConfig): pass @register_quantize_module_handler(IntXQuantizationAwareTrainingConfig) def _intx_quantization_aware_training_transform( module: torch.nn.Module, config: IntXQuantizationAwareTrainingConfig, ) -> torch.nn.Module: mod = module activation_config = config.activation_config weight_config = config.weight_config if isinstance(mod, torch.nn.Linear): return FakeQuantizedLinear.from_linear( mod, activation_config, weight_config, ) elif isinstance(mod, torch.nn.Embedding): if activation_config is not None: raise ValueError( "Activation fake quantization is not supported for embedding" ) return FakeQuantizedEmbedding.from_embedding(mod, weight_config) else: raise ValueError("Module of type '%s' does not have QAT support" % type(mod))
[docs]@dataclass class FromIntXQuantizationAwareTrainingConfig(AOBaseConfig): """ (Deprecated) Please use :class:`~torchao.quantization.qat.QATConfig` instead. Config for converting a model with fake quantized modules, such as :func:`~torchao.quantization.qat.linear.FakeQuantizedLinear` and :func:`~torchao.quantization.qat.linear.FakeQuantizedEmbedding`, back to model with the original, corresponding modules without fake quantization. This should be used with :func:`~torchao.quantization.quant_api.quantize_`. Example usage:: from torchao.quantization import quantize_ quantize_( model_with_fake_quantized_linears, FromIntXQuantizationAwareTrainingConfig(), ) """ def __post_init__(self): _log_deprecation_warning(self)
# for BC class from_intx_quantization_aware_training(FromIntXQuantizationAwareTrainingConfig): pass @register_quantize_module_handler(FromIntXQuantizationAwareTrainingConfig) def _from_intx_quantization_aware_training_transform( mod: torch.nn.Module, config: FromIntXQuantizationAwareTrainingConfig, ) -> torch.nn.Module: """ If the given module is a fake quantized module, return the original corresponding version of the module without fake quantization. """ if isinstance(mod, FakeQuantizedLinear): return mod.to_linear() elif isinstance(mod, FakeQuantizedEmbedding): return mod.to_embedding() else: return mod
[docs]class ComposableQATQuantizer(TwoStepQuantizer): """ Composable quantizer that users can use to apply multiple QAT quantizers easily. Quantizers will be applied in the order they are specified in the constructor. Note: the quantizers provided must apply to different modules in the model, e.g. nn.Linear and nn.Embedding, otherwise the behavior will be undefined. Example usage:: my_quantizer = ComposableQATQuantizer([ QATQuantizer1(), QATQuantizer2(), QATQuantizer3(), ]) model = my_quantizer.prepare(model) train(model) model = my_quantizer.convert(model) """ def __init__(self, quantizers: List[TwoStepQuantizer]): torch._C._log_api_usage_once("torchao.quantization.qat.ComposableQATQuantizer") self.quantizers = quantizers def prepare( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: for quantizer in self.quantizers: model = quantizer.prepare(model) return model def convert( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: for quantizer in self.quantizers: model = quantizer.convert(model) return model
[docs]def initialize_fake_quantizers( model: torch.nn.Module, example_inputs: Tuple[Any, ...], ) -> None: """ (Prototype) Initialize the scales and zero points on all :class:`~torchao.quantization.qat.fake_quantizer.IntxFakeQuantizerBase` in the model based on the provided example inputs. """ torch._C._log_api_usage_once("torchao.quantization.qat.initialize_fake_quantizers") # avoid circular dependencies from torchao.quantization.qat.fake_quantizer import IntxFakeQuantizer def _set_initialized(m: torch.nn.Module): if isinstance(m, IntxFakeQuantizer): m._initialized = True model.apply(_set_initialized) model(*example_inputs)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources