Rate this Page

precompute_float8_dynamic_scale_for_fsdp#

torchao.float8.precompute_float8_dynamic_scale_for_fsdp(module: Module) None[source][source]#

Calculate scale dynamically for all float8 parameters.

This should be run after the optimizer step. It performs a single all-reduce to compute the scales for all float8 weights.

Parameters

module – The module containing float8 parameters.

Example:

model(input).sum().backward()
optim.step()
precompute_float8_dynamic_scale_for_fsdp(model)