Shortcuts

Source code for torchao.dtypes.uintx.cutlass_int4_packed_layout

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional

import torch
from torch.utils._python_dispatch import (
    return_and_correct_aliasing,
)

from torchao.dtypes.affine_quantized_tensor import (
    AffineQuantizedTensor,
    register_layout,
)
from torchao.dtypes.uintx.plain_layout import (
    _aqt_is_int8,
)
from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout

aten = torch.ops.aten


def _aqt_is_int4(aqt):
    """Check if an AffineQuantizedTensor is int4 quantized Tensor"""
    # TODO: use torch.int4
    return (
        aqt.tensor_impl.dtype == torch.int8
        and aqt.quant_min == -8
        and aqt.quant_max == 7
    )


def _same_metadata(self: "Int4PackedTensorImpl", src: "Int4PackedTensorImpl") -> bool:
    return (
        isinstance(self, Int4PackedTensorImpl)
        and isinstance(src, Int4PackedTensorImpl)
        and self.shape == src.shape
        and self.int_data.shape == src.int_data.shape
        and self.scale.shape == src.scale.shape
        and type(self._layout) == type(src._layout)
    )


[docs]@dataclass(frozen=True) class CutlassInt4PackedLayout(Layout): """Layout class for int4 packed layout for affine quantized tensor, for cutlass kernel.""" pass
@register_layout(CutlassInt4PackedLayout) class Int4PackedTensorImpl(AQTTensorImpl): """ TensorImpl storage class for int4 packed layout for affine quantized tensor. """ @staticmethod def __new__( cls, int_data: torch.Tensor, scale: torch.Tensor, _layout: Layout, ): kwargs = {} kwargs["device"] = int_data.device kwargs["layout"] = ( kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout ) kwargs["dtype"] = int_data.dtype kwargs["requires_grad"] = False shape = int_data.shape return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] def __init__( self, int_data: torch.Tensor, scale: torch.Tensor, _layout: Layout, ): self.int_data = int_data self.scale = scale self._layout = _layout @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): kwargs = {} if kwargs is None else kwargs if func is aten.detach.default: return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) elif func is aten.copy_.default: self = args[0] src = args[1] if _same_metadata(self, src): self_tensors = self.__tensor_flatten__()[0] for tensor_name in self_tensors: getattr(self, tensor_name).copy_(getattr(src, tensor_name)) return raise ValueError( f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" ) raise NotImplementedError( f"Int4PackedTensorImpl dispatch: attempting to run {func}, this is not supported" ) def __tensor_flatten__(self): return ["int_data", "scale"], [self._layout] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): int_data = tensor_data_dict["int_data"] scale = tensor_data_dict["scale"] (_layout,) = tensor_attributes return cls(int_data, scale, _layout) def get_plain(self): int_data = torch.stack( ((self.int_data << 4) >> 4, self.int_data >> 4), dim=-1 ).view(self.int_data.shape[:-1] + (2 * self.int_data.shape[-1],)) return int_data, self.scale, None @classmethod def from_plain( cls, int_data: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor], _layout: Layout, ): assert zero_point is None or torch.all(zero_point == 0) int_data_s4 = ((int_data[..., 1::2] & 0xF) << 4) | (int_data[..., 0::2] & 0xF) return cls( int_data_s4, scale, _layout, ) def get_layout(self) -> Layout: return self._layout def _apply_fn_to_data(self, fn): self.int_data = fn(self.int_data) self.scale = fn(self.scale) return self def _linear_int8_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): return ( isinstance(input_tensor, AffineQuantizedTensor) and isinstance(input_tensor._layout, PlainLayout) and _aqt_is_int8(input_tensor) and input_tensor.dtype in (torch.float16, torch.bfloat16) and len(input_tensor.shape) >= 2 and input_tensor.tensor_impl.scale.dtype == torch.float32 and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 and isinstance(weight_tensor, AffineQuantizedTensor) and isinstance(weight_tensor._layout, CutlassInt4PackedLayout) and _aqt_is_int4(weight_tensor) and weight_tensor.dtype == input_tensor.dtype and len(weight_tensor.shape) == 2 and weight_tensor.tensor_impl.scale.dtype == torch.float32 and len(weight_tensor.tensor_impl.scale.shape) == 1 and (bias is None or bias.dtype == input_tensor.dtype) and (bias is None or len(bias.shape) == 1) ) def _linear_int8_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): from torchao.ops import rowwise_scaled_linear_cutlass_s8s4 weight = weight_tensor.tensor_impl.int_data weight_scale = weight_tensor.tensor_impl.scale input = input_tensor.tensor_impl.int_data input_scale = input_tensor.tensor_impl.scale out_dtype = input_tensor.dtype out = rowwise_scaled_linear_cutlass_s8s4( input, input_scale, weight, weight_scale, bias, out_dtype ) return out def _linear_int4_act_int4_weight_cutlass_check(input_tensor, weight_tensor, bias): return ( isinstance(input_tensor, AffineQuantizedTensor) and isinstance(input_tensor._layout, CutlassInt4PackedLayout) and _aqt_is_int4(input_tensor) and input_tensor.dtype in (torch.float16, torch.bfloat16) and len(input_tensor.shape) >= 2 and input_tensor.tensor_impl.scale.dtype == torch.float32 and len(input_tensor.tensor_impl.scale.shape) == len(input_tensor.shape) - 1 and isinstance(weight_tensor, AffineQuantizedTensor) and isinstance(weight_tensor._layout, CutlassInt4PackedLayout) and _aqt_is_int4(weight_tensor) and weight_tensor.dtype == input_tensor.dtype and len(weight_tensor.shape) == 2 and weight_tensor.tensor_impl.scale.dtype == torch.float32 and len(weight_tensor.tensor_impl.scale.shape) == 1 ) def _linear_int4_act_int4_weight_cutlass_impl(input_tensor, weight_tensor, bias): from torchao.ops import rowwise_scaled_linear_cutlass_s4s4 weight = weight_tensor.tensor_impl.int_data weight_scale = weight_tensor.tensor_impl.scale input = input_tensor.tensor_impl.int_data input_scale = input_tensor.tensor_impl.scale out_dtype = input_tensor.dtype out = rowwise_scaled_linear_cutlass_s4s4( input, input_scale, weight, weight_scale, bias, out_dtype ) return out

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources