Source code for torchao.quantization.quant_primitives
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import math
from enum import Enum, auto
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from torchao.prototype.custom_fp_utils import (
_f32_to_floatx_unpacked,
_floatx_unpacked_to_f32,
_n_ones,
)
from torchao.utils import (
TORCH_VERSION_AT_LEAST_2_3,
TORCH_VERSION_AT_LEAST_2_5,
TORCH_VERSION_AT_LEAST_2_6,
_is_float8_type,
_register_custom_op,
)
__all__ = [
"choose_qparams_affine",
"choose_qparams_affine_with_min_max",
"choose_qparams_affine_floatx",
"quantize_affine",
"dequantize_affine",
"quantize_affine_floatx",
"dequantize_affine_floatx",
"fake_quantize_affine",
"fake_quantize_affine_cachemask",
"choose_qparams_and_quantize_affine_hqq",
"choose_qparams_and_quantize_affine_qqq",
"dequantize_affine_qqq",
"MappingType",
"ZeroPointDomain",
"TorchAODType",
"choose_qparams_affine_float8",
"quantize_affine_float8",
"dequantize_affine_float8",
"choose_qparams_gguf",
"quantize_gguf",
"dequantize_gguf",
]
[docs]class MappingType(Enum):
"""How floating point number is mapped to integer number
symmetric mapping means floating point range is symmetrically mapped to integer range
let's say we have floating point range (-3.5, 10.2) and integer range (-8, 7) (int4)
we'll use (-10.2, 10.2) as the range for floating point and map that to (-8, 7)
e.g. scale = (10.2 - (-10.2)) / (7 - (-8))
SYMMETRIC_NO_CLIPPING_ERR is a variant of symmetric mapping, where the scale is the max of smin
and smax, where smin = min_val_neg / quant_min, and smax = max_val_pos / quant_max. By calculating
smin and smax individually, there can be less round error on negative values, and no out-of-range
of all floating point values.
asymmetric mapping means we just directly map the floating point range to integer range,
for the above example, we will map (-3.5, 10.2) to (-8, 7) and calculate quantization parameter
based on this mapping
e.g. scale = (10.2 - (-3.5)) / (7 - (-8))
"""
SYMMETRIC = auto()
SYMMETRIC_NO_CLIPPING_ERR = auto()
ASYMMETRIC = auto()
[docs]class ZeroPointDomain(Enum):
"""Enum that indicate whether zero_point is in integer domain or floating point domain
integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer)
float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale
none domain: quantized_val = (float_val / scale)
"""
INT = auto()
FLOAT = auto()
NONE = auto()
[docs]class TorchAODType(Enum):
"""
Placeholder for dtypes that do not exist in PyTorch core yet.
"""
# torch.int1 to torch.int7 will be added to PyTorch 2.6
# These will remain here for BC with older PyTorch versions
INT1 = auto()
INT2 = auto()
INT3 = auto()
INT4 = auto()
INT5 = auto()
INT6 = auto()
INT7 = auto()
if TORCH_VERSION_AT_LEAST_2_5:
torch.serialization.add_safe_globals([MappingType, ZeroPointDomain])
FP8_TYPES = {
torch.float8_e4m3fn,
torch.float8_e5m2,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
}
"""
Map from dtype to the bound value of integers
TODO: maybe can replace this with call to torch.iinfo
"""
_DTYPE_TO_QVALUE_BOUNDS: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = {
torch.uint8: (0, 255),
torch.int8: (-128, 127),
torch.int16: (-(2**15), 2**15 - 1),
torch.int32: (-(2**31), 2**31 - 1),
}
_DTYPE_TO_BIT_WIDTH: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = {
TorchAODType.INT1: 1,
TorchAODType.INT2: 2,
TorchAODType.INT3: 3,
TorchAODType.INT4: 4,
TorchAODType.INT5: 5,
TorchAODType.INT6: 6,
TorchAODType.INT7: 7,
torch.uint8: 8,
torch.int8: 8,
torch.int16: 16,
torch.int32: 32,
}
_SUB_BYTE_UINT_BOUNDS: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = {}
_SUB_BYTE_INT_BOUNDS: Dict[Union[torch.dtype, TorchAODType], Tuple[int, int]] = {
TorchAODType.INT1: (-(2**0), 2**0 - 1),
TorchAODType.INT2: (-(2**1), 2**1 - 1),
TorchAODType.INT3: (-(2**2), 2**2 - 1),
TorchAODType.INT4: (-(2**3), 2**3 - 1),
TorchAODType.INT5: (-(2**4), 2**4 - 1),
TorchAODType.INT6: (-(2**5), 2**5 - 1),
TorchAODType.INT7: (-(2**6), 2**6 - 1),
}
# torch.uintX available only in PyTorch 2.3+
if TORCH_VERSION_AT_LEAST_2_3:
_SUB_BYTE_UINT_BOUNDS = {
torch.uint1: (0, 2**1 - 1),
torch.uint2: (0, 2**2 - 1),
torch.uint3: (0, 2**3 - 1),
torch.uint4: (0, 2**4 - 1),
torch.uint5: (0, 2**5 - 1),
torch.uint6: (0, 2**6 - 1),
torch.uint7: (0, 2**7 - 1),
}
_DTYPE_TO_BIT_WIDTH.update(
{
torch.uint1: 1,
torch.uint2: 2,
torch.uint3: 3,
torch.uint4: 4,
torch.uint5: 5,
torch.uint6: 6,
torch.uint7: 7,
}
)
# torch.intX available only in PyTorch 2.6+
if TORCH_VERSION_AT_LEAST_2_6:
_SUB_BYTE_INT_BOUNDS.update(
{
torch.int1: (-(2**0), 2**0 - 1),
torch.int2: (-(2**1), 2**1 - 1),
torch.int3: (-(2**2), 2**2 - 1),
torch.int4: (-(2**3), 2**3 - 1),
torch.int5: (-(2**4), 2**4 - 1),
torch.int6: (-(2**5), 2**5 - 1),
torch.int7: (-(2**6), 2**6 - 1),
}
)
_DTYPE_TO_BIT_WIDTH.update(
{
torch.int1: 1,
torch.int2: 2,
torch.int3: 3,
torch.int4: 4,
torch.int5: 5,
torch.int6: 6,
torch.int7: 7,
}
)
_DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_UINT_BOUNDS)
_DTYPE_TO_QVALUE_BOUNDS.update(_SUB_BYTE_INT_BOUNDS)
assert _DTYPE_TO_BIT_WIDTH.keys() == _DTYPE_TO_QVALUE_BOUNDS.keys()
_GGUF_QK_K = 256
_ONES_TABLE = [_n_ones(i) for i in range(8)]
quant_lib = torch.library.Library("torchao", "FRAGMENT")
register_custom_op = _register_custom_op(quant_lib)
# TODO: decide on if we want to allow custom quant_min/quant_max here
def _get_and_check_qmin_qmax(dtype, quant_min, quant_max):
"""Get quant_min and quant_max args based on dtype and also
verify that they are within the range of possible quant_min/quant_max
for dtype
"""
if dtype in FP8_TYPES:
quant_min_lower_bound, quant_max_upper_bound = (
torch.finfo(dtype).min,
torch.finfo(dtype).max,
)
elif dtype not in _DTYPE_TO_QVALUE_BOUNDS:
raise ValueError(f"Unsupported dtype: {dtype}")
else:
quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype]
if quant_min is None:
quant_min = quant_min_lower_bound
if quant_max is None:
quant_max = quant_max_upper_bound
assert quant_min >= quant_min_lower_bound, (
"quant_min out of bound for dtype, "
f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}"
)
assert quant_max <= quant_max_upper_bound, (
"quant_max out of bound for dtype, "
f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}"
)
return quant_min, quant_max
def _get_reduction_params(block_size, input_size):
"""Given block_size and input size find the parameters for reduction:
Output:
shape_for_reduction: the shape we use to `view` input to prepare it for reduction
reduction_dims: the dims we'll do reduction over
Example::
Input:
block_size: (3, 3, 2, 10)
input_size: (3, 3, 10, 10)
Output:
shape_for_reduction: (3, 3, 5, 2, 10)
reduction_dim: [0, 1, 3, 4]
"""
assert len(block_size) == len(input_size)
shape_for_reduction = []
reduction_dims = []
cur_dim = 0
for i in range(len(block_size)):
if block_size[i] != input_size[i] and block_size[i] > 1:
assert input_size[i] % block_size[i] == 0, (
f"Expecting input size at {i} dimension: {input_size[i]} to be divisible by block_size at {i} dimension: {block_size[i]}"
)
shape_for_reduction.append(input_size[i] // block_size[i])
shape_for_reduction.append(block_size[i])
# reduce over the block_size[i] dim
reduction_dims.append(cur_dim + 1)
cur_dim += 2
else:
# block_size[i] == input_size[i] or block_size[i] == 1
shape_for_reduction.append(input_size[i])
# we only need to reduce over the dimension if block_size is greater than 1
# otherwise it's already the same as reduced dimension
if block_size[i] != 1:
reduction_dims.append(cur_dim)
cur_dim += 1
return shape_for_reduction, reduction_dims
[docs]@torch.no_grad()
def quantize_affine(
input: torch.Tensor,
block_size: Tuple[int, ...],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
output_dtype: torch.dtype,
quant_min: Optional[Union[int, float]] = None,
quant_max: Optional[Union[int, float]] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
) -> torch.Tensor:
"""
Args:
input (torch.Tensor): original float32, float16 or bfloat16 Tensor
block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
scale (float): quantization parameter for affine quantization
zero_point (int): quantization parameter for affine quantization
output_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype
quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT
Note:
How can block_size represent different granularities?
let's say we have a Tensor of size: (3, 3, 10, 10), here is the table showing how block_size represents different
granularities:
granularity type | block_size
per_tensor | (3, 3, 10, 10)
per_axis (axis=0) | (1, 3, 10, 10)
per_axis (axis=1) | (3, 1, 10, 10)
per_group (groupsize=2) | (3, 3, 10, 2)
per_group (groupsize=2) for axis = 3 | (3, 3, 2, 10)
Output:
quantized tensor with requested dtype
"""
if zero_point_domain is None:
raise ValueError("Please use ZeroPointDomain.NONE instead of None")
elif zero_point_domain is ZeroPointDomain.NONE and zero_point is not None:
raise ValueError("zero_point should be None when zero_point_domain is NONE")
return _quantize_affine(
input,
block_size,
scale,
zero_point,
output_dtype,
quant_min,
quant_max,
zero_point_domain.name,
)
@register_custom_op
def _quantize_affine(
input: torch.Tensor,
block_size: List[int],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
output_dtype: torch.dtype,
quant_min: Optional[Union[int, float, bool]] = None,
quant_max: Optional[Union[int, float, bool]] = None,
zero_point_domain: str = ZeroPointDomain.INT.name,
) -> torch.Tensor:
"""op definition that has compatible signatures with custom op library
Note:
zero_point_domain is optional specifies how we quantize the floating point to quantized data:
INT: quantized_val = (float_val / scale) (integer) + zero_point (integer)
FLOAT: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale
None: quantized_val = (float_val / scale) | this is primarily used for floatx quantization
Where we do not want to round values to nearest integer and instead scale and cast.
"""
quant_min, quant_max = _get_and_check_qmin_qmax(output_dtype, quant_min, quant_max)
# workaround for uintx dtypes, since we don't have native Uintx dtype connected with
# torch.uintx dtypes yet
if output_dtype in _SUB_BYTE_UINT_BOUNDS:
output_dtype = torch.uint8
return _quantize_affine_no_dtype_cast(
input,
block_size,
scale,
zero_point,
quant_min,
quant_max,
output_dtype,
zero_point_domain,
).to(output_dtype)
def _quantize_affine_no_dtype_cast(
input: torch.Tensor,
block_size: List[int],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
quant_min: Union[int, float],
quant_max: Union[int, float],
quant_dtype: torch.dtype,
zero_point_domain: str = ZeroPointDomain.INT.name,
) -> torch.Tensor:
"""
The op does the following:
1. figure out the dimension for reduction based on block_size, also reshape the input to align with
the shape after reduction
2. quantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain
3. reshape the quantized result to origianl shape
"""
# TODO: validations
# TODO: validate scale/zero_point dimensions are compatible with block_size
assert input.dtype in [
torch.float32,
torch.float16,
torch.bfloat16,
], f"Unsupported input dtype: {input.dtype}"
assert len(block_size) == input.dim(), (
f"Got input dim:{input.dim()}, block_size: {block_size}"
)
shape_for_reduction, reduction_dims = _get_reduction_params(
block_size, input.size()
)
original_shape = input.shape
input = input.view(shape_for_reduction)
shape_after_reduction = shape_for_reduction
for i in reduction_dims:
shape_after_reduction[i] = 1
scale = scale.view(shape_after_reduction)
if zero_point is not None and zero_point.numel() > 0:
zero_point = zero_point.view(shape_after_reduction)
else:
# in some cases zero_point being a non-value shows as a tensor
# with numel=0 which we handle by unifying the two
zero_point = None
if zero_point_domain == ZeroPointDomain.INT.name:
quant = torch.clamp(
torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max
)
elif zero_point_domain == ZeroPointDomain.NONE.name:
assert zero_point is None, (
"zero_point should be None when zero_point_domain is NONE"
)
if _is_float8_type(quant_dtype):
quant = torch.clamp(input * scale.reciprocal(), quant_min, quant_max)
else:
quant = torch.clamp(
torch.round(input * (1.0 / scale)), quant_min, quant_max
)
else:
assert zero_point_domain == ZeroPointDomain.FLOAT.name
mid_point = (quant_max + quant_min + 1) / 2
min_val = zero_point - scale * mid_point
quant = torch.clamp(
torch.round((input - min_val) / scale), quant_min, quant_max
)
quant = quant.view(original_shape)
return quant
[docs]def dequantize_affine(
input: torch.Tensor,
block_size: Tuple[int, ...],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
input_dtype: torch.dtype,
quant_min: Optional[Union[int, float]] = None,
quant_max: Optional[Union[int, float]] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
*,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Args:
input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument
block_size: (List[int]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
scale (Tensor): quantization parameter for affine quantization
zero_point (Tensor): quantization parameter for affine quantization
input_dtype (torch.dtype): requested dtype (e.g. torch.uint8) for output Tensor
quant_min (Optional[int]): minimum quantized value for input Tensor
quant_max (Optional[int]): maximum quantized value for input Tensor
output_dtype (torch.dtype): dtype for output Tensor, default is fp32
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT
Output:
dequantized Tensor, with requested dtype or fp32
"""
if zero_point_domain is None:
raise ValueError("Please use ZeroPointDomain.NONE instead of None")
elif zero_point_domain is ZeroPointDomain.NONE and zero_point is not None:
raise ValueError("zero_point should be None when zero_point_domain is NONE")
return _dequantize_affine(
input,
block_size,
scale,
zero_point,
input_dtype,
quant_min,
quant_max,
zero_point_domain.name,
output_dtype=output_dtype,
)
@register_custom_op
def _dequantize_affine(
input: torch.Tensor,
block_size: List[int],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
input_dtype: torch.dtype,
quant_min: Optional[Union[int, float, bool]] = None,
quant_max: Optional[Union[int, float, bool]] = None,
zero_point_domain: Optional[str] = ZeroPointDomain.INT.name,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""op definition that has compatible signatures with custom op library"""
# TODO: validate scale/zero_point dimensions are compatible with block_size
if input_dtype not in _SUB_BYTE_UINT_BOUNDS:
assert input.dtype == input_dtype, (
f"Expected: {input_dtype}, got: {input.dtype}"
)
assert output_dtype in [
torch.float32,
torch.float16,
torch.bfloat16,
], f"Unsupported output dtype: {output_dtype}"
quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max)
return _dequantize_affine_no_dtype_check(
input,
block_size,
scale,
zero_point,
quant_min,
quant_max,
zero_point_domain,
output_dtype,
)
def _dequantize_affine_no_dtype_check(
input: torch.Tensor,
block_size: List[int],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
quant_min: Union[int, float],
quant_max: Union[int, float],
zero_point_domain: Optional[str] = ZeroPointDomain.INT.name,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""This function converts AQT tensors to their high precision floating point representation
The op does the following:
1. figure out the dimension for reduction based on block_size, also reshape the input to align with
the shape after reduction
2. dequantize the input based on the quantization parameters scale and zero_point and args like zero_point_domain
3. reshape the quantized result to origianl shape and change dtype to the output_dtype
"""
assert len(block_size) == input.dim(), (
f"Got input dim:{input.dim()}, block_size: {block_size}"
)
shape_for_reduction, reduction_dims = _get_reduction_params(
block_size, input.size()
)
original_shape = input.shape
input = input.view(shape_for_reduction)
shape_after_reduction = shape_for_reduction
for i in reduction_dims:
shape_after_reduction[i] = 1
scale = scale.view(shape_after_reduction)
if zero_point is not None:
zero_point = zero_point.view(shape_after_reduction)
if zero_point_domain == ZeroPointDomain.INT.name:
# Force a copy to avoid input modification due
# to upcoming in-place operations.
dequant = input.to(torch.int32, copy=True)
if zero_point is not None:
dequant = dequant - zero_point.to(torch.int32)
dequant = dequant.to(output_dtype)
dequant = dequant * scale
elif zero_point_domain == ZeroPointDomain.NONE.name:
assert zero_point is None, (
"zero_point should be None when zero_point_domain is NONE"
)
dequant = input.to(output_dtype)
dequant = dequant * scale
else:
assert zero_point_domain == ZeroPointDomain.FLOAT.name, (
f"Unexpected zero point domain: {zero_point_domain}"
)
# TODO: this seems to be a detail for tinygemm (converting from uint to int, probably need to refactor this)
mid_point = (quant_max + quant_min + 1) / 2
# This should allocate new memory and avoid input modification
dequant = input - mid_point
dequant = dequant.to(output_dtype)
dequant *= scale
if zero_point is not None:
dequant += zero_point
return dequant.view(original_shape).to(output_dtype)
[docs]def fake_quantize_affine(
input: torch.Tensor,
block_size: Tuple[int, ...],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
quant_dtype: torch.dtype,
quant_min: Optional[Union[int, float]] = None,
quant_max: Optional[Union[int, float]] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
) -> torch.Tensor:
"""
General fake quantize op for quantization-aware training (QAT).
This is equivalent to calling `quantize_affine` + `dequantize_affine`
but without the dtype casts.
Args:
input (torch.Tensor): original float32, float16 or bfloat16 Tensor
block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
scale (float): quantization parameter for affine quantization
zero_point (int): quantization parameter for affine quantization
quant_dtype (torch.dtype): desired quantized dtype for determining and validating quant_min and quant_max values.
quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype
quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT
"""
if zero_point_domain is None:
raise ValueError("Please use ZeroPointDomain.NONE instead of None")
elif zero_point_domain is ZeroPointDomain.NONE and zero_point is not None:
raise ValueError("zero_point should be None when zero_point_domain is NONE")
(_, fq) = _do_fake_quantize_affine(
input,
block_size,
scale,
zero_point,
quant_dtype,
quant_min,
quant_max,
zero_point_domain,
)
return fq
[docs]def fake_quantize_affine_cachemask(
input: torch.Tensor,
block_size: Tuple[int, ...],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
quant_dtype: torch.dtype,
quant_min: Optional[Union[int, float]] = None,
quant_max: Optional[Union[int, float]] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
General fake quantize op for quantization-aware training (QAT).
This is equivalent to calling `quantize_affine` + `dequantize_affine`
but without the dtype casts.
Note: Compared to :func:`~torchao.quantization.quant_primitives.fake_quantize_affine`,
this consumes more memory and returns an additional outlier mask for
intermediate quantized values.
Args:
Same as :func:`~torchao.quantization.quant_primitives.fake_quantize_affine`.
Returns:
A 2-tuple of (
final fake quantized values,
outlier mask for intermediate quantized values
)
"""
if zero_point_domain is None:
raise ValueError("Please use ZeroPointDomain.NONE instead of None")
elif zero_point_domain is None and zero_point is not None:
raise ValueError("zero_point should be None when zero_point_domain is NONE")
(q, dq) = _do_fake_quantize_affine(
input,
block_size,
scale,
zero_point,
quant_dtype,
quant_min,
quant_max,
zero_point_domain,
)
mask = torch.logical_and((q >= quant_min), (q <= quant_max))
return (dq, mask)
def _do_fake_quantize_affine(
input: torch.Tensor,
block_size: Tuple[int, ...],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
quant_dtype: torch.dtype,
quant_min: Optional[Union[int, float]] = None,
quant_max: Optional[Union[int, float]] = None,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Helper function for `fake_quantize_affine` that returns both the
intermediate quantized values and the final dequantized values.
"""
input_dtype = input.dtype
quant_min, quant_max = _get_and_check_qmin_qmax(quant_dtype, quant_min, quant_max)
q = _quantize_affine_no_dtype_cast(
input,
block_size,
scale,
zero_point,
quant_min,
quant_max,
quant_dtype,
zero_point_domain.name,
)
dq = _dequantize_affine_no_dtype_check(
q,
block_size,
scale,
zero_point,
quant_min,
quant_max,
zero_point_domain.name,
output_dtype=input_dtype,
)
return (q, dq)
[docs]@torch.no_grad()
def choose_qparams_affine(
input: torch.Tensor,
mapping_type: MappingType,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
quant_min: Optional[Union[int, float]] = None,
quant_max: Optional[Union[int, float]] = None,
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
input (torch.Tensor): fp32, bf16, fp16 input Tensor
mapping_type (MappingType): determines how the qparams are calculated, symmetric or asymmetric
block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
target_dtype (torch.dtype): dtype for target quantized Tensor
quant_min (Optional[int]): minimum quantized value for target quantized Tensor
quant_max (Optioanl[int]): maximum quantized value for target quantized Tensor
eps (Optional[float]): minimum scale, if not provided, default to eps of input.dtype
scale_dtype (torch.dtype): dtype for scale Tensor
zero_point_dtype (torch.dtype): dtype for zero_point Tensor
preserve_zero (bool): a flag to indicate whether we need zero to be exactly
representable or not, this is typically required for ops that needs zero padding, like convolution
it's less important for ops that doesn't have zero padding in the op itself, like linear.
For example, given a floating point Tensor [1.2, 0.1, 3.0, 4.0, 0.4, 0], if `preserve_zero` is True,
we'll make sure there is a integer value corresponding to the floating point 0, e.g. [-3, -8, 3, 7, -7, -8], 0 will be mapped to `-8` without loss. But if `preserve_zero` is not True, there won't be such
gurantee.
If we don't need zero to be exactly representable, we won't do rounding and clamping for zero_point
zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be either integer or float
if zero_point is in integer domain, zero point is added to the quantized integer value during
quantization
if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
value during quantization
default is ZeroPointDomain.INT
Output:
Tuple of scales and zero_points Tensor with requested dtype
"""
if zero_point_domain is None:
raise ValueError("Please use ZeroPointDomain.NONE instead of None")
return _choose_qparams_affine(
input,
mapping_type.name,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
preserve_zero,
zero_point_domain.name,
)
[docs]def choose_qparams_affine_with_min_max(
min_val: torch.Tensor,
max_val: torch.Tensor,
mapping_type: MappingType,
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""A variant of :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`
operator that pass in min_val and max_val directly instead of deriving these from a single input.
This is used for observers in static quantization where min_val and max_val may be obtained through
tracking all the data in calibration data set.
Args:
Mostly same as :func:`~torchao.quantization.quant_primitives.choose_qparams_affine`. with one
difference: instead of passing in `input` Tensor and use that to calculate min_val/max_val
and then scale/zero_point, we pass in min_val/max_val directly
"""
if zero_point_domain is None:
raise ValueError("Please use ZeroPointDomain.NONE instead of None")
return _choose_qparams_affine(
None,
mapping_type.name,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
scale_dtype,
zero_point_dtype,
preserve_zero,
zero_point_domain.name,
min_val,
max_val,
)
@register_custom_op
def _choose_qparams_affine(
input: Optional[torch.Tensor],
mapping_type: str,
block_size: List[int],
target_dtype: torch.dtype,
quant_min: Optional[Union[int, float, bool]] = None,
quant_max: Optional[Union[int, float, bool]] = None,
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
preserve_zero: bool = True,
zero_point_domain: Optional[str] = "INT",
min_val: Optional[torch.Tensor] = None,
max_val: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""op definition that has compatible signatures with custom op library
The op does the following:
1. figure out the dimension for reduction based on block_size
2. find min_val/max_val based on the dimension for reduction
3. calculate quantization parameters based on min_val/max_val based on args like `preserve_zero`
and `zero_point_domain`
"""
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
assert mapping_type in [
MappingType.SYMMETRIC.name,
MappingType.SYMMETRIC_NO_CLIPPING_ERR.name,
MappingType.ASYMMETRIC.name,
], f"Unsupported mapping type: {mapping_type}"
if target_dtype in FP8_TYPES:
assert mapping_type == MappingType.SYMMETRIC.name, (
f"Only symmetric quantization is supported for FP8 types, got {mapping_type}"
)
if input is not None:
if scale_dtype is None:
scale_dtype = input.dtype
if eps is None:
eps = torch.finfo(input.dtype).eps
assert len(block_size) == input.dim(), (
f"Got input dim:{input.dim()}, block_size: {block_size}"
)
shape_for_reduction, reduction_dims = _get_reduction_params(
block_size, input.size()
)
input = input.view(shape_for_reduction)
min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)
else:
assert min_val is not None and max_val is not None, (
"Need to provide `min_val` and `max_val` when `input` is None, got: {min_val, max_val}"
)
assert min_val.dtype == max_val.dtype, (
"Expecting `min_val` and `max_val` to have the same dtype, got: {min_val.dtype, max_val.dtype}"
)
if scale_dtype is None:
scale_dtype = min_val.dtype
if eps is None:
eps = torch.finfo(min_val.dtype).eps
if preserve_zero:
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
else:
min_val_neg = min_val
max_val_pos = max_val
if (
mapping_type == MappingType.SYMMETRIC.name
or mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR.name
):
# scales
if mapping_type == MappingType.SYMMETRIC.name:
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
else:
assert mapping_type == MappingType.SYMMETRIC_NO_CLIPPING_ERR.name
# calculate smin and smax individually and choose the larger one. For example, if quant_min = -8 and
# quant_max = 7.
# - If smin is bigger: There would be coverage on negative values down to -8, and less rounding
# error than the existing SYMMETRIC case.
# - If smax is bigger: it covers the positive values up to 7. The round
# error may be bigger than the existing SYMMETRIC case. Either way, there's no out-of-range fp values after
# quantization.
smin = min_val_neg / float(quant_min)
smax = max_val_pos / float(quant_max)
mask = smin > smax
scale = torch.where(mask, smin, smax)
# zeros
if not preserve_zero:
raise ValueError(
"preserve_zero == False is not supported for symmetric quantization"
)
if zero_point_domain == ZeroPointDomain.FLOAT.name:
# TODO INT should not be a valid ZeroPointDomain for symmetric quantization since
# symmetric quant doesn't have a zero_point
raise ValueError(
"zero_point_domain should be ZeroPointDomain.INT or ZeroPointDomain.NONE for symmetric quantization"
)
if zero_point_domain == ZeroPointDomain.NONE.name:
zero_point = None
else:
zero_point = torch.full_like(scale, int((quant_max + quant_min + 1) / 2))
scale = torch.clamp(scale, min=eps)
else:
assert mapping_type == MappingType.ASYMMETRIC.name
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
scale = torch.clamp(scale, min=eps)
if zero_point_domain == ZeroPointDomain.NONE.name:
zero_point = None
elif zero_point_domain == ZeroPointDomain.INT.name:
zero_point = quant_min - torch.round(min_val_neg / scale)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
if zero_point_dtype is None:
zero_point_dtype = torch.int32
else:
assert zero_point_domain == ZeroPointDomain.FLOAT.name, (
"zero_point must be in FLOAT/INT/None domain for asymmetric quantization"
)
mid_point = (quant_max + quant_min + 1) / 2
# this is not preserving zero_point, this is converting to TensorCoreTiledFormat
# TODO move the conversion of zero_point out of quant_primitives
# and into TensorCoreTiledLayout.from_plain
zero_point = min_val_neg + scale * mid_point
if zero_point is not None:
zero_point = zero_point.to(dtype=zero_point_dtype)
return scale.to(dtype=scale_dtype), zero_point
def choose_qparams_and_quantize_affine_qqq(
w: torch.Tensor,
num_bits: int,
group_size: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert num_bits == 4, f"Unsupported num_bits = {num_bits}"
size_n, size_k = w.shape
assert group_size in [-1, 128, size_k], f"Unsupported groupsize = {group_size}"
orig_device = w.device
if group_size == -1:
group_size = size_k
if group_size < size_k:
# Reshape to [-1, group_size]
w = w.reshape((-1, group_size))
max_q_val = 2**num_bits - 1
half_q_val = (max_q_val + 1) // 2
# Compute scale for each group
s_group = torch.amax(torch.abs(w), -1, keepdim=True)
s_group *= 2 / max_q_val # 2 => symmetric
# Quantize
q_w = torch.round(w / s_group).int()
q_w += half_q_val
q_w = torch.clamp(q_w, 0, max_q_val)
# Compute ref (dequantized)
w_ref = (q_w - half_q_val).half() * s_group
# Restore original shapes
def reshape_w(w):
w = w.reshape((size_n, size_k)).contiguous()
return w
q_w = reshape_w(q_w)
w_ref = reshape_w(w_ref)
# Compute int8 quantization scale for each channel
s_channel = torch.amax(torch.abs(w_ref), -1, keepdim=True)
s_channel /= 127.0
t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8)
w_ref = t_int8.half() * s_channel
s_channel = s_channel.reshape(-1, 1).to(dtype=torch.float)
# Fuse scales
s_group = (s_group.reshape(size_n, -1).contiguous() / s_channel).to(
dtype=torch.half
)
else:
max_q_val = 2 ** (num_bits - 1) - 1
# Compute scale for each channel
s_channel = torch.amax(torch.abs(w), -1, keepdim=True)
s_channel /= max_q_val
# Quantize
q_w = torch.round(w / s_channel).int()
q_w = torch.clamp(q_w, -max_q_val, max_q_val)
# Compute ref (dequantized)
w_ref = q_w.half() * s_channel
s_group = torch.tensor([], dtype=torch.half, device=orig_device)
# div 2 ** (8 - self.bits)) to offset right shift in unpacking
s_channel /= 2 ** (8 - num_bits)
s_channel = s_channel.reshape(size_n, -1).contiguous().to(torch.float)
return q_w, s_group, s_channel, w_ref
def choose_qparams_gguf(
input: Optional[torch.Tensor],
block_size: List[int],
target_dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
There are two sets of qparams: quantized_block_scale, quantized_block_min and super_block_scale_scale and super_block_min_scale
the relationship is the following:
block_scale = quantized_block_scale * super_block_sclae
block_min = quantized_block_min * super_block_min
quantized_val = (float_val - block_min) / block_scale + quant_min
first we calculate block_scale and block_min
then we calculate super_block_scale_scale and super_block_min_scale
after that we can calculate quantized_block_scale and quantized_min_scale
the returned values are: super_block_scale_scale, super_block_min_scale, quantized_block_scale
and quantized_min_scale
"""
dtype = input.dtype
# 1. get block_scale block_min
shape_for_reduction, reduction_dims = _get_reduction_params(
block_size, input.size()
)
input = input.view(shape_for_reduction)
min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)
quant_max = 15
quant_min = 0
# asymmetric quant to fully utilize the range
block_scale = max_val / (float(quant_max - quant_min) / 2)
block_scale = (max_val - min_val) / float(quant_max - quant_min)
block_min = min_val
# 2. get super_block_scale_scale and super_block_min_scale
assert _GGUF_QK_K % block_size[-1] == 0
super_block_size = (1, _GGUF_QK_K // block_size[-1])
shape_for_reduction, reduction_dims = _get_reduction_params(
super_block_size, block_scale.size()
)
block_scale = block_scale.view(shape_for_reduction)
block_min = block_min.view(shape_for_reduction)
shape_after_reduction = shape_for_reduction.copy()
for i in reduction_dims:
shape_after_reduction[i] = 1
block_scale_absmax = torch.amax(
torch.abs(block_scale), dim=reduction_dims, keepdim=False
)
block_min_absmax = torch.amax(
torch.abs(block_min), dim=reduction_dims, keepdim=False
)
# 2. get super_block_scale_scale and super_block_min_scale
# TODO: make this configurable
# we also quantize the quantization parameters (scale and min) for each block to 6 bit
# for Q4_K
qparam_quant_max = 2**6 - 1
qparam_quant_min = 0
super_block_scale_scale = block_scale_absmax / float(
qparam_quant_max - qparam_quant_min
)
super_block_min_scale = block_min_absmax / float(
qparam_quant_max - qparam_quant_min
)
super_block_scale_scale_view = super_block_scale_scale.view(shape_after_reduction)
super_block_min_scale_view = super_block_min_scale.view(shape_after_reduction)
# 3. quantize block scale and min are stored in 6 bits using super_block_scale_scale and super_block_min_scale
quantized_block_scale = torch.clamp(
block_scale / super_block_scale_scale_view, qparam_quant_min, qparam_quant_max
)
quantized_block_min = torch.clamp(
block_min / super_block_min_scale_view, qparam_quant_min, qparam_quant_max
)
return (
super_block_scale_scale.to(dtype),
super_block_min_scale.to(dtype),
quantized_block_scale.to(dtype),
quantized_block_min.to(dtype),
)
def quantize_gguf(
input: torch.Tensor,
block_size: List[int],
target_dtype: torch.dtype,
super_block_scale_scale: torch.Tensor,
super_block_min_scale: torch.Tensor,
quantized_block_scale: torch.Tensor,
quantized_block_min: torch.Tensor,
) -> torch.Tensor:
assert target_dtype == torch.uint4
# step 1: first order quantization
# just going through shape calculation for block_scale and block_min to get the correct shape
input_shape_for_reduction, reduction_dims = _get_reduction_params(
block_size, input.size()
)
block_qparam_shape_after_reduction = input_shape_for_reduction.copy()
for i in reduction_dims:
block_qparam_shape_after_reduction[i] = 1
original_shape = input.shape
input = input.view(input_shape_for_reduction)
quantized_block_scale = quantized_block_scale.view(
block_qparam_shape_after_reduction
)
quantized_block_min = quantized_block_min.view(block_qparam_shape_after_reduction)
# step 2: second order quantization, recover unquantized block_scale and block_min
super_block_size = (1, _GGUF_QK_K // block_size[-1], 1)
super_block_input_shape_for_reduction, reduction_dims = _get_reduction_params(
super_block_size, quantized_block_scale.size()
)
super_block_qparam_shape_after_reduction = (
super_block_input_shape_for_reduction.copy()
)
for i in reduction_dims:
super_block_qparam_shape_after_reduction[i] = 1
quantized_block_scale = quantized_block_scale.view(
super_block_input_shape_for_reduction
)
quantized_block_min = quantized_block_min.view(
super_block_input_shape_for_reduction
)
super_block_scale_scale = super_block_scale_scale.view(
super_block_qparam_shape_after_reduction
)
super_block_min_scale = super_block_min_scale.view(
super_block_qparam_shape_after_reduction
)
block_scale = super_block_scale_scale * quantized_block_scale
block_min = super_block_min_scale * quantized_block_min
# step 3: quantization with the unquantized block_scale and block_min
block_scale = block_scale.view(block_qparam_shape_after_reduction)
block_min = block_min.view(block_qparam_shape_after_reduction)
int_data = (input - block_min) / block_scale
int_data = int_data.view(original_shape)
return int_data
def dequantize_gguf(
input: torch.Tensor,
block_size: List[int],
target_dtype: torch.dtype,
super_block_scale_scale: torch.Tensor,
super_block_min_scale: torch.Tensor,
quantized_block_scale: torch.Tensor,
quantized_block_min: torch.Tensor,
output_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
# step 1. reshape input and quantized block scale and min to the shape
# after first quantization
input_shape_for_reduction, reduction_dims = _get_reduction_params(
block_size, input.size()
)
block_qparam_shape_after_reduction = input_shape_for_reduction.copy()
for i in reduction_dims:
block_qparam_shape_after_reduction[i] = 1
original_shape = input.shape
input = input.view(input_shape_for_reduction)
quantized_block_scale = quantized_block_scale.view(
block_qparam_shape_after_reduction
)
quantized_block_min = quantized_block_min.view(block_qparam_shape_after_reduction)
# step 2. calculate and reshape block_qparams for second quantization step
super_block_size = (1, _GGUF_QK_K // block_size[-1], 1)
super_block_input_shape_for_reduction, reduction_dims = _get_reduction_params(
super_block_size, quantized_block_scale.size()
)
super_block_qparam_shape_after_reduction = (
super_block_input_shape_for_reduction.copy()
)
for i in reduction_dims:
super_block_qparam_shape_after_reduction[i] = 1
quantized_block_scale = quantized_block_scale.view(
super_block_input_shape_for_reduction
)
quantized_block_min = quantized_block_min.view(
super_block_input_shape_for_reduction
)
super_block_scale_scale = super_block_scale_scale.view(
super_block_qparam_shape_after_reduction
)
super_block_min_scale = super_block_min_scale.view(
super_block_qparam_shape_after_reduction
)
block_scale = super_block_scale_scale * quantized_block_scale
block_min = super_block_min_scale * quantized_block_min
# step 3. dequantize with block_scale and block_min
block_scale = block_scale.view(block_qparam_shape_after_reduction)
block_min = block_min.view(block_qparam_shape_after_reduction)
dequant = input * block_scale + block_min
dequant = dequant.view(original_shape)
if output_dtype is not None:
dequant = dequant.to(output_dtype)
return dequant
def dequantize_affine_qqq(
w: torch.Tensor,
s_group: torch.Tensor,
s_channel: torch.Tensor,
num_bits: int,
group_size: int,
output_dtype: Optional[torch.dtype] = None,
):
assert num_bits == 4, f"Unsupported num_bits = {num_bits}"
size_n, size_k = w.shape
assert group_size in [-1, 128, size_k], f"Unsupported groupsize = {group_size}"
if group_size == -1:
group_size = size_k
if group_size < size_k:
# Reshape to [-1, group_size]
w = w.reshape((-1, group_size))
max_q_val = 2**num_bits - 1
half_q_val = (max_q_val + 1) // 2
s_group = s_group * s_channel.half()
w_dq = (w - half_q_val).half() * s_group.reshape(-1, 1)
# Restore original shapes
def reshape_w(w):
w = w.reshape((size_n, size_k)).contiguous()
return w
w_dq = reshape_w(w_dq)
else:
s_channel = s_channel * (2 ** (8 - num_bits))
w_dq = w.half() * s_channel
if output_dtype is None:
w_dq = w_dq.to(torch.float16)
else:
w_dq = w_dq.to(output_dtype)
return w_dq
# HQQ
############################################################################
# Shrinking operator (proximal operator for the lp norm)
def _shrink_lp_op(x: torch.Tensor, beta: float, lp_norm: float) -> torch.Tensor:
if lp_norm == 1:
return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta)
else:
return torch.sign(x) * torch.nn.functional.relu(
torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x), lp_norm - 1)
)
# Proximal solver || W - dequantize(quantize(W))||_p^p
@torch.inference_mode()
def optimize_weights_proximal_legacy(
tensor: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor,
min_max: list,
axis: int = 0,
dtype: Union[torch.dtype, None] = None,
device: Union[str, None] = None,
verbose: bool = False,
opt_params: dict = {
"lp_norm": 0.7,
"beta": 1e1,
"kappa": 1.01,
"iters": 20,
"early_stop": True,
},
) -> tuple:
lp_norm, beta, kappa, iters, early_stop = (
opt_params["lp_norm"],
opt_params["beta"],
opt_params["kappa"],
opt_params["iters"],
opt_params["early_stop"],
)
device = tensor.device if (device is None) else torch.device(device)
if dtype is None:
dtype = torch.float16 if (device.type == "cuda") else torch.float32
W_f = tensor.to(dtype=dtype, device=device)
scale = scale.to(dtype=dtype, device=device)
zero = zero.to(dtype=dtype, device=device)
best_error = 1e4
for i in range(iters):
W_q = torch.round(W_f * scale + zero).clamp(min_max[0], min_max[1])
W_r = (W_q - zero) / scale
W_e = _shrink_lp_op(W_f - W_r, beta, lp_norm)
zero = torch.mean(W_q - (W_f - W_e) * scale, axis=axis, keepdim=True)
beta *= kappa
current_error = float(torch.abs(W_f - W_r).mean())
if verbose:
print("Iter " + str(i + 1), " | Error: " + str(current_error))
if early_stop:
if current_error < best_error:
best_error = current_error
else:
break
scale = scale.to(tensor.device)
zero = zero.to(tensor.device)
del W_f, W_q, W_r, W_e
torch.cuda.empty_cache()
W_q = torch.round(tensor * scale + zero).clamp(min_max[0], min_max[1])
return W_q, scale, zero
# Mainly used to check if the group-size is divisible by numel()
def _is_divisible(val1: int, val2: int) -> bool:
return int(val2 * math.ceil(val1 / val2)) == val1
# Converts hqq format W_dequant = (W_q - zero)*scale into affinequantized format: (W_q - mid_point)*scale_ao + zero_ao
def _convert_to_affinequantized_format(
W_q: torch.Tensor,
scale: torch.Tensor,
zero: torch.Tensor,
nbits: int,
shape: Union[List, Tuple, torch.Size],
) -> Tuple:
quant_min = 0
quant_max = 2**nbits - 1
mid_point = (quant_max + quant_min + 1) / 2
zero_ao = ((mid_point - zero.float()) * scale.float()).to(zero.dtype)
scale_ao = scale
W_q_ao = W_q.view(shape)
return W_q_ao, scale_ao, zero_ao
# Main hqq quantizer function
[docs]def choose_qparams_and_quantize_affine_hqq(
tensor: torch.Tensor,
nbits: float = 4,
group_size: int = 64,
optimize: bool = True,
axis: int = 1,
compute_dtype: torch.dtype = torch.float16,
device: str = "cuda",
verbose: bool = False, # to check the optimizer error
raw_output: bool = False, # If True, it will return the quant params in hqq lib format
optimize_weights: Callable = optimize_weights_proximal_legacy, # weights proximal optimizer function
) -> tuple:
assert axis in [0, 1], "axis should be either 0 or 1"
if group_size is not None:
assert _is_divisible(tensor.numel(), group_size), (
"group_size should be divisble by the total tensor dimensions. shape: "
+ str(tensor.shape)
+ ", group_size: "
+ str(group_size)
)
# It's better to work with float32 here
W = tensor.to(device=device, dtype=torch.float32)
shape = W.shape
# Reshape for grouping
if group_size is not None:
W = W.reshape([-1, group_size]) if (axis == 1) else W.reshape([group_size, -1])
# Get min/max values
_min = W.min(axis=axis, keepdim=True)[0]
_max = W.max(axis=axis, keepdim=True)[0]
max_v = round(2**nbits - 1)
min_v = 0
min_max = [min_v, max_v]
# Clamp to avoid fp16 issues
scale = (max_v / (_max - _min)).clamp(max=2e4)
zero = -_min * scale
# Round zero as in: https://github.com/casper-hansen/AutoAWQ/blob/main/awq/quantize/quantizer.py#L42C9-L42C14
if nbits in [4]:
zero = torch.round(zero)
# Fine-tune weights
if optimize:
W_q, scale, zero = optimize_weights(
tensor=W,
scale=scale,
zero=zero,
min_max=min_max,
axis=axis,
device=device,
verbose=verbose,
)
else:
zero = zero.to(compute_dtype)
scale = scale.to(compute_dtype)
W_q = torch.round(W * scale + zero).clamp(min_max[0], min_max[1])
# Store meta-data (we invert the scale for dequantization)
scale = 1.0 / scale
# Convert to TensorCoreTiled format
# TODO move the conversion of zero_point out of quant_primitives
# and into TensorCoreTiledLayout.from_plain and rename this
# helper function correctly.
if raw_output is False:
W_q, scale, zero = _convert_to_affinequantized_format(
W_q, scale, zero, nbits, shape
)
else:
# this path was not used before, the way hqq sets up scale/zero is transposed
# compared to the rest of our utils so we need to reshape them acccordingly.
W_q = W_q.reshape(shape)
if axis == 1:
scale = scale.reshape(shape[0], -1)
zero = zero.reshape(shape[0], -1)
else:
scale = scale.reshape(-1, shape[-1])
zero = zero.reshape(-1, shape[-1])
# Make sure all the weights are in the right compute_dtype/device
W_q = W_q.to(dtype=torch.uint8, device=device)
scale = scale.to(dtype=compute_dtype, device=device)
zero = zero.to(dtype=compute_dtype, device=device)
# cleanup
del W, _min, _max
torch.cuda.empty_cache()
return W_q, scale, zero, shape
[docs]def choose_qparams_affine_floatx(
tensor: torch.Tensor, ebits: int, mbits: int
) -> torch.Tensor:
# _n_ones() is not compatible with torch.compile() due to << operator
# https://github.com/pytorch/pytorch/issues/119152
# exp_bias = _n_ones(ebits - 1)
# max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2 ** mbits))
# workaround: global lookup table
exp_bias = _ONES_TABLE[ebits - 1]
max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (
_ONES_TABLE[mbits + 1] / (2**mbits)
)
dtype = tensor.dtype
tensor = tensor.float()
scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal
return scale.to(dtype)
[docs]def quantize_affine_floatx(
tensor: torch.Tensor, scale: torch.Tensor, ebits: int, mbits: int
) -> torch.Tensor:
"""Quantizes the float32 high precision floating point tensor to low precision floating point number and
converts the result to unpacked floating point format with the format of 00SEEEMM (for fp6_e3m2) where S means sign bit, e means exponent bit and m means mantissa bit
"""
tensor = tensor.float()
tensor_floatx = _f32_to_floatx_unpacked(tensor / scale.view(-1, 1), ebits, mbits)
return tensor_floatx
[docs]def dequantize_affine_floatx(
tensor: torch.Tensor,
scale: torch.Tensor,
ebits: int,
mbits: int,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
tensor = _floatx_unpacked_to_f32(tensor, ebits, mbits)
tensor = tensor * scale.float().view(-1, 1)
tensor = tensor.to(dtype=output_dtype)
return tensor
def choose_qparams_affine_float8(
tensor: torch.Tensor,
float8_dtype: torch.dtype = torch.float8_e4m3fn,
) -> torch.Tensor:
"""
Calculates float8 scaling factor for the given high precision tensor, using tensorwise granularity.
Args:
tensor (torch.Tensor): Input tensor to be quantized.
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
"""
# only tensorwise scaling is supported for now:
quant_min, quant_max = torch.finfo(float8_dtype).min, torch.finfo(float8_dtype).max
min_val_neg = torch.min(tensor)
max_val_pos = torch.max(tensor)
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
return scale.to(dtype=torch.float32)
def quantize_affine_float8(
tensor: torch.Tensor,
scale: torch.Tensor,
float8_dtype: torch.dtype = torch.float8_e4m3fn,
) -> torch.Tensor:
"""
Quantizes the high precision floating point tensor to a float8 tensor, using the given scaling factor.
Args:
tensor (torch.Tensor): Input tensor to be quantized.
scale (torch.Tensor): Scaling factor for the quantization.
float8_dtype (torch.dtype): Data type of the quantized tensor (e.g., torch.float8_e4m3fn, torch.float8_e5m2).
"""
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
# upcasted to `float32` to multiply with the scale, since scale is a fp32 tensor in float8 quantization.
# In order to match numerics between eager and compile, we upcast manually here.
tensor_scaled = tensor.to(torch.float32) / scale
max_value = torch.finfo(float8_dtype).max
tensor_clamped = tensor_scaled.clamp(min=-max_value, max=max_value)
fp8_tensor = tensor_clamped.to(float8_dtype)
return fp8_tensor
def dequantize_affine_float8(
tensor: torch.Tensor,
scale: torch.Tensor,
output_dtype: torch.dtype = torch.float32,
) -> torch.Tensor:
"""
Dequantizes the float8 tensor to high precision tensor.
Args:
tensor (torch.Tensor): Input float8 tensor to be dequantized.
scale (torch.Tensor): Scaling factor for the dequantization.
output_dtype (torch.dtype): Data type of the output tensor (e.g., torch.float32).
"""
# Note: when the line below is compiled with `torch.compile`, `tensor` is automatically
# upcasted to `float32` to divide by the scale, since scale is a fp32 for float8 quantization.
# In order to match numerics between eager and compile, we upcast manually here.
fp8_tensor = tensor.to(torch.float32)
hp_tensor = fp8_tensor * scale
return hp_tensor.to(output_dtype)