From 4ca52d73d8fb1ebfd5d272576295f03f7e34fc15 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 3 Dec 2020 10:05:09 -0800 Subject: [PATCH] Add separate set and update method to ModelEmaV2 --- timm/utils/model_ema.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/timm/utils/model_ema.py b/timm/utils/model_ema.py index a767eaa5..073d5c5e 100644 --- a/timm/utils/model_ema.py +++ b/timm/utils/model_ema.py @@ -112,9 +112,15 @@ class ModelEmaV2(nn.Module): if self.device is not None: self.module.to(device=device) - def update(self, model): + def _update(self, model, update_fn): with torch.no_grad(): for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): if self.device is not None: model_v = model_v.to(device=self.device) - ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) + ema_v.copy_(update_fn(ema_v, model_v)) + + def update(self, model): + self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + + def set(self, model): + self._update(model, update_fn=lambda e, m: m)