|
|
@ -112,9 +112,15 @@ class ModelEmaV2(nn.Module):
|
|
|
|
if self.device is not None:
|
|
|
|
if self.device is not None:
|
|
|
|
self.module.to(device=device)
|
|
|
|
self.module.to(device=device)
|
|
|
|
|
|
|
|
|
|
|
|
def update(self, model):
|
|
|
|
def _update(self, model, update_fn):
|
|
|
|
with torch.no_grad():
|
|
|
|
with torch.no_grad():
|
|
|
|
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
|
|
|
|
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
|
|
|
|
if self.device is not None:
|
|
|
|
if self.device is not None:
|
|
|
|
model_v = model_v.to(device=self.device)
|
|
|
|
model_v = model_v.to(device=self.device)
|
|
|
|
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
|
|
|
|
ema_v.copy_(update_fn(ema_v, model_v))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update(self, model):
|
|
|
|
|
|
|
|
self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def set(self, model):
|
|
|
|
|
|
|
|
self._update(model, update_fn=lambda e, m: m)
|
|
|
|