|
|
@ -102,30 +102,34 @@ class ModelEmaV2(nn.Module):
|
|
|
|
This class is sensitive where it is initialized in the sequence of model init,
|
|
|
|
This class is sensitive where it is initialized in the sequence of model init,
|
|
|
|
GPU assignment and distributed training wrappers.
|
|
|
|
GPU assignment and distributed training wrappers.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
def __init__(self, model, decay=0.9999, device=None):
|
|
|
|
def __init__(self, model, decay=0.9999, device=None, foreach=False):
|
|
|
|
super(ModelEmaV2, self).__init__()
|
|
|
|
super(ModelEmaV2, self).__init__()
|
|
|
|
# make a copy of the model for accumulating moving average of weights
|
|
|
|
# make a copy of the model for accumulating moving average of weights
|
|
|
|
self.module = deepcopy(model)
|
|
|
|
self.module = deepcopy(model)
|
|
|
|
self.module.eval()
|
|
|
|
self.module.eval()
|
|
|
|
self.decay = decay
|
|
|
|
self.decay = decay
|
|
|
|
|
|
|
|
self.foreach = foreach
|
|
|
|
self.device = device # perform ema on different device from model if set
|
|
|
|
self.device = device # perform ema on different device from model if set
|
|
|
|
if self.device is not None:
|
|
|
|
if self.device is not None and device != next(model.parameters()).device:
|
|
|
|
|
|
|
|
self.foreach = False # cannot use foreach methods with different devices
|
|
|
|
self.module.to(device=device)
|
|
|
|
self.module.to(device=device)
|
|
|
|
|
|
|
|
|
|
|
|
def _update(self, model, update_fn):
|
|
|
|
@torch.no_grad()
|
|
|
|
with torch.no_grad():
|
|
|
|
|
|
|
|
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
|
|
|
|
|
|
|
|
if self.device is not None:
|
|
|
|
|
|
|
|
model_v = model_v.to(device=self.device)
|
|
|
|
|
|
|
|
update_fn(ema_v, model_v)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update(self, model):
|
|
|
|
def update(self, model):
|
|
|
|
|
|
|
|
ema_params = tuple(self.module.parameters())
|
|
|
|
|
|
|
|
model_params = tuple(model.parameters())
|
|
|
|
|
|
|
|
if self.foreach:
|
|
|
|
|
|
|
|
torch._foreach_mul_(ema_params, scalar=self.decay)
|
|
|
|
|
|
|
|
torch._foreach_add_(ema_params, model_params, alpha=1 - self.decay)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
for ema_p, model_p in zip(ema_params, model_params):
|
|
|
|
|
|
|
|
ema_p.mul_(self.decay).add_(model_p.to(device=self.device), alpha=1 - self.decay)
|
|
|
|
|
|
|
|
|
|
|
|
def ema_update(e, m):
|
|
|
|
# copy buffers instead of EMA
|
|
|
|
if m.is_floating_point():
|
|
|
|
for ema_b, model_b in zip(self.module.buffers(), model.buffers()):
|
|
|
|
e.mul_(self.decay).add_(m, alpha=1 - self.decay)
|
|
|
|
ema_b.copy_(model_b.to(device=self.device))
|
|
|
|
|
|
|
|
|
|
|
|
self._update(model, update_fn=ema_update)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
|
def set(self, model):
|
|
|
|
def set(self, model):
|
|
|
|
self._update(model, update_fn=lambda e, m: e.copy_(m))
|
|
|
|
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
|
|
|
|
|
|
|
|
ema_v.copy_(model_v.to(device=self.device))
|
|
|
|