Rate this Page

MXDynamicActivationMXWeightConfig#

class torchao.prototype.mx_formats.MXDynamicActivationMXWeightConfig(block_size: int = 32, activation_dtype: dtype = torch.float8_e4m3fn, weight_dtype: dtype = torch.float8_e4m3fn, kernel_preference: KernelPreference = KernelPreference.AUTO, scaling_mode: ScaleCalculationMode = ScaleCalculationMode.RCEIL)[source][source]#

MX Format Inference Quantization

This module provides support for running inference with float8 quantization using MX formats.

Requirements: - NVIDIA SM100+ hardware (Blackwell or newer) is required for execution - PyTorch 2.5+ for proper serialization support

Example (mxfp8):

import torch
import torch.nn as nn

from torchao.prototype.mx_formats.inference_workflow import (
    MXDynamicActivationMXWeightConfig,
)
from torchao.quantization import quantize_
from torchao.quantization.quantize_.common import KernelPreference

model = nn.Linear(32, 128, bias=False, dtype=torch.bfloat16, device="cuda")
config = MXDynamicActivationMXWeightConfig(
    activation_dtype=torch.float8_e4m3fn,
    weight_dtype=torch.float8_e4m3fn,
    kernel_preference=KernelPreference.AUTO,
)
quantize_(model, config=config)
model = torch.compile(model, fullgraph=True)

Example (mxfp4):

import torch
import torch.nn as nn

from torchao.prototype.mx_formats.inference_workflow import (
    MXDynamicActivationMXWeightConfig,
)
from torchao.quantization import quantize_
from torchao.quantization.quantize_.common import KernelPreference

model = nn.Linear(32, 128, bias=False, dtype=torch.bfloat16, device="cuda")
config = MXDynamicActivationMXWeightConfig(
    activation_dtype=torch.float4_e2m1fn_x2,
    weight_dtype=torch.float4_e2m1fn_x2,
    kernel_preference=KernelPreference.AUTO,
)
quantize_(model, config=config)
model = torch.compile(model, fullgraph=True)