Rate this Page

Source code for torchao.prototype.mx_formats.inference_workflow

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import types
from dataclasses import dataclass

import torch

from torchao.core.config import AOBaseConfig
from torchao.prototype.mx_formats.config import (
    _validate_elem_dtype,
    _validate_kernel_preference,
)
from torchao.prototype.mx_formats.mx_tensor import (
    MXTensor,
    QuantizeTensorToMXKwargs,
    ScaleCalculationMode,
)
from torchao.prototype.mx_formats.nvfp4_tensor import (
    NVFP4Tensor,
    QuantizeTensorToNVFP4Kwargs,
    per_tensor_amax_to_scale,
)
from torchao.quantization.quant_api import _quantization_type
from torchao.quantization.quantize_.common.kernel_preference import KernelPreference
from torchao.quantization.transform_module import (
    register_quantize_module_handler,
)
from torchao.utils import (
    is_sm_at_least_100,
    torch_version_at_least,
)


[docs]@dataclass class MXDynamicActivationMXWeightConfig(AOBaseConfig): """ MX Format Inference Quantization This module provides support for running inference with float8 quantization using MX formats. Requirements: - NVIDIA SM100+ hardware (Blackwell or newer) is required for execution - PyTorch 2.5+ for proper serialization support """ block_size: int = 32 # Dtypes for Input and Weights, supports Fp8 and Fp4 formats activation_dtype: torch.dtype = torch.float8_e4m3fn weight_dtype: torch.dtype = torch.float8_e4m3fn # Which kernel to run for mm kernel_preference: KernelPreference = KernelPreference.AUTO # How to calculate the block scales scaling_mode: ScaleCalculationMode = ScaleCalculationMode.RCEIL def __post_init__(self): assert self.activation_dtype == self.weight_dtype, ( "For now - we only support matching input/weight dtypes." ) _validate_elem_dtype(self.activation_dtype) _validate_elem_dtype(self.weight_dtype) _validate_kernel_preference( self.kernel_preference, self.block_size, self.weight_dtype )
def _linear_extra_repr(self): return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}" @register_quantize_module_handler(MXDynamicActivationMXWeightConfig) def _mx_inference_linear_transform( module: torch.nn.Module, config: MXDynamicActivationMXWeightConfig ): weight = module.weight assert weight.dtype == torch.bfloat16, ( f"Only supporting bf16 out dtype for now, got {weight.dtype}" ) act_quant_kwargs = QuantizeTensorToMXKwargs( elem_dtype=config.activation_dtype, block_size=config.block_size, kernel_preference=config.kernel_preference, is_swizzled_scales=True, scaling_mode=config.scaling_mode, ) # Convert weight to MX Tensor quantized_weight = MXTensor.to_mx( weight, config.weight_dtype, block_size=config.block_size, kernel_preference=config.kernel_preference, act_quant_kwargs=act_quant_kwargs, is_swizzled_scales=True, scaling_mode=config.scaling_mode, ) module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module
[docs]@dataclass class NVFP4DynamicActivationNVFP4WeightConfig(AOBaseConfig): """ NVIDIA FP4 (NVFP4) Inference Quantization Configuration This is a specialized configuration for NVIDIA's FP4 format. Configuration parameters: - use_triton_kernel: bool, whether to use fused triton kernel for activation scaling (default: True) - use_dynamic_per_tensor_scale: bool, whether to dynamically compute per tensor scale (default: True) - Data: float4_e2m1fn_x2 - Scales: float8_e4m3fn - Block size: 16 along the reduction dim Note: Triton kernel only works with DYNAMIC mode and has constraints that input dimensions must satisfy M % 128 == 0 and K % 64 == 0. Will automatically fallback when constraints aren't met. """ use_triton_kernel: bool = True use_dynamic_per_tensor_scale: bool = True def __post_init__(self): # Validate PyTorch version if not torch_version_at_least("2.8.0"): raise RuntimeError( "NVFP4DynamicActivationNVFP4WeightConfig requires PyTorch 2.8 or later" )
@register_quantize_module_handler(NVFP4DynamicActivationNVFP4WeightConfig) def _nvfp4_inference_linear_transform( module: torch.nn.Linear, config: NVFP4DynamicActivationNVFP4WeightConfig ): """Quantization handler for NVFP4DynamicActivationNVFP4WeightConfig""" assert is_sm_at_least_100(), ( "NVFP4 DYNAMIC mode is only supported on sm100+ machines" ) weight = module.weight if weight.shape[-2] % 16 != 0 or weight.shape[-1] % 16 != 0: raise RuntimeError( f"NVFP4 only supports weight shape with last 2 dims divisible by 16, got {weight.shape}" ) per_tensor_scale = None if config.use_dynamic_per_tensor_scale: tensor_amax = torch.max(torch.abs(weight)) per_tensor_scale = per_tensor_amax_to_scale(tensor_amax) act_quant_kwargs = QuantizeTensorToNVFP4Kwargs( use_dynamic_per_tensor_scale=config.use_dynamic_per_tensor_scale, use_triton_kernel=config.use_triton_kernel, is_swizzled_scales=True, ) quantized_weight = NVFP4Tensor.to_nvfp4( weight, per_tensor_scale=per_tensor_scale, is_swizzled_scales=True, use_triton_kernel=False, # Always use traditional construction for weights act_quant_kwargs=act_quant_kwargs, ) # Set triton preference after construction quantized_weight.use_triton_kernel = config.use_triton_kernel module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module
[docs]@dataclass class NVFP4WeightOnlyConfig(AOBaseConfig): use_dynamic_per_tensor_scale: bool = True def __post_init__(self): # Validate PyTorch version if not torch_version_at_least("2.8.0"): raise RuntimeError( "NVFP4DynamicActivationNVFP4WeightConfig requires PyTorch 2.8 or later" )
@register_quantize_module_handler(NVFP4WeightOnlyConfig) def _nvfp4_weight_only_linear_transform( module: torch.nn.Linear, config: NVFP4WeightOnlyConfig ): """Quantization handler for NVFP4WeightOnlyConfig""" weight = module.weight if weight.shape[-2] % 16 != 0 or weight.shape[-1] % 16 != 0: raise RuntimeError( f"NVFP4 only supports weight shape with last 2 dims divisible by 16, got {weight.shape}" ) per_tensor_scale = None if config.use_dynamic_per_tensor_scale: tensor_amax = torch.max(torch.abs(weight)) per_tensor_scale = per_tensor_amax_to_scale(tensor_amax) quantized_weight = NVFP4Tensor.to_nvfp4( weight, per_tensor_scale=per_tensor_scale, is_swizzled_scales=True, act_quant_kwargs=None, ) # Set triton preference after construction module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module torch.serialization.add_safe_globals( [ MXTensor, NVFP4Tensor, QuantizeTensorToMXKwargs, QuantizeTensorToNVFP4Kwargs, ScaleCalculationMode, ] ) import torch.nn as nn def _auto_filter_for_nfp4(mod: nn.Module, fqn: str) -> bool: """Generic Filter fn for NVFP4 that is best practice for most models.""" # Define any FQNs you want to exclude directly in the function filter_fqns = ["embedder", "embed", "embedding", "time_text_embed"] # Only support Linear modules if not isinstance(mod, nn.Linear): return False # If the fqn matches any filtered fqn, then we should not convert this module is_filtered_fqn = any(filter_fqn in fqn for filter_fqn in filter_fqns) if is_filtered_fqn: return False # All dims must be divisible by 16 due to float8 hardware requirements. N, K = mod.weight.shape dims_multiples_of_16 = K % 16 == 0 and N % 16 == 0 if not dims_multiples_of_16: return False if N <= 64: print("skiping small linear layer") # TODO cublas doesn't like this one return False # Dims below these thresholds may result in worse performance if K <= 1024 and N <= 1024: print("skiping small linear layer") return False return True