From 6ec5cd6a99911f2ae771536fc9dd382243069105 Mon Sep 17 00:00:00 2001 From: Jerome Rony Date: Thu, 17 Nov 2022 11:53:29 -0500 Subject: [PATCH] Use in-place operations for EMA --- timm/utils/model_ema.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/timm/utils/model_ema.py b/timm/utils/model_ema.py index 073d5c5e..0213d582 100644 --- a/timm/utils/model_ema.py +++ b/timm/utils/model_ema.py @@ -117,10 +117,15 @@ class ModelEmaV2(nn.Module): for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): if self.device is not None: model_v = model_v.to(device=self.device) - ema_v.copy_(update_fn(ema_v, model_v)) + update_fn(ema_v, model_v) def update(self, model): - self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + + def ema_update(e, m): + if m.is_floating_point(): + e.mul_(self.decay).add_(m, alpha=1 - self.decay) + + self._update(model, update_fn=ema_update) def set(self, model): - self._update(model, update_fn=lambda e, m: m) + self._update(model, update_fn=lambda e, m: e.copy_(m))