Rate this Page

torchao.float8#

Created On: Dec 17, 2025 | Last Updated On: Dec 17, 2025

Main float8 training APIs#

convert_to_float8_training

Swaps torch.nn.Linear in module with Float8Linear.

Other float8 training types#

Float8LinearConfig

Configuration for converting a torch.nn.Linear module to float8 for training.

CastConfig

Configuration for maybe casting a single tensor to float8

ScalingType

ScalingGranularity

Defines the granularity of scaling strategies for casting to float8

precompute_float8_dynamic_scale_for_fsdp

Calculate scale dynamically for all float8 parameters. This should be run after the optimizer step. It performs a single all-reduce to compute the scales for all float8 weights. Example usage: model(input).sum().backward() optim.step() precompute_float8_dynamic_scale_for_fsdp(model).