Source code for torchao.quantization.qat.embedding
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Optional
import torch
import torch.nn.functional as F
from torchao.quantization.quant_primitives import TorchAODType
from torchao.quantization.unified import TwoStepQuantizer
from torchao.quantization.utils import get_group_qparams_symmetric
from .api import FakeQuantizeConfig
from .fake_quantizer import FakeQuantizer
from .utils import (
_get_qmin_qmax,
)
class FakeQuantizedEmbedding(torch.nn.Embedding):
"""
General embedding layer with fake quantized weights.
Specific target dtypes, granularity, schemes etc. are specified
through separate configs for weights and activations.
Example usage::
weight_config = FakeQuantizeConfig(
dtype=torch.int4,
group_size=8,
symmetric=True,
)
fq_embedding = FakeQuantizedEmbedding(5, 10, weight_config)
fq_embedding(torch.LongTensor([3]))
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
weight_config: Optional[FakeQuantizeConfig] = None,
*args,
**kwargs,
) -> None:
super().__init__(
num_embeddings,
embedding_dim,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
*args,
**kwargs,
)
if weight_config is not None:
self.weight_fake_quantizer = FakeQuantizer(weight_config)
else:
self.weight_fake_quantizer = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.weight_fake_quantizer is not None:
w = self.weight_fake_quantizer(self.weight)
else:
w = self.weight
return F.embedding(
x,
w,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
def to_embedding(self) -> torch.nn.Embedding:
new_embedding = torch.nn.Embedding(
self.num_embeddings,
self.embedding_dim,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
device=self.weight.device,
dtype=self.weight.dtype,
)
# In distributed training, the model may be instantiated
# on the meta device, in which case there is no need to
# copy the weights, and doing so will result in an error
if self.weight.device != torch.device("meta"):
new_embedding.weight = self.weight
return new_embedding
@classmethod
def from_embedding(
cls,
mod: torch.nn.Embedding,
weight_config: Optional[FakeQuantizeConfig] = None,
):
new_embedding = FakeQuantizedEmbedding(
mod.num_embeddings,
mod.embedding_dim,
mod.padding_idx,
mod.max_norm,
mod.norm_type,
mod.scale_grad_by_freq,
mod.sparse,
weight_config=weight_config,
device=mod.weight.device,
dtype=mod.weight.dtype,
)
# In distributed training, the model may be instantiated
# on the meta device, in which case there is no need to
# copy the weights, and doing so will result in an error
if mod.weight.device != torch.device("meta"):
new_embedding.weight = mod.weight
return new_embedding
# ======================================
# | Embedding int4 weight-only QAT |
# ======================================
[docs]class Int4WeightOnlyEmbeddingQATQuantizer(TwoStepQuantizer):
"""
Quantizer for performing QAT on a model, where embedding layers have
int4 fake quantized grouped per channel weights.
"""
def __init__(
self,
group_size: int = 256,
scale_precision: torch.dtype = torch.float32,
zero_point_precision: torch.dtype = torch.int32,
) -> None:
super().__init__()
self.bit_width = 4
self.group_size: int = group_size
self.scale_precision: torch.dtype = scale_precision
self.zero_point_precision: torch.dtype = zero_point_precision
[docs] def prepare(
self, model: torch.nn.Module, *args: Any, **kwargs: Any
) -> torch.nn.Module:
"""
Swap `nn.Embedding` modules with `Int4WeightOnlyQATEmbedding`.
"""
# avoid circular imports
from torchao.quantization.quant_api import (
_replace_with_custom_fn_if_matches_filter,
)
def filter_fn(child: torch.nn.Module, cur_fqn: str) -> bool:
return isinstance(child, torch.nn.Embedding)
def replacement_fn(child: torch.nn.Module) -> torch.nn.Module:
new_embedding = Int4WeightOnlyQATEmbedding(
# nn.Embedding args
num_embeddings=child.num_embeddings,
embedding_dim=child.embedding_dim,
padding_idx=child.padding_idx,
max_norm=child.max_norm,
norm_type=child.norm_type,
scale_grad_by_freq=child.scale_grad_by_freq,
sparse=child.sparse,
# quantization args
group_size=self.group_size,
scale_precision=self.scale_precision,
zero_point_precision=self.zero_point_precision,
device=child.weight.device,
dtype=child.weight.dtype,
)
# In distributed training, the model may be instantiated
# on the meta device, in which case there is no need to
# copy the weights, and doing so will result in an error
if child.weight.device != torch.device("meta"):
new_embedding.weight = child.weight
return new_embedding
_replace_with_custom_fn_if_matches_filter(model, replacement_fn, filter_fn)
return model
[docs] def convert(
self, model: torch.nn.Module, *args: Any, **kwargs: Any
) -> torch.nn.Module:
"""
Swap all `Int4WeightOnlyQATEmbedding` modules with `Int4WeightOnlyEmbedding`.
"""
self._convert_helper(model)
return model
def _convert_helper(self, module: torch.nn.Module):
"""
Helper function to recursively swap `Int4WeightOnlyQATEmbedding`
modules with `Int4WeightOnlyEmbedding`
"""
from torchao._executorch_ops import (
_quantized_decomposed_quantize_per_channel_group_wrapper,
)
for name, child in module.named_children():
if isinstance(child, Int4WeightOnlyQATEmbedding):
group_size = child.weight_fake_quantizer.config.group_size
scale_precision = child.weight_fake_quantizer.config.scale_precision
zero_point_precision = (
child.weight_fake_quantizer.config.zero_point_precision
)
quantized_embedding = Int4WeightOnlyEmbedding(
# nn.Embedding args
num_embeddings=child.num_embeddings,
embedding_dim=child.embedding_dim,
padding_idx=child.padding_idx,
max_norm=child.max_norm,
norm_type=child.norm_type,
scale_grad_by_freq=child.scale_grad_by_freq,
sparse=child.sparse,
# quantization args
group_size=group_size,
scale_precision=scale_precision,
zero_point_precision=zero_point_precision,
device=child.weight.device,
output_dtype=child.weight.dtype,
)
setattr(module, name, quantized_embedding)
# Load weights and qparams into quantized embedding
(qmin, qmax) = _get_qmin_qmax(self.bit_width)
(s, zp) = get_group_qparams_symmetric(
child.weight,
self.bit_width,
group_size,
precision=scale_precision,
)
zp = zp.to(zero_point_precision)
q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper(
child.weight,
s,
zp,
qmin,
qmax,
torch.int8,
group_size,
)
quantized_embedding.weight = q_weight
quantized_embedding.scale = s.to(scale_precision)
quantized_embedding.zero_point = zp.to(zero_point_precision)
else:
self._convert_helper(child)
class Int4WeightOnlyQATEmbedding(FakeQuantizedEmbedding):
"""
This module implements a embedding layer with int4 fake quantized
grouped per channel weights.
args:
group_size: the number of elements in each quantized group for weights
scale_precision: precision of per group scales
zero_point_precision: precision of per group zero points
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
group_size: int = 32,
scale_precision: torch.dtype = torch.float32,
zero_point_precision: torch.dtype = torch.int32,
*args,
**kwargs,
):
weight_config = FakeQuantizeConfig(
dtype=TorchAODType.INT4,
group_size=group_size,
is_symmetric=True,
is_dynamic=True,
scale_precision=scale_precision,
zero_point_precision=zero_point_precision,
)
super().__init__(
num_embeddings,
embedding_dim,
padding_idx,
max_norm,
norm_type,
scale_grad_by_freq,
sparse,
weight_config,
*args,
**kwargs,
)
def enable_fake_quant(self, enabled: bool = True):
self.weight_fake_quantizer.enabled = enabled
def disable_fake_quant(self):
self.enable_fake_quant(False)
class Int4WeightOnlyEmbedding(torch.nn.Module):
"""
This module implements a embedding layer with int4 quantized
grouped per channel weights.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: Optional[int] = None,
max_norm: Optional[float] = None,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False,
group_size: int = 32,
scale_precision: torch.dtype = torch.float32,
zero_point_precision: torch.dtype = torch.int32,
device: torch.device = None,
output_dtype: torch.dtype = torch.float32,
):
super().__init__()
# nn.Embedding args
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
# quantization args
self.bit_width = 4
self.group_size = group_size
self.scale_precision = scale_precision
self.zero_point_precision = zero_point_precision
self.output_dtype = output_dtype
# currently storing unpacked int8 weights
self.register_buffer(
"weight",
torch.empty(
(num_embeddings, embedding_dim), dtype=torch.int8, device=device
),
)
self.register_buffer(
"scale",
torch.empty(
(num_embeddings, embedding_dim // group_size),
dtype=scale_precision,
device=device,
),
)
self.register_buffer(
"zero_point",
torch.empty(
(num_embeddings, embedding_dim // group_size),
dtype=zero_point_precision,
device=device,
),
)
def forward(self, x):
from torchao.quantization.quant_primitives import (
dequantize_affine,
)
qmin, qmax = _get_qmin_qmax(self.bit_width)
# dequantize_affine casts to output_dtype before scaling
# dequantize_per_channel_group scales and then casts to output_dtype
# The two do not agree when dtype != torch.float32
w_dq = dequantize_affine(
self.weight,
[1, self.group_size],
self.scale,
self.zero_point,
torch.int8,
qmin,
qmax,
output_dtype=self.output_dtype,
)
return F.embedding(
x,
w_dq,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)