From d7165588c16d7447bd87ae358170536f9d582be4 Mon Sep 17 00:00:00 2001 From: Jerome Rony Date: Thu, 17 Nov 2022 11:15:37 -0500 Subject: [PATCH] Use in-place operations for EMA --- timm/utils/model_ema.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/timm/utils/model_ema.py b/timm/utils/model_ema.py index 073d5c5e..6c481449 100644 --- a/timm/utils/model_ema.py +++ b/timm/utils/model_ema.py @@ -117,10 +117,11 @@ class ModelEmaV2(nn.Module): for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): if self.device is not None: model_v = model_v.to(device=self.device) - ema_v.copy_(update_fn(ema_v, model_v)) + if model_v.is_floating_point(): + update_fn(ema_v, model_v) def update(self, model): - self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + self._update(model, update_fn=lambda e, m: e.mul_(self.decay).add_(m, alpha=1 - self.decay)) def set(self, model): self._update(model, update_fn=lambda e, m: m)