diff --git a/timm/utils/model_ema.py b/timm/utils/model_ema.py index a767eaa5..073d5c5e 100644 --- a/timm/utils/model_ema.py +++ b/timm/utils/model_ema.py @@ -112,9 +112,15 @@ class ModelEmaV2(nn.Module): if self.device is not None: self.module.to(device=device) - def update(self, model): + def _update(self, model, update_fn): with torch.no_grad(): for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): if self.device is not None: model_v = model_v.to(device=self.device) - ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) + ema_v.copy_(update_fn(ema_v, model_v)) + + def update(self, model): + self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + + def set(self, model): + self._update(model, update_fn=lambda e, m: m)