Rate this Page

Source code for torchao.quantization.qat.linear

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Optional

import torch
import torch.nn.functional as F

from torchao.dtypes.utils import is_device
from torchao.quantization.granularity import PerGroup, PerRow
from torchao.quantization.linear_quant_modules import (
    Int8DynActInt4WeightLinear,
    WeightOnlyInt4Linear,
    _check_linear_int4_k,
    _replace_linear_8da4w,
    _replace_linear_int4,
    groupwise_affine_quantize_tensor,
)
from torchao.quantization.quant_primitives import (
    TorchAODType,
    ZeroPointDomain,
)
from torchao.quantization.unified import TwoStepQuantizer
from torchao.quantization.utils import get_group_qparams_symmetric

from .fake_quantize_config import (
    FakeQuantizeConfigBase,
    Float8FakeQuantizeConfig,
    IntxFakeQuantizeConfig,
)
from .fake_quantizer import (
    FakeQuantizerBase,
)
from .utils import (
    _get_qmin_qmax,
)


[docs] class FakeQuantizedLinear(torch.nn.Linear): """ General linear layer with fake quantized weights and activations. Specific target dtypes, granularity, schemes etc. are specified through separate configs for weights and activations. Example usage:: activation_config = IntxFakeQuantizeConfig( dtype=torch.int8, granularity="per_token", is_symmetric=False, ) weight_config = IntxFakeQuantizeConfig( dtype=torch.int4, group_size=8, is_symmetric=True, ) fq_linear = FakeQuantizedLinear( 16, 32, False, activation_config, weight_config, ) fq_linear(torch.randn(16)) """ def __init__( self, in_features: int, out_features: int, bias: bool = False, activation_config: Optional[FakeQuantizeConfigBase] = None, weight_config: Optional[FakeQuantizeConfigBase] = None, *args, **kwargs, ) -> None: super().__init__( in_features, out_features, bias, *args, **kwargs, ) torch._C._log_api_usage_once("torchao.quantization.qat.FakeQuantizedLinear") # initialize activation fake quantizer if activation_config is not None: self.activation_fake_quantizer = FakeQuantizerBase.from_config( activation_config ) else: self.activation_fake_quantizer = None # initialize weight fake quantizer if weight_config is not None: if isinstance(weight_config, IntxFakeQuantizeConfig) and isinstance( weight_config.granularity, PerGroup ): group_size = weight_config.group_size if group_size is not None and in_features % group_size != 0: raise ValueError( "in_features (%s) %% group_size (%s) must be == 0" % (in_features, group_size) ) self.weight_fake_quantizer = FakeQuantizerBase.from_config(weight_config) else: self.weight_fake_quantizer = None
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: if self.activation_fake_quantizer is not None: x = self.activation_fake_quantizer(x) if self.weight_fake_quantizer is not None: w = self.weight_fake_quantizer(self.weight) else: w = self.weight return F.linear(x, w, self.bias)
def to_linear(self) -> torch.nn.Linear: new_linear = torch.nn.Linear( self.in_features, self.out_features, self.bias is not None, device=self.weight.device, dtype=self.weight.dtype, ) # In distributed training, the model may be instantiated # on the meta device, in which case there is no need to # copy the weights, and doing so will result in an error if self.weight.device != torch.device("meta"): new_linear.weight = self.weight new_linear.bias = self.bias return new_linear @classmethod def from_linear( cls, mod: torch.nn.Linear, activation_config: Optional[FakeQuantizeConfigBase] = None, weight_config: Optional[FakeQuantizeConfigBase] = None, ): new_linear = FakeQuantizedLinear( mod.in_features, mod.out_features, mod.bias is not None, activation_config=activation_config, weight_config=weight_config, device=mod.weight.device, dtype=mod.weight.dtype, ) # In distributed training, the model may be instantiated # on the meta device, in which case there is no need to # copy the weights, and doing so will result in an error if mod.weight.device != torch.device("meta"): new_linear.weight = mod.weight new_linear.bias = mod.bias return new_linear
[docs] def enable_linear_fake_quant( mod: torch.nn.Module, enabled: bool = True, ): """ Helper function to enable fake quantization in `FakeQuantizedLinear`. """ if isinstance(mod, FakeQuantizedLinear): if mod.activation_fake_quantizer is not None: mod.activation_fake_quantizer.enabled = enabled if mod.weight_fake_quantizer is not None: mod.weight_fake_quantizer.enabled = enabled
[docs] def disable_linear_fake_quant(mod: torch.nn.Module): """ Helper function to disable fake quantization in `FakeQuantizedLinear`. """ enable_linear_fake_quant(mod, enabled=False)
# =========================== # | QAT quantizer interface | # =========================== class _LegacyQATQuantizer(TwoStepQuantizer): """ Base class for sharing common methods across legacy QAT quantizers. """ def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: return None def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: return None # =========================================== # | int8 dynamic activations + int4 weights | # ===========================================
[docs] class Int8DynActInt4WeightQATQuantizer(_LegacyQATQuantizer): """ Quantizer for performing QAT on a model, where linear layers have int8 dynamic per token fake quantized activations and int4 fake quantized grouped per channel weights. """ def __init__( self, groupsize: int = 256, padding_allowed: bool = False, precision: torch.dtype = torch.float32, scales_precision: torch.dtype = torch.float32, ) -> None: super().__init__() torch._C._log_api_usage_once( "torchao.quantization.qat.Int8DynActInt4WeightQATQuantizer" ) self.groupsize: int = groupsize self.padding_allowed: bool = padding_allowed self.precision: torch.dtype = precision self.scales_precision: torch.dtype = scales_precision # TODO: generalize this self.activation_scales_precision = torch.float32 def prepare( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: _replace_linear_8da4w( model, self.groupsize, self.padding_allowed, self.precision, self.scales_precision, Int8DynActInt4WeightQATLinear, copy_weights=True, ) return model def convert( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: self._convert_qat_linear_8da4w(model) return model def _convert_qat_linear_8da4w(self, module: torch.nn.Module): """ Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`. """ for name, child in module.named_children(): if isinstance(child, Int8DynActInt4WeightQATLinear): config = child.weight_fake_quantizer.config quantized_linear = Int8DynActInt4WeightLinear( child.in_features, child.out_features, child.bias is not None, groupsize=config.group_size, precision=child.weight.dtype, scales_precision=config.scale_precision, ) setattr(module, name, quantized_linear) # Load weights and qparams into quantized linear n_bit = 4 (qmin, qmax) = _get_qmin_qmax(n_bit) (s, zp) = get_group_qparams_symmetric( child.weight, n_bit, config.group_size, precision=config.scale_precision, ) zp = zp.to(config.zero_point_precision) from torchao._executorch_ops import ( _quantized_decomposed_quantize_per_channel_group_wrapper, ) q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( child.weight, s, zp, qmin, qmax, torch.int8, config.group_size, ) quantized_linear.weight = q_weight quantized_linear.scales = s quantized_linear.zeros = zp if child.bias is not None: quantized_linear.bias = child.bias else: self._convert_qat_linear_8da4w(child) def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: return _get_8da4w_activation_config(self.activation_scales_precision) def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: return _get_8da4w_weight_config(self.groupsize, self.scales_precision)
[docs] class Int8DynActInt4WeightQATLinear(FakeQuantizedLinear): """ This module implements a linear layer with int8 dynamic per token fake quantized activations with int4 fake quantized grouped per channel weights. args: groupsize: the number of elements in each quantized group for weights precision: precision of weights scales_precision: precision of per group scales and zero points Note: we hardcode activation scales to use torch.fp32, but allow users to specify the weight scales (defaults to torch.fp32). Here scales_precision refers specifically to the weight scales only, not the activation scales. """ def __init__( self, in_features: int, out_features: int, bias: bool = False, device: torch.device = None, groupsize: int = 256, precision: torch.dtype = torch.float32, scales_precision: torch.dtype = torch.float32, ) -> None: # Use torch.float32 to match torchao.quantization.quant_api._int8_asymm_per_token_quant, # which is used in PTQ routines # TODO: generalize this activation_config = _get_8da4w_activation_config(torch.float32) weight_config = _get_8da4w_weight_config(groupsize, scales_precision) super().__init__( in_features, out_features, bias, activation_config, weight_config, device=device, dtype=precision, ) def enable_fake_quant(self, enabled: bool = True): self.activation_fake_quantizer.enabled = enabled self.weight_fake_quantizer.enabled = enabled def disable_fake_quant(self): self.enable_fake_quant(False)
# TODO: remove these in favor of enable_linear_fake_quant def enable_8da4w_fake_quant(mod: torch.nn.Module): """ (deprecated) Enable fake quantization for `Int8DynActInt4WeightQATLinear`. """ if isinstance(mod, Int8DynActInt4WeightQATLinear): mod.enable_fake_quant() # TODO: remove in favor of disable_linear_fake_quant def disable_8da4w_fake_quant(mod: torch.nn.Module): """ (deprecated) Disable fake quantization for `Int8DynActInt4WeightQATLinear`. """ if isinstance(mod, Int8DynActInt4WeightQATLinear): mod.disable_fake_quant() def _get_8da4w_activation_config( qparams_precision: torch.dtype, ) -> IntxFakeQuantizeConfig: """ Return the activation `IntxFakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`. """ # TODO: generalize this assert qparams_precision == torch.float32 return IntxFakeQuantizeConfig( dtype=torch.int8, granularity="per_token", is_symmetric=False, is_dynamic=True, scale_precision=qparams_precision, zero_point_precision=qparams_precision, eps=torch.finfo(qparams_precision).eps, ) def _get_8da4w_weight_config( group_size: int, qparams_precision: torch.dtype, ) -> IntxFakeQuantizeConfig: """ Return the weight `IntxFakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`. """ return IntxFakeQuantizeConfig( dtype=TorchAODType.INT4, group_size=group_size, is_symmetric=True, is_dynamic=True, scale_precision=qparams_precision, zero_point_precision=qparams_precision, ) # ==================== # | int4 weight-only | # ====================
[docs] class Int4WeightOnlyQATQuantizer(_LegacyQATQuantizer): """ Quantizer for performing QAT on a model, where linear layers have int4 fake quantized grouped per channel weights. """ def __init__( self, groupsize: int = 256, inner_k_tiles: Optional[int] = 8, precision: torch.dtype = torch.bfloat16, scales_precision: torch.dtype = torch.bfloat16, ) -> None: super().__init__() torch._C._log_api_usage_once( "torchao.quantization.qat.Int4WeightOnlyQATQuantizer" ) assert inner_k_tiles in [2, 4, 8] assert groupsize in [32, 64, 128, 256] self.inner_k_tiles = inner_k_tiles self.groupsize = groupsize self.precision = precision self.scales_precision = scales_precision def prepare( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: _replace_linear_int4( model, self.groupsize, self.inner_k_tiles, padding_allowed=True, precision=self.precision, scales_precision=self.scales_precision, linear_class=Int4WeightOnlyQATLinear, copy_weights=True, ) return model def convert( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: self._convert_qat_linear_4w(model) return model def _convert_qat_linear_4w(self, module: torch.nn.Module): """ Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`. """ for name, child in module.named_children(): if isinstance(child, Int4WeightOnlyQATLinear): in_features = child.in_features out_features = child.out_features inner_k_tiles = child.inner_k_tiles config = child.weight_fake_quantizer.config quantized_linear = WeightOnlyInt4Linear( in_features, out_features, bias=False, groupsize=config.group_size, inner_k_tiles=inner_k_tiles, precision=child.weight.dtype, scales_precision=config.scale_precision, device=next(child.parameters()).device, ) setattr(module, name, quantized_linear) # Load weights and qparams into quantized linear n_bit = 4 (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( child.weight, n_bit, config.group_size, ) if is_device(q_weight.device.type, "cpu"): q_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( q_weight.to(child.weight.device), child.inner_k_tiles, ) else: q_weight = torch.ops.aten._convert_weight_to_int4pack( q_weight.to(child.weight.device), child.inner_k_tiles, ) quantized_linear.weight = q_weight quantized_linear.scales_and_zeros = scales_and_zeros else: self._convert_qat_linear_4w(child) def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: return _get_4w_weight_config(self.groupsize, self.scales_precision)
[docs] class Int4WeightOnlyQATLinear(FakeQuantizedLinear): """ This module implements a linear layer with int4 fake quantized grouped per channel weights, with forward numerics matching `WeightOnlyInt4Linear`, which uses the efficient int4 tinygemm kernel. args: groupsize: the number of elements in each quantized group for weights precision: precision of weights scales_precision: precision of per group scales and zero points """ def __init__( self, in_features: int, out_features: int, bias: bool = False, device: torch.device = None, groupsize: int = 256, inner_k_tiles: int = 8, precision: torch.dtype = torch.bfloat16, scales_precision: torch.dtype = torch.bfloat16, ) -> None: assert scales_precision == torch.bfloat16, "only bf16 is supported for scales" if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles): raise ValueError("Padding for QAT 4w is not supported yet") self.inner_k_tiles = inner_k_tiles weight_config = _get_4w_weight_config(groupsize, scales_precision) super().__init__( in_features, out_features, bias, activation_config=None, weight_config=weight_config, device=device, dtype=precision, ) def enable_fake_quant(self, enabled: bool = True): self.activation_fake_quantizer.enabled = enabled self.weight_fake_quantizer.enabled = enabled def disable_fake_quant(self): self.enable_fake_quant(False)
# TODO: remove these in favor of enable_linear_fake_quant def enable_4w_fake_quant(mod: torch.nn.Module): """ (deprecated) Enable fake quantization for `Int4WeightOnlyQATLinear`. """ if isinstance(mod, Int4WeightOnlyQATLinear): mod.enable_fake_quant() # TODO: remove these in favor of disable_linear_fake_quant def disable_4w_fake_quant(mod: torch.nn.Module): """ (deprecated) Disable fake quantization for `Int4WeightOnlyQATLinear`. """ if isinstance(mod, Int4WeightOnlyQATLinear): mod.disable_fake_quant() def _get_4w_weight_config( group_size: int, qparams_precision: torch.dtype, ) -> IntxFakeQuantizeConfig: """ Return the weight `IntxFakeQuantizeConfig` for `Int4WeightOnlyQATQuantizer`. """ return IntxFakeQuantizeConfig( dtype=torch.uint4, group_size=group_size, is_symmetric=False, is_dynamic=True, scale_precision=qparams_precision, zero_point_precision=qparams_precision, zero_point_domain=ZeroPointDomain.FLOAT, ) # ============================================= # | float8 rowwise activations + int4 weights | # =============================================
[docs] class Float8ActInt4WeightQATQuantizer(_LegacyQATQuantizer): """ QAT quantizer for applying dynamic rowwise float8 activation + int4 per group/channel symmetric weight fake quantization to linear layers in the model. Currently only supports rowwise granularity for float8 activations. args: group_size (Optional[int]): the number of elements in each quantized group for weights, defaults to 64. Use None for per channel. scale_precision: precision of weight scales, defaults to torch.bfloat16. """ def __init__( self, group_size: Optional[int] = 64, scale_precision: torch.dtype = torch.bfloat16, ): torch._C._log_api_usage_once( "torchao.quantization.qat.Float8ActInt4WeightQATQuantizer" ) if group_size is not None: weight_granularity = "per_group" else: weight_granularity = "per_channel" self._activation_config = Float8FakeQuantizeConfig( dtype=torch.float8_e4m3fn, granularity=PerRow(), ) self._weight_config = IntxFakeQuantizeConfig( dtype=torch.int4, granularity=weight_granularity, group_size=group_size, is_symmetric=True, is_dynamic=True, scale_precision=scale_precision, )
[docs] def prepare( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: """ Swap all `nn.Linear` with `FakeQuantizedLinear` with float8 fake quantizer for activations and int4 fake quantizer for weights. """ for name, child in model.named_children(): if isinstance(child, torch.nn.Linear): new_linear = FakeQuantizedLinear.from_linear( child, activation_config=self._activation_config, weight_config=self._weight_config, ) setattr(model, name, new_linear) else: self.prepare(child) return model
# TODO: add convert path def convert( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: raise NotImplementedError def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: raise NotImplementedError("Float8 FakeQuantizeConfig does not exist yet") def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfigBase]: return self.weight_config