Shortcuts

Source code for torchao.float8.config

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import enum
import logging
from dataclasses import dataclass
from typing import Optional, Union

import torch

from torchao.utils import is_MI300

logger: logging.Logger = logging.getLogger()


[docs]class ScalingType(enum.Enum): DYNAMIC = "dynamic" # ScalingType.DISABLED means "skip scaling for this tensor, leave it in # its original precision. DISABLED = "disabled" def short_str(self): if self is ScalingType.DYNAMIC: return "dyn" else: assert self is ScalingType.DISABLED return "dis"
[docs]class ScalingGranularity(enum.Enum): """ Defines the granularity of scaling strategies for casting to float8 """ # A single scaling factor for the entire tensor TENSORWISE = "tensorwise" # Scaling factors computed along one axis of the tensor, reducing it to # size 1. AXISWISE = "axiswise" def short_str(self): if self is ScalingGranularity.TENSORWISE: return "ten" else: assert self is ScalingGranularity.AXISWISE return "axs"
@dataclass class Float8TypeConfig: """ Configuration for selecting the preferred float8 type pair, either e4m3fn/e5m2 or e4m3fnuz/e5m2fnuz. Currently, ROCm supports 1. fnuz variants in MI300. 2. OCP F8 variants in MI350/Navi4. """ # The preferred e4m3 type. e4m3_dtype = torch.float8_e4m3fn # The preferred e5m2 type. e5m2_dtype = torch.float8_e5m2 def __post_init__(self): if torch.version.hip and torch.cuda.is_available() and is_MI300(): self.e4m3_dtype = torch.float8_e4m3fnuz self.e5m2_dtype = torch.float8_e5m2fnuz # User defined type for using the individual F8 type based on config type_config = Float8TypeConfig() e4m3_dtype = type_config.e4m3_dtype e5m2_dtype = type_config.e5m2_dtype
[docs]@dataclass(frozen=True) class CastConfig: """ Configuration for maybe casting a single tensor to float8 """ scaling_type: ScalingType = ScalingType.DYNAMIC scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE target_dtype: Optional[torch.dtype] = None def short_str(self): dtype = {e4m3_dtype: "e4m3", e5m2_dtype: "e5m2"}[self.target_dtype] return f"{self.scaling_type.short_str()}_{self.scaling_granularity.short_str()}_{dtype}" def __post_init__(self): if self.scaling_granularity is ScalingGranularity.AXISWISE: assert self.scaling_type is ScalingType.DYNAMIC, ( "only dynamic scaling type is supported for axiswise scaling granularity" ) assert self.target_dtype is None or ( self.target_dtype.is_floating_point and self.target_dtype.itemsize == 1 ), "must specify a 8-bit floating-point dtype"
@dataclass(frozen=True) class Float8GemmConfig: """ Configuration for a float8 gemm. """ # If True, fast accumulation in lower precision is used. # Note: this flag is currently a no-op if emulation is turned on. use_fast_accum: bool = False # Pre-made recipes for common configurations class Float8LinearRecipeName(enum.Enum): # Default, dynamic per-tensor scaling with the cuBLAS tensorwise kernel TENSORWISE = "tensorwise" # dynamic rowwise scaling with the CUTLASS rowwise kernel # * e4m3 for activations, weights, gradients # * scales rounded (floor) to the nearest power of two for increased accuracy ROWWISE = "rowwise" # lw's recipe for a modification on rowwise scaling: # # output_hp = input_fp8_rowwise_dim0 @ weight_t_rowwise_dim1 # grad_input_hp = grad_output_fp8_rowwise_dim0 @ weight_fp8_tensorwise # grad_weight_hp = input_t_hp @ grad_output_hp # # key characteristics: # * increased accuracy for grad_weight # * `input`, `weight` and `grad_output` now only need to be scaled # rowwise across a single dim compared to vanilla rowwise, # which is more amenable to fast kernels # * the e4m3 dtype is used across the board, including for gradients ROWWISE_WITH_GW_HP = "rowwise_with_gw_hp"
[docs]@dataclass(frozen=True) class Float8LinearConfig: """ Configuration for converting a `torch.nn.Linear` module to float8 for training. """ # # Per-tensor configuration for casting of `input`, `weight`, `grad_output` # for the operands of gemms calculating `output`, `grad_weight`, and `grad_input`. # # Note: # 1. if `cast_config_input_for_grad_weight` is None, then # `cast_config_input` is used for scaling `input` for both gemms that # use `input. # 2. if `cast_config_input_for_grad_weight` is specified, then # a. `cast_config_input` is used for scaling `input` for the gemm that calculates # `output` # b. `cast_config_input_for_grad_weight` is used for scaling `input` for # the gemm that calculates `grad_weight` # 3. the same behavior holds for `cast_config_weight` and `cast_config_grad_output`. # # `input` cast_config_input: CastConfig = CastConfig() cast_config_input_for_grad_weight: Optional[CastConfig] = None # `weight` cast_config_weight: CastConfig = CastConfig() cast_config_weight_for_grad_input: Optional[CastConfig] = None # `grad_output` cast_config_grad_output: CastConfig = CastConfig() cast_config_grad_output_for_grad_weight: Optional[CastConfig] = None # # Per-gemm configuration for gemms calculating `output`, `grad_input` and # `grad_weight` # gemm_config_output: Float8GemmConfig = Float8GemmConfig(use_fast_accum=True) gemm_config_grad_input: Float8GemmConfig = Float8GemmConfig() gemm_config_grad_weight: Float8GemmConfig = Float8GemmConfig() # # Per-linear configuration # # If True, then uses a tensor subclass for the float8 linear module's weight that # implements pre/post-all-gather methods to do float8 all-gather with FSDP2. enable_fsdp_float8_all_gather: bool = False # If True, then prior to performing the fp8 scaled mamtmul we will pad the # inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls # _scaled_mm since it has the strong constraint that for M,N,K N, K must be a multiple of 16. # This can cause a memory spike however so we keep this off by default. pad_inner_dim: bool = False # If True, emulation is used instead of hardware accelerated gemm emulate: bool = False # This flag is deprecated and currently has no effect. It will be removed # in a future release. Please see https://github.com/pytorch/ao/issues/2251 # for more context. force_recompute_fp8_weight_in_bwd: bool = False # If this option is enabled, the scaling factor used for float8 quantization # will be rounded down to the nearest power of 2. This has been shown to help # reduce quantization error by avoiding rounding errors when multiplying/dividing # by the scaling factor, as well as ensuring large values are quantized to the # same value in the forward pass as the backward passes. round_scales_to_power_of_2: bool = False def __post_init__(self): # Populate the additional cast overrides, if the user did not specify them # Note: this hacks around the frozen-ness of this dataclass # by using `object.__setattr__`. This is fine, as what we really need # is for this object to be frozen after `__post_init__` for torch.compile # to work. # Source of hack: https://stackoverflow.com/a/65959419/ if self.cast_config_input_for_grad_weight is None: object.__setattr__( self, "cast_config_input_for_grad_weight", self.cast_config_input ) if self.cast_config_weight_for_grad_input is None: object.__setattr__( self, "cast_config_weight_for_grad_input", self.cast_config_weight ) if self.cast_config_grad_output_for_grad_weight is None: object.__setattr__( self, "cast_config_grad_output_for_grad_weight", self.cast_config_grad_output, ) # float8 all-gather only supports tensorwise, in the future may support blockwise if self.cast_config_weight.scaling_granularity != ScalingGranularity.TENSORWISE: assert not self.enable_fsdp_float8_all_gather, ( f"enable_fsdp_float8_all_gather only supports tensorwise scaling granularity, got {self.cast_config_weight.scaling_granularity}" ) # save some characters in the compatibility checks below cc_i = self.cast_config_input cc_w = self.cast_config_weight cc_go = self.cast_config_grad_output cc_i_gw = self.cast_config_input_for_grad_weight cc_w_gi = self.cast_config_weight_for_grad_input cc_go_gw = self.cast_config_grad_output_for_grad_weight # for now, we only have gemm kernels where both operands are either both # in high precision, or both in float8. In the future, this may be relaxed. # TODO(future): make the float8 check more precise with the specific dtypes. for cc1, cc2, gemm_name in ( (cc_i, cc_w, "output"), (cc_go, cc_w_gi, "grad_input"), (cc_i_gw, cc_go_gw, "grad_weight"), ): is_disabled_1 = cc1.scaling_type is ScalingType.DISABLED is_disabled_2 = cc1.scaling_type is ScalingType.DISABLED assert is_disabled_1 == is_disabled_2, ( f"incompatible operand precision for {gemm_name}" ) for cc1, cc2, operand_name, default_dtype in [ (cc_i, cc_i_gw, "input", e4m3_dtype), (cc_w, cc_w_gi, "weight", e4m3_dtype), (cc_go, cc_go_gw, "grad_output", e5m2_dtype), ]: # Override the dataclass being frozen if cc1.target_dtype is None: object.__setattr__(cc1, "target_dtype", default_dtype) if cc2.target_dtype is None: object.__setattr__(cc2, "target_dtype", default_dtype) assert cc1.target_dtype == cc2.target_dtype, ( f"{operand_name} must be cast to the same dtype in both matmuls it's used in" ) if self.force_recompute_fp8_weight_in_bwd: logger.warning( "`config.force_recompute_fp8_weight_in_bwd` is deprecated and will be removed in a future release. Please see https://github.com/pytorch/ao/issues/2251 for more details." )
[docs] @staticmethod def from_recipe_name( recipe_name: Union[Float8LinearRecipeName, str], ) -> "Float8LinearConfig": """ Input: `Float8LinearRecipeName` value, or a string representing a `Float8LinearRecipeName` value Output: a `Float8LinearConfig` configured to implement the specified recipe """ if type(recipe_name) == str: valid_names = [n.value for n in Float8LinearRecipeName] assert recipe_name in valid_names, ( f"recipe_name {recipe_name} not in valid names {valid_names}" ) recipe_name = Float8LinearRecipeName(recipe_name) if recipe_name is Float8LinearRecipeName.TENSORWISE: return Float8LinearConfig() elif recipe_name is Float8LinearRecipeName.ROWWISE: cc_i = CastConfig( scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype ) cc_w = CastConfig( scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype ) cc_go = CastConfig( scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype ) return Float8LinearConfig( cast_config_input=cc_i, cast_config_weight=cc_w, cast_config_grad_output=cc_go, # enable power of 2 scaling factors by default for row-wise scaling round_scales_to_power_of_2=True, ) elif recipe_name is Float8LinearRecipeName.ROWWISE_WITH_GW_HP: # output_hp = input_fp8_axiswise_dim0 @ weight_t_axiswise_dim1 cc_i = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise cc_go = CastConfig( scaling_granularity=ScalingGranularity.AXISWISE, target_dtype=e4m3_dtype ) cc_w_gi = CastConfig(scaling_granularity=ScalingGranularity.TENSORWISE) # grad_weight_hp = input_t_hp @ grad_output_hp cc_i_gw = CastConfig(scaling_type=ScalingType.DISABLED) cc_go_gw = CastConfig( scaling_type=ScalingType.DISABLED, target_dtype=e4m3_dtype ) return Float8LinearConfig( cast_config_input=cc_i, cast_config_weight=cc_w, cast_config_grad_output=cc_go, cast_config_input_for_grad_weight=cc_i_gw, cast_config_weight_for_grad_input=cc_w_gi, cast_config_grad_output_for_grad_weight=cc_go_gw, ) else: raise AssertionError(f"unknown recipe_name {recipe_name}")

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources