Shortcuts

Source code for torchao.dtypes.uintx.block_sparse_layout

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import logging
from dataclasses import dataclass
from typing import Optional, Tuple

import torch
from torch.utils._python_dispatch import (
    return_and_correct_aliasing,
)

from torchao.dtypes.affine_quantized_tensor import (
    AffineQuantizedTensor,
    register_layout,
)
from torchao.dtypes.uintx.plain_layout import (
    PlainAQTTensorImpl,
    _aqt_is_int8_reduced_range,
)
from torchao.dtypes.utils import (
    Layout,
    PlainLayout,
)

logger = logging.getLogger(__name__)

aten = torch.ops.aten


[docs]@dataclass(frozen=True) class BlockSparseLayout(Layout): """BlockSparseLayout is a data class that represents the layout of a block sparse matrix. Attributes: blocksize (int): The size of the blocks in the sparse matrix. Default is 64. """ blocksize: int = 64
@register_layout(BlockSparseLayout) class BlockSparseAQTTensorImpl(PlainAQTTensorImpl): bsr_crow_indices: Optional[torch.Tensor] bsr_col_indices: Optional[torch.Tensor] bsr_values: Optional[torch.Tensor] scale: Optional[torch.Tensor] zero_point: Optional[torch.Tensor] __slots__ = [ "bsr_crow_indices", "bsr_col_indices", "bsr_values", "scale", "zero_point", ] @staticmethod def __new__( # noqa: PYI034 cls, shape: torch.Size, bsr_crow_indices: Optional[torch.Tensor], bsr_col_indices: Optional[torch.Tensor], bsr_values: Optional[torch.Tensor], scale: Optional[torch.Tensor], zero_point: Optional[torch.Tensor], _layout: Layout, requires_grad: bool = False, ): if bsr_values is None: raise ValueError("bsr values must be provided!") else: previous_tensor = bsr_values kwargs = { "device": previous_tensor.device, "dtype": previous_tensor.dtype, "layout": previous_tensor.layout, "requires_grad": requires_grad, } return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] def __init__( # noqa: PYI034 self, shape: torch.Size, bsr_crow_indices: Optional[torch.Tensor], bsr_col_indices: Optional[torch.Tensor], bsr_values: Optional[torch.Tensor], scale: Optional[torch.Tensor], zero_point: Optional[torch.Tensor], _layout: Layout, requires_grad: bool = False, ): self.bsr_crow_indices = bsr_crow_indices self.bsr_col_indices = bsr_col_indices self.bsr_values = bsr_values self.scale = scale self.zero_point = zero_point self._layout = _layout def __tensor_flatten__(self): inner_tensors = list( filter(lambda x: getattr(self, x) is not None, self.__slots__) ) tensor_meta = (self.shape, self._layout, self.requires_grad) return inner_tensors, tensor_meta @classmethod def __tensor_unflatten__( cls, inner_tensors, tensor_meta: Tuple[torch.Size, bool], outer_size, outer_stride, ) -> torch.Tensor: shape, _layout, requires_grad = tensor_meta return cls( shape=shape, bsr_crow_indices=inner_tensors.get("bsr_crow_indices", None), bsr_col_indices=inner_tensors.get("bsr_col_indices", None), bsr_values=inner_tensors.get("bsr_values", None), scale=inner_tensors.get("scale", None), zero_point=inner_tensors.get("zero_point", None), _layout=_layout, requires_grad=requires_grad, ) @classmethod def from_plain(cls, int_data, scale, zero_point, _layout): bsr_tensor = int_data.to_sparse_bsr(_layout.blocksize) return cls( shape=int_data.shape, bsr_crow_indices=bsr_tensor.crow_indices(), bsr_col_indices=bsr_tensor.col_indices(), bsr_values=bsr_tensor.values(), scale=scale, zero_point=zero_point, _layout=_layout, requires_grad=False, ) def get_plain(self): int_data_expanded = torch.ops.blocksparse.bsr_to_dense( self.crow_indices(), self.col_indices(), self.values(), self.shape[0], self.shape[1], ) return int_data_expanded, self.scale, self.zero_point def _apply_fn_to_data(self, func): return self.__class__( shape=self.shape, bsr_crow_indices=func(self.bsr_crow_indices), bsr_col_indices=func(self.bsr_col_indices), bsr_values=func(self.bsr_values), scale=self.scale, zero_point=self.zero_point, _layout=self._layout, requires_grad=self.requires_grad, ) @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): kwargs = {} if kwargs is None else kwargs if func is aten.detach.default: return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) if func is aten.clone.default: return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) ) # Need the following for bsr specific functions if func is aten.crow_indices.default: return args[0].bsr_crow_indices.detach() if func is aten.col_indices.default: return args[0].bsr_col_indices.detach() if func is aten.values.default: return args[0].bsr_values.detach() if func is aten._nnz.default: return args[0].bsr_values.shape[0] raise NotImplementedError( f"BlockSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" ) def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, bias): return ( isinstance(input_tensor, AffineQuantizedTensor) and _aqt_is_int8_reduced_range(input_tensor) and isinstance(weight_tensor, AffineQuantizedTensor) and weight_tensor.is_cuda and input_tensor.dtype == weight_tensor.dtype and isinstance(input_tensor._layout, PlainLayout) and isinstance(weight_tensor._layout, BlockSparseLayout) ) def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, bias): x_vals_int8 = input_tensor.tensor_impl.int_data x_scales = input_tensor.tensor_impl.scale w_vals = weight_tensor.tensor_impl w_scales = weight_tensor.tensor_impl.scale tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) tmp_t = tmp.t() y = torch.ops.blocksparse.int_addmm( w_vals.crow_indices(), w_vals.col_indices(), w_vals.values(), tmp_t, w_scales, x_scales.reshape(-1), ) y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1]) y = y.reshape(*y_shape) # can downcast only at the very end output_dtype = input_tensor.dtype y = y.to(output_dtype) if bias is not None: y += bias return y

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources