diff --git a/timm/utils/model_ema.py b/timm/utils/model_ema.py index 073d5c5e..6c481449 100644 --- a/timm/utils/model_ema.py +++ b/timm/utils/model_ema.py @@ -117,10 +117,11 @@ class ModelEmaV2(nn.Module): for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): if self.device is not None: model_v = model_v.to(device=self.device) - ema_v.copy_(update_fn(ema_v, model_v)) + if model_v.is_floating_point(): + update_fn(ema_v, model_v) def update(self, model): - self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) + self._update(model, update_fn=lambda e, m: e.mul_(self.decay).add_(m, alpha=1 - self.decay)) def set(self, model): self._update(model, update_fn=lambda e, m: m)