Rate this Page

convert_to_float8_training#

torchao.float8.convert_to_float8_training(module: Module, *, module_filter_fn: Optional[Callable[[Module, str], bool]] = None, config: Optional[Float8LinearConfig] = None) Module[source][source]#

Swaps torch.nn.Linear in module with Float8Linear.

Parameters
  • module – Module to modify.

  • module_filter_fn – If specified, only the torch.nn.Linear subclasses that that pass the filter function will be swapped. The inputs to the filter function are the module instance and the FQN.

  • config (Float8LinearConfig) – configuration for conversion to float8

Returns

The modified module with swapped linear layers.

Return type

nn.Module