Shortcuts

Source code for torchao.dtypes.utils

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch

from torchao.utils import TorchAOBaseTensor

"""
Base class for different layout, following the same design of PyTorch layout
https://pytorch.org/docs/stable/tensor_attributes.html#torch-layout, used to represent different
data layout of a Tensor, it's used in conjunction with TensorImpl to represent custom data layout.

As a native PyTorch example, Sparse Coordinate format Tensor (https://pytorch.org/docs/stable/generated/torch.sparse_coo_tensor.html#torch-sparse-coo-tensor) has `torch.sparse_coo` layout, which is backed up by
`SparseImpl`: https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/SparseTensorImpl.h which stores two Tensors (indices_ and values_)

We extended the layout in torchao with Layout class (instead of torch.layout objects), also we use tensor subclass to implement TensorImpl classes.

Layout also allows users to pass around configurations for the TensorImpl,
e.g. inner_k_tiles for int4 tensor core tiled TensorImpl

Note: Layout is an abstraction not only for custom data representation, it is also used for how the
Tensor interacts with different operators, e.g. the same data representation can have different
behaviors when running the same operator, e.g. transpose, quantized_linear. This is the same as layout
in PyTorch native Tensor
"""


[docs]@dataclass(frozen=True) class Layout: """The Layout class serves as a base class for defining different data layouts for tensors. It provides methods for pre-processing and post-processing tensors, as well as static pre-processing with additional parameters like scale, zero_point, and block_size. The Layout class is designed to be extended by other layout classes that define specific data representations and behaviors for tensors. It is used in conjunction with TensorImpl classes to represent custom data layouts and how tensors interact with different operators. """ def pre_process(self, input: torch.Tensor) -> torch.Tensor: return input def post_process( self, input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, block_size: Tuple[int, ...], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return input, scale, zero_point def pre_process_static( self, input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, block_size: Tuple[int, ...], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return self.pre_process(input), scale, zero_point def __repr__(self): return f"{self.__class__.__name__}({self.extra_repr()})" def extra_repr(self) -> str: return ""
[docs]@dataclass(frozen=True) class PlainLayout(Layout): """PlainLayout is the most basic layout class, inheriting from the Layout base class. It does not add any additional metadata or processing steps to the tensor. Typically, this layout is used as the default when no specific layout is required. """ pass
def is_device(target_device_str: str, device: Union[str, torch.device]): return torch.device(device).type == target_device_str def get_out_shape(input_shape: Tuple[int], weight_shape: Tuple[int]) -> Tuple[int, int]: """Returns the unflattened shape of the input tensor. Args: input_shape: The input tensor shape possibly more than 2 dimensions weight_shape: The weight tensor shape. Returns: The unflattened shape of the input tensor. """ out_dim = weight_shape[0] inpt_dims = input_shape[:-1] return (*inpt_dims, out_dim) ############################### # Base Tensor Impl Subclass # ############################### class AQTTensorImpl(TorchAOBaseTensor): """ Base class for the tensor impl for `AffineQuantizedTensor` Note: This is not a user facing API, it's used by AffineQuantizedTensor to construct the underlying implementation of a AQT based on layout """ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: """Get the plain (unpacked) Tensor for the tensor impl Returns data, scale and zero_point Can be overwritten if other types of AQTTensorImpl has different numbers of plain tensors """ pass def get_layout(self) -> Layout: pass @classmethod def from_plain( cls, data: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor], _layout: Layout, ): """Construct a TensorImpl from data, scale, zero_point and the _layout""" pass def __repr__(self): data, scale, zero_point = self.get_plain() _layout = self.get_layout() return f"{self.__class__.__name__}(data={str(data)}... , scale={str(scale)}... , zero_point={str(zero_point)}... , _layout={_layout})"

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources