torch.optim.swa_utils.get_ema_avg_fn#
- torch.optim.swa_utils.get_ema_avg_fn(decay=0.999)[source]#
Get the function applying exponential moving average (EMA) across multiple params.
The EMA is computed as:
where is the EMA parameter at step , is the model parameter at step , and is the decay rate (default: 0.999).
- Parameters:
decay (float) – Decay rate for EMA. Must be in the range [0, 1]. Default: 0.999
- Returns:
A function that updates EMA parameters given current model parameters
- Return type:
Callable