Rate this Page

torch.optim.swa_utils.get_ema_avg_fn#

torch.optim.swa_utils.get_ema_avg_fn(decay=0.999)[source]#

Get the function applying exponential moving average (EMA) across multiple params.

The EMA is computed as:

W0EMA=W0modelW_0^{\text{EMA}} = W_0^{\text{model}}
Wt+1EMA=decay×WtEMA+(1decay)×Wt+1modelW_{t+1}^{\text{EMA}} = \text{decay} \times W_t^{\text{EMA}} + (1 - \text{decay}) \times W_{t+1}^{\text{model}}

where WtEMAW_t^{\text{EMA}} is the EMA parameter at step tt, WtmodelW_t^{\text{model}} is the model parameter at step tt, and decay\text{decay} is the decay rate (default: 0.999).

Parameters:

decay (float) – Decay rate for EMA. Must be in the range [0, 1]. Default: 0.999

Returns:

A function that updates EMA parameters given current model parameters

Return type:

Callable