You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
49 lines
2.1 KiB
49 lines
2.1 KiB
""" Exponential Moving Average (EMA) of model updates
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
"""
|
|
from copy import deepcopy
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class ModelEma(nn.Module):
|
|
""" Model Exponential Moving Average
|
|
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
|
|
|
This is intended to allow functionality like
|
|
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
|
|
|
A smoothed version of the weights is necessary for some training schemes to perform well.
|
|
E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
|
|
RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
|
|
smoothing of weights to match results. Pay attention to the decay constant you are using
|
|
relative to your update count per epoch.
|
|
|
|
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
|
|
disable validation of the EMA weights. Validation will have to be done manually in a separate
|
|
process, or after the training stops converging.
|
|
|
|
This class is sensitive where it is initialized in the sequence of model init,
|
|
GPU assignment and distributed training wrappers.
|
|
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
|
|
"""
|
|
def __init__(self, model, decay=0.9999, device=None):
|
|
super(ModelEma, self).__init__()
|
|
# make a copy of the model for accumulating moving average of weights
|
|
self.module = deepcopy(model)
|
|
self.module.eval()
|
|
self.decay = decay
|
|
self.device = device # perform ema on different device from model if set
|
|
if device is not None:
|
|
self.module.to(device=device)
|
|
|
|
def update(self, model):
|
|
with torch.no_grad():
|
|
for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
|
|
assert ema_v.shape == model_v.shape
|
|
if self.device:
|
|
model_v = model_v.to(device=self.device)
|
|
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
|