From 3491506fecf5adc85e0e4da724c6c2d43d3f9904 Mon Sep 17 00:00:00 2001 From: Jerome Rony Date: Wed, 30 Nov 2022 14:06:58 -0500 Subject: [PATCH] Add foreach option for faster EMA --- timm/utils/model_ema.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/timm/utils/model_ema.py b/timm/utils/model_ema.py index 0213d582..5cefe08b 100644 --- a/timm/utils/model_ema.py +++ b/timm/utils/model_ema.py @@ -102,30 +102,34 @@ class ModelEmaV2(nn.Module): This class is sensitive where it is initialized in the sequence of model init, GPU assignment and distributed training wrappers. """ - def __init__(self, model, decay=0.9999, device=None): + def __init__(self, model, decay=0.9999, device=None, foreach=False): super(ModelEmaV2, self).__init__() # make a copy of the model for accumulating moving average of weights self.module = deepcopy(model) self.module.eval() self.decay = decay + self.foreach = foreach self.device = device # perform ema on different device from model if set - if self.device is not None: + if self.device is not None and device != next(model.parameters()).device: + self.foreach = False # cannot use foreach methods with different devices self.module.to(device=device) - def _update(self, model, update_fn): - with torch.no_grad(): - for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): - if self.device is not None: - model_v = model_v.to(device=self.device) - update_fn(ema_v, model_v) - + @torch.no_grad() def update(self, model): + ema_params = tuple(self.module.parameters()) + model_params = tuple(model.parameters()) + if self.foreach: + torch._foreach_mul_(ema_params, scalar=self.decay) + torch._foreach_add_(ema_params, model_params, alpha=1 - self.decay) + else: + for ema_p, model_p in zip(ema_params, model_params): + ema_p.mul_(self.decay).add_(model_p.to(device=self.device), alpha=1 - self.decay) - def ema_update(e, m): - if m.is_floating_point(): - e.mul_(self.decay).add_(m, alpha=1 - self.decay) - - self._update(model, update_fn=ema_update) + # copy buffers instead of EMA + for ema_b, model_b in zip(self.module.buffers(), model.buffers()): + ema_b.copy_(model_b.to(device=self.device)) + @torch.no_grad() def set(self, model): - self._update(model, update_fn=lambda e, m: e.copy_(m)) + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + ema_v.copy_(model_v.to(device=self.device))