Shortcuts

Source code for torchao.dtypes.uintx.tensor_core_tiled_layout

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
from torch.utils._python_dispatch import (
    is_traceable_wrapper_subclass,
    return_and_correct_aliasing,
)

from torchao.dtypes.affine_quantized_tensor import (
    AffineQuantizedTensor,
    register_layout,
)
from torchao.dtypes.utils import AQTTensorImpl, Layout, is_device
from torchao.quantization.quant_primitives import ZeroPointDomain, _get_reduction_params
from torchao.utils import (
    TORCH_VERSION_AT_LEAST_2_5,
    fill_defaults,
    find_multiple,
)

aten = torch.ops.aten


def _aqt_is_tensor_core_tile_uint4(aqt):
    """Check if an AffineQuantizedTensor is uint4 quantized Tensor"""
    # TODO: use torch.uint4
    return (
        aqt.tensor_impl.dtype == torch.int32
        and aqt.quant_min == 0
        and aqt.quant_max == 15
    )


def _same_metadata(
    self: "TensorCoreTiledAQTTensorImpl", src: "TensorCoreTiledAQTTensorImpl"
) -> bool:
    return (
        isinstance(self, TensorCoreTiledAQTTensorImpl)
        and isinstance(src, TensorCoreTiledAQTTensorImpl)
        and self.shape == src.shape
        and self.packed_weight.shape == src.packed_weight.shape
        and self.scale_and_zero.shape == src.scale_and_zero.shape
        and self.transposed == src.transposed
        and type(self._layout) == type(src._layout)
    )


def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias):
    return (
        # input is native bfloat16 tensor
        not is_traceable_wrapper_subclass(input_tensor)
        and input_tensor.dtype == torch.bfloat16
        and
        # weight is uint4, group quantized tensor_core_tiled tensor impl affine quantized tensor
        isinstance(weight_tensor, AffineQuantizedTensor)
        and _aqt_is_tensor_core_tile_uint4(weight_tensor)
        and weight_tensor.dtype == torch.bfloat16
        and len(weight_tensor.shape) == 2
        and weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT
        and isinstance(weight_tensor._layout, TensorCoreTiledLayout)
    )


def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias):
    assert weight_tensor.block_size[0] == 1, (
        f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}"
    )
    assert input_tensor.shape[-1] == weight_tensor.shape[1], (
        f"need input_tensor shape: {input_tensor.shape} final"
        f"dim to match weight_tensor shape: {weight_tensor.shape} second dim "
    )

    # TODO: check groupsize quantization
    # avoid circular dep, TODO: move this to a common util.py
    act_mat = input_tensor
    # weight is packed from padded (out_features, in_features) weight tensor
    # (same dimension requirement as F.linear weight)
    packed_weight = weight_tensor.tensor_impl.packed_weight
    scale_and_zero = weight_tensor.tensor_impl.scale_and_zero

    orig_act_size = act_mat.size()
    orig_dtype = act_mat.dtype

    # reshape and pad activation
    act_mat = act_mat.reshape(-1, act_mat.shape[-1]).to(torch.bfloat16)
    pad_size = find_multiple(act_mat.shape[-1], 1024)
    act_mat = torch.nn.functional.pad(act_mat, (0, pad_size - act_mat.shape[-1]))

    # groupwise int4 quantization
    groupsize = weight_tensor.block_size[-1]
    if act_mat.numel() == 0:  # handling for empty input
        y = act_mat
    else:
        y = torch.ops.aten._weight_int4pack_mm(
            act_mat.contiguous(), packed_weight, groupsize, scale_and_zero
        )
    # remove out_feature padding
    orig_out_features = weight_tensor.shape[-2]
    y = y[:, :orig_out_features]
    y = y.reshape(*orig_act_size[:-1], orig_out_features)

    if bias is not None:
        y += bias
    return y.to(orig_dtype)


