Shortcuts

Source code for torchao.quantization.qat.linear

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from typing import Any, Optional

import torch
import torch.nn.functional as F

from torchao.dtypes.utils import is_device
from torchao.quantization.granularity import PerGroup
from torchao.quantization.linear_quant_modules import (
    Int8DynActInt4WeightLinear,
    WeightOnlyInt4Linear,
    _check_linear_int4_k,
    _replace_linear_8da4w,
    _replace_linear_int4,
    groupwise_affine_quantize_tensor,
)
from torchao.quantization.quant_primitives import (
    TorchAODType,
    ZeroPointDomain,
)
from torchao.quantization.unified import TwoStepQuantizer
from torchao.quantization.utils import get_group_qparams_symmetric
from torchao.utils import TORCH_VERSION_AT_LEAST_2_6

from .api import FakeQuantizeConfig
from .fake_quantizer import (
    FakeQuantizer,
    _Float8RowwiseActivationFakeQuantizer,
)
from .utils import (
    _get_qmin_qmax,
)


class FakeQuantizedLinear(torch.nn.Linear):
    """
    General linear layer with fake quantized weights and activations.

    Specific target dtypes, granularity, schemes etc. are specified
    through separate configs for weights and activations.

    Example usage::

        activation_config = FakeQuantizeConfig(
            dtype=torch.int8,
            granularity="per_token",
            is_symmetric=False,
        )
        weight_config = FakeQuantizeConfig(
            dtype=torch.int4,
            group_size=8,
            is_symmetric=True,
        )
        fq_linear = FakeQuantizedLinear(
            16, 32, False, activation_config, weight_config,
        )
        fq_linear(torch.randn(16))
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = False,
        activation_config: Optional[FakeQuantizeConfig] = None,
        weight_config: Optional[FakeQuantizeConfig] = None,
        *args,
        **kwargs,
    ) -> None:
        super().__init__(
            in_features,
            out_features,
            bias,
            *args,
            **kwargs,
        )
        # initialize activation fake quantizer
        if activation_config is not None:
            self.activation_fake_quantizer = FakeQuantizer(activation_config)
        else:
            self.activation_fake_quantizer = None

        # initialize weight fake quantizer
        if weight_config is not None:
            if isinstance(weight_config.granularity, PerGroup):
                group_size = weight_config.group_size
                if group_size is not None and in_features % group_size != 0:
                    raise ValueError(
                        "in_features (%s) %% group_size (%s) must be == 0"
                        % (in_features, group_size)
                    )
            self.weight_fake_quantizer = FakeQuantizer(weight_config)
        else:
            self.weight_fake_quantizer = None

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.activation_fake_quantizer is not None:
            x = self.activation_fake_quantizer(x)
        if self.weight_fake_quantizer is not None:
            w = self.weight_fake_quantizer(self.weight)
        else:
            w = self.weight
        return F.linear(x, w, self.bias)

    def to_linear(self) -> torch.nn.Linear:
        new_linear = torch.nn.Linear(
            self.in_features,
            self.out_features,
            self.bias is not None,
            device=self.weight.device,
            dtype=self.weight.dtype,
        )
        # In distributed training, the model may be instantiated
        # on the meta device, in which case there is no need to
        # copy the weights, and doing so will result in an error
        if self.weight.device != torch.device("meta"):
            new_linear.weight = self.weight
            new_linear.bias = self.bias
        return new_linear

    @classmethod
    def from_linear(
        cls,
        mod: torch.nn.Linear,
        activation_config: Optional[FakeQuantizeConfig] = None,
        weight_config: Optional[FakeQuantizeConfig] = None,
    ):
        new_linear = FakeQuantizedLinear(
            mod.in_features,
            mod.out_features,
            mod.bias is not None,
            activation_config=activation_config,
            weight_config=weight_config,
            device=mod.weight.device,
            dtype=mod.weight.dtype,
        )
        # In distributed training, the model may be instantiated
        # on the meta device, in which case there is no need to
        # copy the weights, and doing so will result in an error
        if mod.weight.device != torch.device("meta"):
            new_linear.weight = mod.weight
            new_linear.bias = mod.bias
        return new_linear


# ===========================
# | QAT quantizer interface |
# ===========================


class _LegacyQATQuantizer(TwoStepQuantizer):
    """
    Base class for sharing common methods across legacy QAT quantizers.
    """

    def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
        return None

    def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
        return None


def enable_linear_fake_quant(
    mod: torch.nn.Module,
    enabled: bool = True,
):
    """
    Helper function to enable fake quantization in `FakeQuantizerLinear`.
    """
    if isinstance(mod, FakeQuantizedLinear):
        if mod.activation_fake_quantizer is not None:
            mod.activation_fake_quantizer.enabled = enabled
        if mod.weight_fake_quantizer is not None:
            mod.weight_fake_quantizer.enabled = enabled


def disable_linear_fake_quant(mod: torch.nn.Module):
    """
    Helper function to disable fake quantization in `FakeQuantizerLinear`.
    """
    enable_linear_fake_quant(mod, enabled=False)


# ===========================================
# | int8 dynamic activations + int4 weights |
# ===========================================


[docs]class Int8DynActInt4WeightQATQuantizer(_LegacyQATQuantizer): """ Quantizer for performing QAT on a model, where linear layers have int8 dynamic per token fake quantized activations and int4 fake quantized grouped per channel weights. """ def __init__( self, groupsize: int = 256, padding_allowed: bool = False, precision: torch.dtype = torch.float32, scales_precision: torch.dtype = torch.float32, ) -> None: super().__init__() self.groupsize: int = groupsize self.padding_allowed: bool = padding_allowed self.precision: torch.dtype = precision self.scales_precision: torch.dtype = scales_precision # TODO: generalize this self.activation_scales_precision = torch.float32 def prepare( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: _replace_linear_8da4w( model, self.groupsize, self.padding_allowed, self.precision, self.scales_precision, Int8DynActInt4WeightQATLinear, copy_weights=True, ) return model def convert( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: self._convert_qat_linear_8da4w(model) return model def _convert_qat_linear_8da4w(self, module: torch.nn.Module): """ Replace all `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`. """ for name, child in module.named_children(): if isinstance(child, Int8DynActInt4WeightQATLinear): config = child.weight_fake_quantizer.config quantized_linear = Int8DynActInt4WeightLinear( child.in_features, child.out_features, child.bias is not None, groupsize=config.group_size, precision=child.weight.dtype, scales_precision=config.scale_precision, ) setattr(module, name, quantized_linear) # Load weights and qparams into quantized linear n_bit = 4 (qmin, qmax) = _get_qmin_qmax(n_bit) (s, zp) = get_group_qparams_symmetric( child.weight, n_bit, config.group_size, precision=config.scale_precision, ) zp = zp.to(config.zero_point_precision) from torchao._executorch_ops import ( _quantized_decomposed_quantize_per_channel_group_wrapper, ) q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( child.weight, s, zp, qmin, qmax, torch.int8, config.group_size, ) quantized_linear.weight = q_weight quantized_linear.scales = s quantized_linear.zeros = zp if child.bias is not None: quantized_linear.bias = child.bias else: self._convert_qat_linear_8da4w(child) def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: return _get_8da4w_activation_config(self.activation_scales_precision) def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: return _get_8da4w_weight_config(self.groupsize, self.scales_precision)
class Int8DynActInt4WeightQATLinear(FakeQuantizedLinear): """ This module implements a linear layer with int8 dynamic per token fake quantized activations with int4 fake quantized grouped per channel weights. args: groupsize: the number of elements in each quantized group for weights precision: precision of weights scales_precision: precision of per group scales and zero points Note: we hardcode activation scales to use torch.fp32, but allow users to specify the weight scales (defaults to torch.fp32). To get an exact numerical match with Int8DynamicActivationInt4WeightConfig, users must use the same dtype for both the weights and the scales. Here scales_precision refers specifically to the weight scales only, not the activation scales. """ def __init__( self, in_features: int, out_features: int, bias: bool = False, device: torch.device = None, groupsize: int = 256, precision: torch.dtype = torch.float32, scales_precision: torch.dtype = torch.float32, ) -> None: # Use torch.float32 to match torchao.quantization.quant_api._int8_asymm_per_token_quant, # which is used in PTQ routines # TODO: generalize this activation_config = _get_8da4w_activation_config(torch.float32) weight_config = _get_8da4w_weight_config(groupsize, scales_precision) super().__init__( in_features, out_features, bias, activation_config, weight_config, device=device, dtype=precision, ) def enable_fake_quant(self, enabled: bool = True): self.activation_fake_quantizer.enabled = enabled self.weight_fake_quantizer.enabled = enabled def disable_fake_quant(self): self.enable_fake_quant(False) # TODO: remove these in favor of enable_linear_fake_quant def enable_8da4w_fake_quant(mod: torch.nn.Module): """ Enable fake quantization for `Int8DynActInt4WeightQATLinear`. """ if isinstance(mod, Int8DynActInt4WeightQATLinear): mod.enable_fake_quant() # TODO: remove in favor of disable_linear_fake_quant def disable_8da4w_fake_quant(mod: torch.nn.Module): """ Disable fake quantization for `Int8DynActInt4WeightQATLinear`. """ if isinstance(mod, Int8DynActInt4WeightQATLinear): mod.disable_fake_quant() def _get_8da4w_activation_config(qparams_precision: torch.dtype) -> FakeQuantizeConfig: """ Return the activation `FakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`. """ # TODO: generalize this assert qparams_precision == torch.float32 return FakeQuantizeConfig( dtype=torch.int8, granularity="per_token", is_symmetric=False, is_dynamic=True, scale_precision=qparams_precision, zero_point_precision=qparams_precision, eps=torch.finfo(qparams_precision).eps, ) def _get_8da4w_weight_config( group_size: int, qparams_precision: torch.dtype, ) -> FakeQuantizeConfig: """ Return the weight `FakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`. """ return FakeQuantizeConfig( dtype=TorchAODType.INT4, group_size=group_size, is_symmetric=True, is_dynamic=True, scale_precision=qparams_precision, zero_point_precision=qparams_precision, ) # ==================== # | int4 weight-only | # ====================
[docs]class Int4WeightOnlyQATQuantizer(_LegacyQATQuantizer): """ Quantizer for performing QAT on a model, where linear layers have int4 fake quantized grouped per channel weights. """ def __init__( self, groupsize: int = 256, inner_k_tiles: Optional[int] = 8, precision: torch.dtype = torch.bfloat16, scales_precision: torch.dtype = torch.bfloat16, ) -> None: super().__init__() assert inner_k_tiles in [2, 4, 8] assert groupsize in [32, 64, 128, 256] self.inner_k_tiles = inner_k_tiles self.groupsize = groupsize self.precision = precision self.scales_precision = scales_precision def prepare( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: _replace_linear_int4( model, self.groupsize, self.inner_k_tiles, padding_allowed=True, precision=self.precision, scales_precision=self.scales_precision, linear_class=Int4WeightOnlyQATLinear, copy_weights=True, ) return model def convert( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: self._convert_qat_linear_4w(model) return model def _convert_qat_linear_4w(self, module: torch.nn.Module): """ Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`. """ for name, child in module.named_children(): if isinstance(child, Int4WeightOnlyQATLinear): in_features = child.in_features out_features = child.out_features inner_k_tiles = child.inner_k_tiles config = child.weight_fake_quantizer.config quantized_linear = WeightOnlyInt4Linear( in_features, out_features, bias=False, groupsize=config.group_size, inner_k_tiles=inner_k_tiles, precision=child.weight.dtype, scales_precision=config.scale_precision, device=next(child.parameters()).device, ) setattr(module, name, quantized_linear) # Load weights and qparams into quantized linear n_bit = 4 (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( child.weight, n_bit, config.group_size, ) if ( is_device(q_weight.device.type, "cpu") and TORCH_VERSION_AT_LEAST_2_6 ): q_weight = torch.ops.aten._convert_weight_to_int4pack_for_cpu( q_weight.to(child.weight.device), child.inner_k_tiles, ) else: q_weight = torch.ops.aten._convert_weight_to_int4pack( q_weight.to(child.weight.device), child.inner_k_tiles, ) quantized_linear.weight = q_weight quantized_linear.scales_and_zeros = scales_and_zeros else: self._convert_qat_linear_4w(child) def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: return _get_4w_weight_config(self.groupsize, self.scales_precision)
class Int4WeightOnlyQATLinear(FakeQuantizedLinear): """ This module implements a linear layer with int4 fake quantized grouped per channel weights, with forward numerics matching `WeightOnlyInt4Linear`, which uses the efficient int4 tinygemm kernel. args: groupsize: the number of elements in each quantized group for weights precision: precision of weights scales_precision: precision of per group scales and zero points """ def __init__( self, in_features: int, out_features: int, bias: bool = False, device: torch.device = None, groupsize: int = 256, inner_k_tiles: int = 8, precision: torch.dtype = torch.bfloat16, scales_precision: torch.dtype = torch.bfloat16, ) -> None: assert scales_precision == torch.bfloat16, "only bf16 is supported for scales" if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles): raise ValueError("Padding for QAT 4w is not supported yet") self.inner_k_tiles = inner_k_tiles weight_config = _get_4w_weight_config(groupsize, scales_precision) super().__init__( in_features, out_features, bias, activation_config=None, weight_config=weight_config, device=device, dtype=precision, ) def enable_fake_quant(self, enabled: bool = True): self.activation_fake_quantizer.enabled = enabled self.weight_fake_quantizer.enabled = enabled def disable_fake_quant(self): self.enable_fake_quant(False) # TODO: remove these in favor of enable_linear_fake_quant def enable_4w_fake_quant(mod: torch.nn.Module): """ Enable fake quantization for `Int4WeightOnlyQATLinear`. """ if isinstance(mod, Int4WeightOnlyQATLinear): mod.enable_fake_quant() # TODO: remove these in favor of disable_linear_fake_quant def disable_4w_fake_quant(mod: torch.nn.Module): """ Disable fake quantization for `Int4WeightOnlyQATLinear`. """ if isinstance(mod, Int4WeightOnlyQATLinear): mod.disable_fake_quant() def _get_4w_weight_config( group_size: int, qparams_precision: torch.dtype, ) -> FakeQuantizeConfig: """ Return the weight `FakeQuantizeConfig` for `Int4WeightOnlyQATQuantizer`. """ return FakeQuantizeConfig( dtype=torch.uint4, group_size=group_size, is_symmetric=False, is_dynamic=True, scale_precision=qparams_precision, zero_point_precision=qparams_precision, zero_point_domain=ZeroPointDomain.FLOAT, ) # ============================================= # | float8 rowwise activations + int4 weights | # ============================================= class Float8ActInt4WeightQATQuantizer(_LegacyQATQuantizer): """ QAT quantizer for applying dynamic rowwise float8 activation + int4 per group/channel symmetric weight fake quantization to linear layers in the model. Currently only supports rowwise granularity for float8 activations. args: group_size (Optional[int]): the number of elements in each quantized group for weights, defaults to 64. Use None for per channel. scale_precision: precision of weight scales, defaults to torch.bfloat16. """ def __init__( self, group_size: Optional[int] = 64, scale_precision: torch.dtype = torch.bfloat16, ): if group_size is not None: weight_granularity = "per_group" else: weight_granularity = "per_channel" self._weight_config = FakeQuantizeConfig( dtype=torch.int4, granularity=weight_granularity, group_size=group_size, is_symmetric=True, is_dynamic=True, scale_precision=scale_precision, ) def prepare( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: """ Swap all `nn.Linear` with `FakeQuantizedLinear` with float8 fake quantizer for activations and int4 fake quantizer for weights. """ for name, child in model.named_children(): if isinstance(child, torch.nn.Linear): # TODO: add a config for float8? new_linear = FakeQuantizedLinear.from_linear( child, weight_config=self._weight_config, ) new_linear.activation_fake_quantizer = ( _Float8RowwiseActivationFakeQuantizer() ) setattr(model, name, new_linear) else: self.prepare(child) return model # TODO: add convert path def convert( self, model: torch.nn.Module, *args: Any, **kwargs: Any ) -> torch.nn.Module: raise NotImplementedError def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: raise NotImplementedError("Float8 FakeQuantizeConfig does not exist yet") def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]: return self.weight_config

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources