convert_to_float8_training¶
- torchao.float8.convert_to_float8_training(module: Module, *, module_filter_fn: Optional[Callable[[Module, str], bool]] = None, config: Optional[Float8LinearConfig] = None) Module [source]¶
Swaps torch.nn.Linear in module with Float8Linear.
- Parameters:
module – Module to modify.
module_filter_fn – If specified, only the torch.nn.Linear subclasses that that pass the filter function will be swapped. The inputs to the filter function are the module instance and the FQN.
config (Float8LinearConfig) – configuration for conversion to float8
- Returns:
The modified module with swapped linear layers.
- Return type:
nn.Module