[docs]@dataclass(frozen=True) class TensorCoreTiledLayout(Layout): """TensorCoreTiledLayout is a layout class for handling tensor core tiled layouts in affine quantized tensors. It provides methods for pre-processing and post-processing tensors to fit the required layout for efficient computation on tensor cores. Attributes: inner_k_tiles (int): An internal argument for the packing function of tensor core tiled layout that can affect the performance of the matmul kernel. Defaults to 8. """ inner_k_tiles: int = 8 def pre_process(self, input: torch.Tensor) -> torch.Tensor: orig_out_features, orig_in_features = input.shape[-2:] in_features = find_multiple(orig_in_features, 1024) out_features = find_multiple(orig_out_features, 8) input = torch.nn.functional.pad( input, (0, in_features - orig_in_features, 0, out_features - orig_out_features), ) return input def pre_process_static( self, input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, block_size: Tuple[int, ...], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: input = self.pre_process(input) orig_qparam_shape = scale.shape new_qparam_shape, reduction_dims = _get_reduction_params( block_size, input.size() ) for dim in reduction_dims: new_qparam_shape.pop(dim) change_in_qparam_shape = [ new_dim_size - orig_dim_size for new_dim_size, orig_dim_size in zip(new_qparam_shape, orig_qparam_shape) ] padding_changes = [] for dim_change in change_in_qparam_shape: padding_changes = [0, dim_change] + padding_changes scale = torch.nn.functional.pad(scale, padding_changes) zero_point = torch.nn.functional.pad(zero_point, padding_changes) return input, scale, zero_point def post_process( self, input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, block_size: Tuple[int, ...], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: orig_out_features, orig_in_features = input.shape[-2:] in_features = find_multiple(orig_in_features, 1024) out_features = find_multiple(orig_out_features, 8) input = torch.nn.functional.pad( input, (0, in_features - orig_in_features, 0, out_features - orig_out_features), ) assert len(block_size) == 2 or len(block_size) == 3, ( f"TensorCoreTiledLayout only supports len(block_size) == 2 or 3, got: {block_size}" ) scale_pad_dim_0 = (out_features - orig_out_features) // block_size[-2] scale_pad_dim_1 = (in_features - orig_in_features) // block_size[-1] scale = torch.nn.functional.pad(scale, (0, scale_pad_dim_1, 0, scale_pad_dim_0)) zero_point = torch.nn.functional.pad( zero_point, (0, scale_pad_dim_1, 0, scale_pad_dim_0) ) return input, scale, zero_point def extra_repr(self): return f"inner_k_tiles={self.inner_k_tiles}"
@register_layout(TensorCoreTiledLayout) class TensorCoreTiledAQTTensorImpl(AQTTensorImpl): """TensorImpl for tensor_core_tiled layout for affine quantized tensor, this is for int4 only, used by tinygemm kernels `_weight_int4pack_mm` It stores the original tensor of dimension [n][k] (int32 dtype) as packed weight of 4-d tensor of dimension: [n / 8][k / (inner_k_tiles * 16)][32][inner_k_tiles / 2] (unpacked Tensor shape is n * k) where inner_k_tiles is an internal argument for packing function of tensor core tiled layout that can affect the performance of the matmul kernel (defaults to 8) Note: we also pack scale and zero point together here for tinygemm kernel Note: technically tensor core tiled layout should be the layout for the underlying packed weight (int Tensor) but since the scale and zero_point are also packed into the same tensor here which is not used in plain layout, we just created a layout for AQT right now, this could be improved if we split out int4 aqt into a separate tensor subclass fields: packed_weight (torch.Tensor): the 4-d packed tensor in a tensor_core_tiled layout scale_and_zero (torch.Tensor): the combined scale Tensor used to map between floating point tensor to quantized tensor and zero_point Tensor """ def __new__( cls, packed_weight: torch.Tensor, scale_and_zero: torch.Tensor, transposed: bool, _layout: Layout, ): kwargs = {} kwargs["device"] = packed_weight.device kwargs["layout"] = ( kwargs.get("layout") if kwargs.get("layout", False) else packed_weight.layout ) kwargs["dtype"] = packed_weight.dtype kwargs["requires_grad"] = False shape = packed_weight.shape return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] def __init__( self, packed_weight: torch.Tensor, scale_and_zero: torch.Tensor, transposed: bool, _layout: Layout, ): self.packed_weight = packed_weight self.scale_and_zero = scale_and_zero self.transposed = False self._layout = _layout def __tensor_flatten__(self): return ["packed_weight", "scale_and_zero"], [self.transposed, self._layout] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): packed_weight, scale_and_zero = ( tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"], ) ( transposed, _layout, ) = tensor_attributes return cls(packed_weight, scale_and_zero, transposed, _layout) @classmethod def from_plain( cls, int_data: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor], _layout: Layout, ): assert isinstance(_layout, TensorCoreTiledLayout) assert int_data.dtype == torch.int32, ( "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" ) def quant_2d(int_data_2d): if TORCH_VERSION_AT_LEAST_2_5: int_data_2d = (int_data_2d[::, ::2] << 4 | int_data_2d[::, 1::2]).to( torch.uint8 ) else: assert int_data_2d.dtype == torch.int32, ( "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" ) return torch.ops.aten._convert_weight_to_int4pack( int_data_2d.contiguous(), _layout.inner_k_tiles ) if int_data.dim() == 3: # for moe quant num_experts = int_data.shape[0] packed_weight_list = [] for expert in range(num_experts): packed_weight_list.append(quant_2d(int_data[expert]).unsqueeze(0)) packed_weight = torch.cat(packed_weight_list, dim=0) scale = scale.reshape(int_data.shape[0], int_data.shape[-2], -1) zero_point = ( zero_point.reshape(int_data.shape[0], int_data.shape[-2], -1) if zero_point is not None else None ) else: assert int_data.dim() == 2 packed_weight = quant_2d(int_data) scale = scale.reshape(int_data.shape[0], -1) zero_point = ( zero_point.reshape(int_data.shape[0], -1) if zero_point is not None else None ) from torchao.quantization.utils import pack_tinygemm_scales_and_zeros scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point, scale.dtype) return cls(packed_weight, scale_and_zero, False, _layout) def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs["device"] # tensor core tiled layout supports both cpu and cuda but does not support the conversion # between these two devices, in the future we should not use the same layout for # cpu and cuda device: https://github.com/pytorch/ao/issues/1117 if not is_device(torch.device(self.device).type, device): raise ValueError( f"TensorCoreTiledAQTTensorImpl does not support conversion from {self.device} to {device}" ) return self.__class__( self.packed_weight.to(device), self.scale_and_zero.to(device), self.transposed, self._layout, ) def _apply_fn_to_data(self, fn): # self.packed_weight = fn(self.packed_weight) # self.scale_and_zero = fn(self.scale_and_zero) # return self return self.__class__( fn(self.packed_weight), fn(self.scale_and_zero), self.transposed, self._layout, ) @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): kwargs = {} if kwargs is None else kwargs if func is aten.detach.default: return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) if func is aten.clone.default: return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) ) if func is aten.copy_.default: self = args[0] src = args[1] if _same_metadata(self, src): self_tensors = self.__tensor_flatten__()[0] for tensor_name in self_tensors: getattr(self, tensor_name).copy_(getattr(src, tensor_name)) return raise ValueError( f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" ) if func in [aten.select.int, aten.index.Tensor]: assert not (func is aten.select.int and args[1] != 0), ( "aten.select.int currently only has support for dim=0" ) return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)), ) if func is aten.t.default: """we don't need to repack the weight and just rely on external shape being changed and record the status of transpose/no-transpose """ transposed = TensorCoreTiledAQTTensorImpl( args[0].packed_weight, args[0].scale_and_zero, not args[0].transposed, args[0]._layout, ) return return_and_correct_aliasing(func, args, kwargs, transposed) if func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) cur_shape = self.shape assert len(cur_shape) == 4 inner_k_tiles = cur_shape[-1] * 2 original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16)) n_by_8, k_by_inner_tiles, _, _ = self.packed_weight.shape sz_dim1, sz_dim0, _ = self.scale_and_zero.shape data_len = original_shape[dim] assert dim in [0, 1], ( f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" ) if dim == 0: pw_len = n_by_8 sz_len = sz_dim0 else: pw_len = k_by_inner_tiles sz_len = sz_dim1 if pw_len == 0 or sz_len == 0: return return_and_correct_aliasing( func, args, kwargs, TensorCoreTiledAQTTensorImpl( self.packed_weight, self.scale_and_zero, self.transposed, self._layout, ), ) pw_ratio = data_len / pw_len start_pw = int(start / pw_ratio) end_pw = int(end / pw_ratio) sz_ratio = data_len / sz_len start_sz = int(start / sz_ratio) end_sz = int(end / sz_ratio) packed_weight = aten.slice(self.packed_weight, dim, start_pw, end_pw, step) scale_and_zero = aten.slice( self.scale_and_zero, 1 - dim, start_sz, end_sz, step ) return return_and_correct_aliasing( func, args, kwargs, TensorCoreTiledAQTTensorImpl( packed_weight, scale_and_zero, self.transposed, self._layout ), ) raise NotImplementedError( f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, this is not supported" ) __torch_function__ = torch._C._disabled_torch_function_impl @property def block_size(self): from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) cur_shape = self.shape if len(cur_shape) == 5: ones = [1, 1] cur_shape = cur_shape[1:] else: assert len(cur_shape) == 4 ones = [1] inner_k_tiles = cur_shape[-1] * 2 original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16)) groupsize = int(original_shape[1] / scale.shape[-2]) return tuple([*ones, groupsize]) def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: from torchao.quantization.quant_primitives import ( ZeroPointDomain, quantize_affine, ) from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros def dequant_4d(self): cur_shape = self.shape scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) assert len(cur_shape) == 4 inner_k_tiles = cur_shape[-1] * 2 original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16)) eye_shape = original_shape[1] groupsize = int(original_shape[1] / scale.shape[-2]) block_size = (1, groupsize) original_dtype = torch.bfloat16 assert len(block_size) == 2 and block_size[0] == 1 dequantized = torch.ops.aten._weight_int4pack_mm( torch.eye(eye_shape, device=self.device, dtype=original_dtype), self.packed_weight, groupsize, self.scale_and_zero, ) dequantized = dequantized.t().contiguous() return dequantized cur_shape = self.shape if len(cur_shape) == 4: dequantized = dequant_4d(self) else: assert len(cur_shape) == 5 num_experts = cur_shape[0] dequantized_list = [] for expert in range(num_experts): dequantized_list.append(dequant_4d(self[expert]).unsqueeze(0)) dequantized = torch.cat(dequantized_list, dim=0) scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) # TODO: move this to `unpack_tinygemm_scales_and_zeros`? scale = scale.reshape(scale.shape[:-1]).contiguous() zero = zero.reshape(zero.shape[:-1]).contiguous() target_dtype = torch.int32 quant_min = 0 quant_max = 15 zero_point_domain = ZeroPointDomain.FLOAT int_data = quantize_affine( dequantized, self.block_size, scale, zero, target_dtype, quant_min, quant_max, zero_point_domain, ) return int_data, scale, zero def get_layout(self) -> Layout: return self._layout

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources