Shortcuts

Source code for torchao.dtypes.floatx.float8_layout

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
from dataclasses import dataclass
from typing import Optional, Tuple, Union

import torch
from torch.utils._python_dispatch import (
    is_traceable_wrapper_subclass,
    return_and_correct_aliasing,
)

from torchao.dtypes.affine_quantized_tensor import (
    AffineQuantizedTensor,
    register_layout,
)
from torchao.dtypes.utils import AQTTensorImpl, Layout, get_out_shape
from torchao.float8.inference import (
    Float8MMConfig,
    _is_rowwise_scaled,
    addmm_float8_unwrapped_inference,
    preprocess_data,
)
from torchao.utils import _is_float8_type, fill_defaults

aten = torch.ops.aten


def _same_metadata(self: "Float8AQTTensorImpl", src: "Float8AQTTensorImpl") -> bool:
    # Special handling for transposed attribute
    transposed_match = (self.transposed == src.transposed) or (
        self.transposed is False and src.transposed is None
    )

    return (
        isinstance(self, Float8AQTTensorImpl)
        and isinstance(src, Float8AQTTensorImpl)
        and self.shape == src.shape
        and self.float8_data.shape == src.float8_data.shape
        and self.scale.shape == src.scale.shape
        and transposed_match
        and type(self._layout) == type(src._layout)
    )


[docs]@dataclass(frozen=True) class Float8Layout(Layout): """Represents the layout configuration for Float8 affine quantized tensors. Attributes: mm_config (Optional[Float8MMConfig]): Configuration for matrix multiplication operations involving Float8 tensors. If None, default settings are used. """ mm_config: Optional[Float8MMConfig] = None
_fallback_warning_shown = False @register_layout(Float8Layout) class Float8AQTTensorImpl(AQTTensorImpl): """ TensorImpl for float8 layout affine quantized tensor Note: technically we should not create a new layout for float8 we should merge this into plain layout """ float8_data: torch.Tensor scale: torch.Tensor transposed: bool def __new__( cls, float8_data: torch.Tensor, scale: torch.Tensor, transposed: bool, _layout: Layout, ): kwargs = {} kwargs["device"] = float8_data.device kwargs["layout"] = ( kwargs.get("layout") if kwargs.get("layout", False) else float8_data.layout ) kwargs["dtype"] = float8_data.dtype kwargs["requires_grad"] = False shape = float8_data.shape return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] def __init__( self, float8_data: torch.Tensor, scale: torch.Tensor, transposed: bool, _layout: Layout, ): self.float8_data = float8_data self.scale = scale self.transposed = transposed self._layout = _layout def _apply_fn_to_data(self, fn): """Applys a fn to all tensor components stored on this class""" global _fallback_warning_shown try: return self.__class__( fn(self.float8_data), fn(self.scale), self.transposed, self._layout, ) except RuntimeError as e: if '"index_cuda" not implemented for ' in str(e): if not _fallback_warning_shown: import warnings warnings.warn( f"When trying to index Float8AQTTensorImpl, got known error {e}, will use slower fallback but " + "note: You can torch.compile the model to avoid this problem.", UserWarning, ) _fallback_warning_shown = True return self.__class__( # do indexing in bfloat16 then convert back fn(self.float8_data.to(torch.bfloat16)).to(self.float8_data.dtype), fn(self.scale), self.transposed, self._layout, ) else: raise e def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) return self.__class__( self.float8_data.to(kwargs["device"]), self.scale.to(kwargs["device"]), self.transposed, self._layout, ) def __tensor_flatten__(self): return ["float8_data", "scale"], [self.transposed, self._layout] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): float8_data, scale = tensor_data_dict["float8_data"], tensor_data_dict["scale"] ( transposed, _layout, ) = tensor_attributes return cls(float8_data, scale, transposed, _layout) @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): kwargs = {} if kwargs is None else kwargs if func is aten.detach.default: return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) elif func is aten.clone.default: return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) ) elif func is aten.t.default: """we don't need to repack the weight and just rely on external shape being changed and record the status of transpose/no-transpose """ args[0].transposed = not args[0].transposed return return_and_correct_aliasing(func, args, kwargs, args[0]) elif func is aten.copy_.default: self = args[0] src = args[1] if _same_metadata(self, src): self_tensors = self.__tensor_flatten__()[0] for tensor_name in self_tensors: getattr(self, tensor_name).copy_(getattr(src, tensor_name)) return raise ValueError( f"Not supported args for copy_ due to metadata mistach: {args[0], args[1]}" ) elif func in [aten.select.int, aten.index.Tensor]: return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)), ) elif func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) if dim == 0: # TODO: scale replecation should be dependent on block size if self.scale.ndim == 1: return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data( lambda x: aten.slice.Tensor(x, dim, start, end, step) ), ) elif self.scale.ndim == 0: return return_and_correct_aliasing( func, args, kwargs, Float8AQTTensorImpl( aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self._layout, ), ) else: raise NotImplementedError( f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported" ) elif dim == 1: return return_and_correct_aliasing( func, args, kwargs, Float8AQTTensorImpl( aten.slice.Tensor( self.float8_data, dim, start, end, step ).contiguous(), self.scale, None, self._layout, ), ) else: raise NotImplementedError( f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" ) else: raise NotImplementedError( f"Float8AQTTensorImpl dispatch: attempting to run {func}, this is not supported" ) __torch_function__ = torch._C._disabled_torch_function_impl def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: return self.float8_data, self.scale, None def get_layout(self) -> Layout: return self._layout @classmethod def from_plain( cls, data: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor], _layout: Layout, ): """Main entrypoint for constructing Float8TensorImpl""" assert _is_float8_type(data.dtype), ( f"Float8 TensorImpl must be constructed from float8 dtype but got {data.dtype}" ) assert isinstance(_layout, Float8Layout), ( f"Float8 TensorImpl must be constructed from Float8Layout but got {_layout}" ) return cls(data, scale, False, _layout) def __repr__(self): float8_data, scale, _ = self.get_plain() _layout = self.get_layout() return ( f"{self.__class__.__name__}(\n" f"float8_data={float8_data},\n" f"scale={scale},\n" f"transposed={self.transposed}, " f"_layout={_layout})" ) ########################## # Float8 Dispatch Kernels ########################## def _linear_fp8_act_fp8_weight_check( input_tensor: Union[torch.Tensor, "AffineQuantizedTensor"], weight_tensor: Union[torch.Tensor, "AffineQuantizedTensor"], bias: Optional[torch.Tensor], ) -> bool: def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: return ( isinstance(aqt, AffineQuantizedTensor) and isinstance(aqt._layout, Float8Layout) and aqt.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt)) ) return check_aqt(input_tensor) and check_aqt(weight_tensor) def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]): """Ensures input tensor is correctly formated for _scaled_mm""" input_scale = input_scale.unsqueeze(-1) if input_scale.dim() > 2: input_scale = input_scale.reshape(-1, input_scale.shape[-1]) return input_scale def _linear_fp8_act_fp8_weight_impl( input_tensor: "AffineQuantizedTensor", weight_tensor: "AffineQuantizedTensor", bias: Optional[torch.Tensor], ): """Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm""" scaled_mm_config = weight_tensor._layout.mm_config assert scaled_mm_config is not None out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape) # Weight tensor preprocessing w_tensor_impl = weight_tensor.tensor_impl assert not w_tensor_impl.transposed, "Weight tensor must be contiguous" w_data = w_tensor_impl.float8_data w_scale = w_tensor_impl.scale # Input tensor preprocessing inpt_data = input_tensor.tensor_impl.float8_data input_scale = input_tensor.tensor_impl.scale # Handle case where input tensor is more than 2D inpt_data = inpt_data.reshape(-1, inpt_data.shape[-1]) # Handle rowwise case if _is_rowwise_scaled(weight_tensor): assert _is_rowwise_scaled(input_tensor), ( "Input tensor must be rowwise block size" ) w_scale = w_scale.unsqueeze(-1).T input_scale = preprocess_scale(input_scale, input_tensor.shape) # Preprocess data inpt_data, w_data = preprocess_data(inpt_data, w_data.T, scaled_mm_config) # Perform the computation return addmm_float8_unwrapped_inference( inpt_data, input_scale, w_data, w_scale, output_dtype=input_tensor.dtype, bias=bias, use_fast_accum=scaled_mm_config.use_fast_accum, ).reshape(out_shape) def _linear_fp_act_fp8_weight_check( input_tensor: Union[torch.Tensor, "AffineQuantizedTensor"], weight_tensor: Union[torch.Tensor, "AffineQuantizedTensor"], bias: Optional[torch.Tensor], ) -> bool: return ( # input is native float tensor not is_traceable_wrapper_subclass(input_tensor) and input_tensor.is_floating_point() and # weight is float8 quantized affine quantized tensor isinstance(weight_tensor, AffineQuantizedTensor) and isinstance(weight_tensor._layout, Float8Layout) and weight_tensor.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and ( weight_tensor.shape == weight_tensor.block_size or _is_rowwise_scaled(weight_tensor) ) ) def _linear_fp_act_fp8_weight_impl( input_tensor: torch.Tensor, weight_tensor: "AffineQuantizedTensor", bias: Optional[torch.Tensor], ): return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources