""" Exponential Moving Average (EMA) of model updates Hacked together by / Copyright 2020 Ross Wightman """ import logging from collections import OrderedDict from copy import deepcopy import torch import torch.nn as nn _logger = logging.getLogger(__name__) class ModelEma: """ Model Exponential Moving Average (DEPRECATED) Keep a moving average of everything in the model state_dict (parameters and buffers). This version is deprecated, it does not work with scripted models. Will be removed eventually. This is intended to allow functionality like https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage A smoothed version of the weights is necessary for some training schemes to perform well. E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA smoothing of weights to match results. Pay attention to the decay constant you are using relative to your update count per epoch. To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but disable validation of the EMA weights. Validation will have to be done manually in a separate process, or after the training stops converging. This class is sensitive where it is initialized in the sequence of model init, GPU assignment and distributed training wrappers. """ def __init__(self, model, decay=0.9999, device='', resume=''): # make a copy of the model for accumulating moving average of weights self.ema = deepcopy(model) self.ema.eval() self.decay = decay self.device = device # perform ema on different device from model if set if device: self.ema.to(device=device) self.ema_has_module = hasattr(self.ema, 'module') if resume: self._load_checkpoint(resume) for p in self.ema.parameters(): p.requires_grad_(False) def _load_checkpoint(self, checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') assert isinstance(checkpoint, dict) if 'state_dict_ema' in checkpoint: new_state_dict = OrderedDict() for k, v in checkpoint['state_dict_ema'].items(): # ema model may have been wrapped by DataParallel, and need module prefix if self.ema_has_module: name = 'module.' + k if not k.startswith('module') else k else: name = k new_state_dict[name] = v self.ema.load_state_dict(new_state_dict) _logger.info("Loaded state_dict_ema") else: _logger.warning("Failed to find state_dict_ema, starting from loaded model weights") def update(self, model): # correct a mismatch in state dict keys needs_module = hasattr(model, 'module') and not self.ema_has_module with torch.no_grad(): msd = model.state_dict() for k, ema_v in self.ema.state_dict().items(): if needs_module: k = 'module.' + k model_v = msd[k].detach() if self.device: model_v = model_v.to(device=self.device) ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) class ModelEmaV2(nn.Module): """ Model Exponential Moving Average V2 Keep a moving average of everything in the model state_dict (parameters and buffers). V2 of this module is simpler, it does not match params/buffers based on name but simply iterates in order. It works with torchscript (JIT of full model). This is intended to allow functionality like https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage A smoothed version of the weights is necessary for some training schemes to perform well. E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA smoothing of weights to match results. Pay attention to the decay constant you are using relative to your update count per epoch. To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but disable validation of the EMA weights. Validation will have to be done manually in a separate process, or after the training stops converging. This class is sensitive where it is initialized in the sequence of model init, GPU assignment and distributed training wrappers. """ def __init__(self, model, decay=0.9999, device=None): super(ModelEmaV2, self).__init__() # make a copy of the model for accumulating moving average of weights self.module = deepcopy(model) self.module.eval() self.decay = decay self.device = device # perform ema on different device from model if set if self.device is not None: self.module.to(device=device) def _update(self, model, update_fn): with torch.no_grad(): for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): if self.device is not None: model_v = model_v.to(device=self.device) ema_v.copy_(update_fn(ema_v, model_v)) def update(self, model): self._update(model, update_fn=lambda e, m: self.decay * e + (1. - self.decay) * m) def set(self, model): self._update(model, update_fn=lambda e, m: m)