get_gradient_scaler¶
- torchtune.utils.precision.get_gradient_scaler(fsdp: bool = False) Union[GradScaler, ShardedGradScaler] [source]¶
Returns a gradient scaler for mixed-precision training.
- Parameters:
fsdp (bool) – Whether FSDP is being used for training, in which case a shard-aware gradient scaler is returned.
- Returns:
Gradient scaler object
- Return type:
Union[GradScaler, ShardedGradScaler]