Add foreach option for faster EMA

pull/1553/head
Jerome Rony 2 years ago
parent 6ec5cd6a99
commit 3491506fec

@ -102,30 +102,34 @@ class ModelEmaV2(nn.Module):
This class is sensitive where it is initialized in the sequence of model init, This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers. GPU assignment and distributed training wrappers.
""" """
def __init__(self, model, decay=0.9999, device=None): def __init__(self, model, decay=0.9999, device=None, foreach=False):
super(ModelEmaV2, self).__init__() super(ModelEmaV2, self).__init__()
# make a copy of the model for accumulating moving average of weights # make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model) self.module = deepcopy(model)
self.module.eval() self.module.eval()
self.decay = decay self.decay = decay
self.foreach = foreach
self.device = device # perform ema on different device from model if set self.device = device # perform ema on different device from model if set
if self.device is not None: if self.device is not None and device != next(model.parameters()).device:
self.foreach = False # cannot use foreach methods with different devices
self.module.to(device=device) self.module.to(device=device)
def _update(self, model, update_fn): @torch.no_grad()
with torch.no_grad():
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
if self.device is not None:
model_v = model_v.to(device=self.device)
update_fn(ema_v, model_v)
def update(self, model): def update(self, model):
ema_params = tuple(self.module.parameters())
model_params = tuple(model.parameters())
if self.foreach:
torch._foreach_mul_(ema_params, scalar=self.decay)
torch._foreach_add_(ema_params, model_params, alpha=1 - self.decay)
else:
for ema_p, model_p in zip(ema_params, model_params):
ema_p.mul_(self.decay).add_(model_p.to(device=self.device), alpha=1 - self.decay)
def ema_update(e, m): # copy buffers instead of EMA
if m.is_floating_point(): for ema_b, model_b in zip(self.module.buffers(), model.buffers()):
e.mul_(self.decay).add_(m, alpha=1 - self.decay) ema_b.copy_(model_b.to(device=self.device))
self._update(model, update_fn=ema_update)
@torch.no_grad()
def set(self, model): def set(self, model):
self._update(model, update_fn=lambda e, m: e.copy_(m)) for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
ema_v.copy_(model_v.to(device=self.device))

Loading…
Cancel
Save