Shortcuts

Source code for torchao.quantization.smoothquant

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
Testing out accuracy-only implementation of SmoothQuant
(https://arxiv.org/pdf/2211.10438.pdf)
Note: this is an application of input-weight equalization, with the addition that the
multiplication by scale is fused into the preceding layer, specifically for relevant
parts of transformer blocks.
"""

import torch
import torch.nn.functional as F

from .utils import (
    dynamically_quantize_per_channel,
    quant_int8_dynamic_per_token_linear,
)

__all__ = [
    "get_scale",
    "SmoothFakeDynQuantMixin",
    "SmoothFakeDynamicallyQuantizedLinear",
    "swap_linear_with_smooth_fq_linear",
    "smooth_fq_linear_to_inference",
    "set_smooth_fq_attribute",
]


def get_scale(X_absmax, W_absmax, alpha=0.5):
    """
    Calculate the scale based on abs(max(X)), abs(max(W)), and alpha.

    Args:
        X_absmax (torch.Tensor): Absolute maximum values of the input tensor X.
        W_absmax (torch.Tensor): Absolute maximum values of the weight tensor W.
        alpha (float, optional): Scaling factor. Defaults to 0.5.

    Returns:
        torch.Tensor: The calculated scale of dimension `k` if X is of dimension `b*n*k` and W is of dimension `k*m`.
    """
    X_pow = torch.pow(X_absmax, alpha)
    W_pow = torch.pow(W_absmax, 1.0 - alpha)
    div = X_pow / W_pow
    return div.reshape(-1)


class SmoothFakeDynQuantMixin(torch.nn.Module):
    def init_smoothquant_variables(self, alpha):
        self.calibrating = True
        self.x_running_abs_max = None
        self.register_buffer("smooth_scale", None)
        self.alpha = alpha
        # debug only
        self.debug_skip_scaling = False
        # self.debug_skip_scaling = True

        # Currently torch._int_mm cuBLAS underlying kernel does not work with
        # non-contiguous weight. However, torch.compil'ing through
        # torch._int_mm leads to triton code which is ~2x faster if the weight
        # is transposed. So, for now we have a debug flag to toggle whether
        # we store the quantized weight transposed, so that we can get correct
        # numerics both in eager mode and after torch.compile.
        # The default is True for cuBLAS / eager mode, set to False for
        # torch.compile.
        # self.store_w_int_repr_t = True
        self.store_w_int_repr_t = False

    def update_x_running_abs_max(self, X):
        # update the running max of incoming activations
        all_dims_except_last = tuple(range(len(X.shape) - 1))
        cur_abs_max = torch.amax(torch.abs(X), dim=all_dims_except_last)
        if self.x_running_abs_max is None:
            self.x_running_abs_max = cur_abs_max
        else:
            self.x_running_abs_max = torch.max(cur_abs_max, self.x_running_abs_max)

    def get_scaled_quantized_w(self):
        # inference
        assert self.smooth_scale is not None, (
            "self.smooth_scale is None, did you turn on inference?"
        )
        W = self.weight

        # scale weight
        # in the future, this can be done ahead of time instead of
        # during inference
        if not self.debug_skip_scaling:
            # TODO(future): do below in `to_inference` instead of here
            W = torch.matmul(
                torch.diag(self.smooth_scale), W.transpose(0, 1)
            ).transpose(0, 1)

        # fake quantize input and weight, and then do matmul in fp32/fp16
        # in the future, this should be replaced with quantized kernels which
        # work on NVIDIA GPUs (such as protoquant's implementation)
        W_int_repr, W_scales, W_zps = dynamically_quantize_per_channel(
            W, -128, 127, torch.int8
        )
        W_int_repr = W_int_repr.contiguous()
        return W_int_repr, W_scales, W_zps

    def to_inference(self):
        raise NotImplementedError()

    def fold_weight(self):
        # note: _W_zps are zeroes and they are ignored
        # TODO(future PR): set up serialization for this
        W_int_repr, self.W_scales, _W_zps = self.get_scaled_quantized_w()
        # need to store transposed weights to make eager mode matmul
        # op work in cuBlas, or non-transposed to make it fast in torch.compile
        if self.store_w_int_repr_t:
            self.register_buffer("W_int_repr", W_int_repr.transpose(0, 1).contiguous())
        else:
            self.register_buffer("W_int_repr", W_int_repr.contiguous())
        del self.weight

    def set_debug_x_absmax(self):
        """
        Sets `self.x_running_abs_max` to a value which will lead to smooth scale
        of all ones if `alpha=0.5`, to enable performance benchmarking without
        calibration.
        """
        raise NotImplementedError()


class SmoothFakeDynamicallyQuantizedLinear(SmoothFakeDynQuantMixin, torch.nn.Linear):
    """
    This is a replacement for `torch.nn.Linear` which implements dynamic per-token
    activation quantization and dynamic per-channel weight quantization based on
    Smoothquant scaling.
    """

    def __init__(self, *args, **kwargs):
        alpha = kwargs.pop("alpha")
        super().__init__(*args, **kwargs)
        self.init_smoothquant_variables(alpha)

    def forward(self, X, *args, **kwargs):
        if self.calibrating:
            self.update_x_running_abs_max(X)
            Y = F.linear(X, self.weight, self.bias)
        else:
            if not self.debug_skip_scaling:
                # Ideally this would be fused into preceding layers
                # but in practice torch.compile fuses it with other
                # ops so the slowdown is minimal
                X = X / self.smooth_scale
            W_int_repr_t = (
                self.W_int_repr if self.store_w_int_repr_t else self.W_int_repr.t()
            )
            Y = quant_int8_dynamic_per_token_linear(
                X, W_int_repr_t, self.W_scales, self.bias, X.dtype
            )
        return Y

    @classmethod
    def from_float(cls, mod, alpha=0.5):
        """
        Converts a `mod` of class `torch.nn.Linear` to the smooth fake quantized
        version of it.  Note: requires calibration.
        """
        # create the new module with a toy size to ensure initialization is fast
        fake_in_features, fake_out_features = 8, 8
        new_mod = cls(
            fake_in_features, fake_out_features, bias=mod.bias is not None, alpha=alpha
        )
        new_mod.in_features = mod.in_features
        new_mod.out_features = mod.out_features
        new_mod.weight = mod.weight
        new_mod.bias = mod.bias
        # TODO: test when creation is on cuda
        device_to_use = next(mod.parameters()).device
        new_mod.to(device_to_use)
        return new_mod

    def to_inference(self):
        """
        Calculates the smoothquant scale based on calibration
        in preparation for inference
        """
        assert self.x_running_abs_max is not None, "no calibration data found"
        self.calibrating = False
        self.smooth_scale = get_scale(
            self.x_running_abs_max,
            torch.max(torch.abs(self.weight.transpose(0, 1)), dim=1).values,
            alpha=self.alpha,
        )
        self.fold_weight()

    def set_debug_x_absmax(self):
        w_absmax = torch.max(torch.abs(self.weight.transpose(0, 1)), dim=1).values
        self.x_running_abs_max = w_absmax


#
# utils to use the smooth linear on real models
#

source_cls_to_target_cls = {
    torch.nn.Linear: SmoothFakeDynamicallyQuantizedLinear,
    torch.nn.modules.linear.NonDynamicallyQuantizableLinear: SmoothFakeDynamicallyQuantizedLinear,
}


[docs]def swap_linear_with_smooth_fq_linear( model, skip_fqn_list=None, cur_fqn="", alpha=0.5 ) -> None: """ Replaces linear layers in the model with their SmoothFakeDynamicallyQuantizedLinear equivalents. Args: model (torch.nn.Module): The model containing linear layers to be replaced. skip_fqn_list (list of str, optional): List of fully qualified names to skip during replacement. Defaults to None. cur_fqn (str, optional): The current fully qualified name of the module being processed. Defaults to "". alpha (float, optional): The scaling factor for SmoothQuant. Defaults to 0.5. Returns: None """ name_to_child = dict(model.named_children()) for name, child in name_to_child.items(): if cur_fqn == "": new_fqn = name else: new_fqn = f"{cur_fqn}.{name}" if ((skip_fqn_list is None) or (new_fqn not in skip_fqn_list)) and ( type(child) in source_cls_to_target_cls.keys() ): target_cls = source_cls_to_target_cls[type(child)] new_child = target_cls.from_float(child, alpha=alpha) setattr(model, name, new_child) else: swap_linear_with_smooth_fq_linear(child, skip_fqn_list, new_fqn, alpha)
[docs]def smooth_fq_linear_to_inference(model, debug_skip_calibration=False) -> None: """ Prepares the model for inference by calculating the smoothquant scale for each SmoothFakeDynamicallyQuantizedLinear layer. Args: model (torch.nn.Module): The model containing SmoothFakeDynamicallyQuantizedLinear layers. debug_skip_calibration (bool, optional): If True, sets the running maximum of activations to a debug value for performance benchmarking. Defaults to False. Returns: None """ for _, mod in model.named_modules(): if isinstance(mod, tuple(source_cls_to_target_cls.values())): if debug_skip_calibration: mod.set_debug_x_absmax() mod.to_inference()
# useful for quickly toggling smoothquant debug settings on all smoothquant # modules in a model def set_smooth_fq_attribute(model, attribute_name, new_attribute_val): for _, mod in model.named_modules(): if isinstance(mod, tuple(source_cls_to_target_cls.values())): if hasattr(mod, attribute_name): setattr(mod, attribute_name, new_attribute_val)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources