Source code for torchao.dtypes.uintx.cutlass_int4_packed_layout
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional
import torch
from torch.utils._python_dispatch import (
return_and_correct_aliasing,
)
from torchao.dtypes.affine_quantized_tensor import (
AffineQuantizedTensor,
register_layout,
)
from torchao.dtypes.uintx.plain_layout import (
_aqt_is_int8,
)
from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout
aten = torch.ops.aten
def _aqt_is_int4(aqt):
"""Check if an AffineQuantizedTensor is int4 quantized Tensor"""
# TODO: use torch.int4
return (
aqt.tensor_impl.dtype == torch.int8
and aqt.quant_min == -8
and aqt.quant_max == 7
)
def _same_metadata(self: "Int4PackedTensorImpl", src: "Int4PackedTensorImpl") -> bool:
return (
isinstance(self, Int4PackedTensorImpl)
and isinstance(src, Int4PackedTensorImpl)
and self.shape == src.shape
and self.int_data.shape == src.int_data.shape
and self.scale.shape == src.scale.shape
and type(self._layout) == type(src._layout)
)
[docs]@dataclass(frozen=True)
class CutlassInt4PackedLayout(Layout):
"""Layout class for int4 packed layout for affine quantized tensor, for cutlass kernel."""
pass
@register_layout(CutlassInt4PackedLayout)
class Int4PackedTensorImpl(AQTTensorImpl):
"""
TensorImpl storage class for int4 packed layout for affine quantized tensor.
"""
@staticmethod
def __new__(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
_layout: Layout,
):
kwargs = {}
kwargs["device"] = int_data.device
kwargs["layout"] = (
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
)
kwargs["dtype"] = int_data.dtype
kwargs["requires_grad"] = False
shape = int_data.shape
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
def __init__(
self,
int_data: torch.Tensor,
scale: torch.Tensor,
_layout: Layout,
):
self.int_data = int_data
self.scale = scale
self._layout = _layout
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
kwargs = {} if kwargs is None else kwargs
if func is aten.detach.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)
elif func is aten.copy_.default:
self = args[0]
src = args[1]
if _same_metadata(self, src):
self_tensors = self.__tensor_flatten__()[0]
for tensor_name in self_tensors:
getattr(self, tensor_name).copy_(getattr(src, tensor_name))
return
raise ValueError(
f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}"
)
raise NotImplementedError(
f"Int4PackedTensorImpl dispatch: attempting to run {func}, this is not supported"
)
def __tensor_flatten__(self):
return ["int_data", "scale"], [self._layout]
@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
int_data = tensor_data_dict["int_data"]
scale = tensor_data_dict["scale"]
(_layout,) = tensor_attributes
return cls(int_data, scale, _layout)
def get_plain(self):
int_data = torch.stack(
((self.int_data << 4) >> 4, self.int_data >> 4), dim=-1
).view(self.int_data.shape[:-1] + (2 * self.int_data.shape[-1],))
return int_data, self.scale, None
@classmethod
def from_plain(
cls,
int_data: torch.Tensor,
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
_layout: Layout,
):
assert zero_point is None or torch.all(zero_point == 0)
int_data_s4 = ((int_data[..., 1::2] & 0xF) << 4) | (int_data[..., 0::2] & 0xF)
return cls(
int_data_s4,
scale,
_layout,
)
def get_layout(self) -> Layout:
return self._layout
def _apply_fn_to_data(self, fn):
self.int_data = fn(self.int_data)
self.scale = fn(self.scale)
return self
def _linear_int8_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias):
return (
isinstance(input_tensor, AffineQuantizedTensor)
and isinstance(input_tensor._layout, PlainLayout)
and _aqt_is_int8(input_tensor)
and input_tensor.dtype in (torch.float16, torch.bfloat16)
and len(input_tensor.shape) >= 2
and input_tensor.tensor_impl.scale.dtype == torch.float32
and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1
and isinstance(weight_tensor, AffineQuantizedTensor)
and isinstance(weight_tensor._layout, CutlassInt4PackedLayout)
and _aqt_is_int4(weight_tensor)
and weight_tensor.dtype == input_tensor.dtype
and len(weight_tensor.shape) == 2
and weight_tensor.tensor_impl.scale.dtype == torch.float32
and len(weight_tensor.tensor_impl.scale.shape) == 1
and (bias is None or bias.dtype == input_tensor.dtype)
and (bias is None or len(bias.shape) == 1)
)
def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias):
from torchao.ops import rowwise_scaled_linear_cutlass_s8s4
weight = weight_tensor.tensor_impl.int_data
weight_scale = weight_tensor.tensor_impl.scale
input = input_tensor.tensor_impl.int_data
input_scale = input_tensor.tensor_impl.scale
out_dtype = input_tensor.dtype
out = rowwise_scaled_linear_cutlass_s8s4(
input, input_scale, weight, weight_scale, bias, out_dtype
)
return out
def _linear_int4_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias):
return (
isinstance(input_tensor, AffineQuantizedTensor)
and isinstance(input_tensor._layout, CutlassInt4PackedLayout)
and _aqt_is_int4(input_tensor)
and input_tensor.dtype in (torch.float16, torch.bfloat16)
and len(input_tensor.shape) >= 2
and input_tensor.tensor_impl.scale.dtype == torch.float32
and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1
and isinstance(weight_tensor, AffineQuantizedTensor)
and isinstance(weight_tensor._layout, CutlassInt4PackedLayout)
and _aqt_is_int4(weight_tensor)
and weight_tensor.dtype == input_tensor.dtype
and len(weight_tensor.shape) == 2
and weight_tensor.tensor_impl.scale.dtype == torch.float32
and len(weight_tensor.tensor_impl.scale.shape) == 1
)
def _linear_int4_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias):
from torchao.ops import rowwise_scaled_linear_cutlass_s4s4
weight = weight_tensor.tensor_impl.int_data
weight_scale = weight_tensor.tensor_impl.scale
input = input_tensor.tensor_impl.int_data
input_scale = input_tensor.tensor_impl.scale
out_dtype = input_tensor.dtype
out = rowwise_scaled_linear_cutlass_s4s4(
input, input_scale, weight, weight_scale, bias, out_dtype
)
return out