Use in-place operations for EMA

pull/1552/head
Jerome Rony 3 years ago
parent 25ffac6880
commit d7165588c1

@ -117,10 +117,11 @@ class ModelEmaV2(nn.Module):
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
if self.device is not None:
model_v = model_v.to(device=self.device)
ema_v.copy_(update_fn(ema_v, model_v))
if model_v.is_floating_point():
update_fn(ema_v, model_v)
def update(self, model):
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
self._update(model, update_fn=lambda e, m: e.mul_(self.decay).add_(m, alpha=1 - self.decay))
def set(self, model):
self._update(model, update_fn=lambda e, m: m)

Loading…
Cancel
Save