Shortcuts

convert_to_float8_training

torchao.float8.convert_to_float8_training(module: Module, *, module_filter_fn: Optional[Callable[[Module, str], bool]] = None, config: Optional[Float8LinearConfig] = None) Module[source]

Swaps torch.nn.Linear in module with Float8Linear.

Parameters:
  • module – Module to modify.

  • module_filter_fn – If specified, only the torch.nn.Linear subclasses that that pass the filter function will be swapped. The inputs to the filter function are the module instance and the FQN.

  • config (Float8LinearConfig) – configuration for conversion to float8

Returns:

The modified module with swapped linear layers.

Return type:

nn.Module

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources