MXDynamicActivationMXWeightConfig#
- class torchao.prototype.mx_formats.MXDynamicActivationMXWeightConfig(block_size: int = 32, activation_dtype: dtype = torch.float8_e4m3fn, weight_dtype: dtype = torch.float8_e4m3fn, kernel_preference: KernelPreference = KernelPreference.AUTO, scaling_mode: ScaleCalculationMode = ScaleCalculationMode.RCEIL)[source][source]#
MX Format Inference Quantization
This module provides support for running inference with float8 quantization using MX formats.
Requirements: - NVIDIA SM100+ hardware (Blackwell or newer) is required for execution - PyTorch 2.5+ for proper serialization support
Example (mxfp8):
import torch import torch.nn as nn from torchao.prototype.mx_formats.inference_workflow import ( MXDynamicActivationMXWeightConfig, ) from torchao.quantization import quantize_ from torchao.quantization.quantize_.common import KernelPreference model = nn.Linear(32, 128, bias=False, dtype=torch.bfloat16, device="cuda") config = MXDynamicActivationMXWeightConfig( activation_dtype=torch.float8_e4m3fn, weight_dtype=torch.float8_e4m3fn, kernel_preference=KernelPreference.AUTO, ) quantize_(model, config=config) model = torch.compile(model, fullgraph=True)
Example (mxfp4):
import torch import torch.nn as nn from torchao.prototype.mx_formats.inference_workflow import ( MXDynamicActivationMXWeightConfig, ) from torchao.quantization import quantize_ from torchao.quantization.quantize_.common import KernelPreference model = nn.Linear(32, 128, bias=False, dtype=torch.bfloat16, device="cuda") config = MXDynamicActivationMXWeightConfig( activation_dtype=torch.float4_e2m1fn_x2, weight_dtype=torch.float4_e2m1fn_x2, kernel_preference=KernelPreference.AUTO, ) quantize_(model, config=config) model = torch.compile(model, fullgraph=True)