From 9214ca071674ce62b0eff36f0a1e3eaaba6ec2e3 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 16 Nov 2020 12:51:52 -0800 Subject: [PATCH] Simplifying EMA... --- timm/utils/model.py | 5 +---- timm/utils/model_ema.py | 49 +++++++++-------------------------------- train.py | 2 +- 3 files changed, 12 insertions(+), 44 deletions(-) diff --git a/timm/utils/model.py b/timm/utils/model.py index cfd42806..0d6700b7 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -6,10 +6,7 @@ from .model_ema import ModelEma def unwrap_model(model): - if isinstance(model, ModelEma): - return unwrap_model(model.ema) - else: - return model.module if hasattr(model, 'module') else model + return model.module if hasattr(model, 'module') else model def get_state_dict(model, unwrap_fn=unwrap_model): diff --git a/timm/utils/model_ema.py b/timm/utils/model_ema.py index b788b32e..f146e471 100644 --- a/timm/utils/model_ema.py +++ b/timm/utils/model_ema.py @@ -2,16 +2,13 @@ Hacked together by / Copyright 2020 Ross Wightman """ -import logging -from collections import OrderedDict from copy import deepcopy import torch +import torch.nn as nn -_logger = logging.getLogger(__name__) - -class ModelEma: +class ModelEma(nn.Module): """ Model Exponential Moving Average Keep a moving average of everything in the model state_dict (parameters and buffers). @@ -32,46 +29,20 @@ class ModelEma: GPU assignment and distributed training wrappers. I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU. """ - def __init__(self, model, decay=0.9999, device='', resume=''): + def __init__(self, model, decay=0.9999, device=None): + super(ModelEma, self).__init__() # make a copy of the model for accumulating moving average of weights - self.ema = deepcopy(model) - self.ema.eval() + self.module = deepcopy(model) + self.module.eval() self.decay = decay self.device = device # perform ema on different device from model if set - if device: - self.ema.to(device=device) - self.ema_has_module = hasattr(self.ema, 'module') - if resume: - self._load_checkpoint(resume) - for p in self.ema.parameters(): - p.requires_grad_(False) - - def _load_checkpoint(self, checkpoint_path): - checkpoint = torch.load(checkpoint_path, map_location='cpu') - assert isinstance(checkpoint, dict) - if 'state_dict_ema' in checkpoint: - new_state_dict = OrderedDict() - for k, v in checkpoint['state_dict_ema'].items(): - # ema model may have been wrapped by DataParallel, and need module prefix - if self.ema_has_module: - name = 'module.' + k if not k.startswith('module') else k - else: - name = k - new_state_dict[name] = v - self.ema.load_state_dict(new_state_dict) - _logger.info("Loaded state_dict_ema") - else: - _logger.warning("Failed to find state_dict_ema, starting from loaded model weights") + if device is not None: + self.module.to(device=device) def update(self, model): - # correct a mismatch in state dict keys - needs_module = hasattr(model, 'module') and not self.ema_has_module with torch.no_grad(): - msd = model.state_dict() - for k, ema_v in self.ema.state_dict().items(): - if needs_module: - k = 'module.' + k - model_v = msd[k].detach() + for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()): + assert ema_v.shape == model_v.shape if self.device: model_v = model_v.to(device=self.device) ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) diff --git a/train.py b/train.py index ef3adf85..f56089e3 100755 --- a/train.py +++ b/train.py @@ -568,7 +568,7 @@ def main(): if args.distributed and args.dist_bn in ('broadcast', 'reduce'): distribute_bn(model_ema, args.world_size, args.dist_bn == 'reduce') ema_eval_metrics = validate( - model_ema.ema, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') + model_ema.module, loader_eval, validate_loss_fn, args, amp_autocast=amp_autocast, log_suffix=' (EMA)') eval_metrics = ema_eval_metrics if lr_scheduler is not None: