convert_to_float8_training#
- torchao.float8.convert_to_float8_training(module: Module, *, module_filter_fn: Optional[Callable[[Module, str], bool]] = None, config: Optional[Float8LinearConfig] = None) Module[source][source]#
Swaps torch.nn.Linear in module with Float8Linear.
- Parameters
module – Module to modify.
module_filter_fn – If specified, only the torch.nn.Linear subclasses that that pass the filter function will be swapped. The inputs to the filter function are the module instance and the FQN.
config (Float8LinearConfig) – configuration for conversion to float8
- Returns
The modified module with swapped linear layers.
- Return type
nn.Module
Example:
import torch import torch.nn as nn from torchao.float8 import convert_to_float8_training, Float8LinearConfig # create model and sample input m = nn.Sequential( nn.Linear(8192, 4096, bias=False), nn.Linear(4096, 128, bias=False), ).bfloat16().cuda() optimizer = torch.optim.SGD(m.parameters(), lr=0.1) # optional: filter modules from being eligible for float8 conversion def module_filter_fn(mod: torch.nn.Module, fqn: str): # don't convert the last module if fqn == "1": return False # don't convert linear modules with weight dimensions not divisible by 16 if isinstance(mod, torch.nn.Linear): if mod.in_features % 16 != 0 or mod.out_features % 16 != 0: return False return True # configure float8 recipe # valid recipe names: "tensorwise", "rowwise", "rowwise_with_gw_hp" config = Float8LinearConfig.from_recipe_name("tensorwise") # convert specified `torch.nn.Linear` modules to `Float8Linear` convert_to_float8_training(m, config=config, module_filter_fn=module_filter_fn) # enable torch.compile for competitive performance m = torch.compile(m) # training loop x = torch.randn(4096, 8192, device="cuda", dtype=torch.bfloat16) for _ in range(10): optimizer.zero_grad() y = m(x) y.sum().backward() optimizer.step()