torchao.float8¶
Main float8 training APIs¶
Swaps torch.nn.Linear in module with Float8Linear. |
Other float8 training types¶
Configuration for converting a torch.nn.Linear module to float8 for training. |
|
Configuration for maybe casting a single tensor to float8 |
|
Defines the granularity of scaling strategies for casting to float8 |
|
Calculate scale dynamically for all float8 parameters. This should be run after the optimizer step. It performs a single all-reduce to compute the scales for all float8 weights. Example usage: model(input).sum().backward() optim.step() precompute_float8_dynamic_scale_for_fsdp(model). |