Add exponential moving average for model weights + few other additions and cleanup

* ModelEma class added to track an EMA set of weights for the model being trained
* EMA handling added to train, validation and clean_checkpoint scripts
* Add multi checkpoint or multi-model validation support to validate.py
* Add syncbn option (APEX) to train script for experimentation
* Cleanup interface of CheckpointSaver while adding ema functionality
pull/12/head
Ross Wightman 6 years ago
parent ff99625603
commit 9bcd65181b

@ -9,6 +9,8 @@ parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
parser.add_argument('--output', default='./cleaned.pth', type=str, metavar='PATH',
help='output path')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present')
def main():
@ -24,8 +26,13 @@ def main():
checkpoint = torch.load(args.checkpoint, map_location='cpu')
new_state_dict = OrderedDict()
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
if isinstance(checkpoint, dict):
state_dict_key = 'state_dict_ema' if args.use_ema else 'state_dict'
if state_dict_key in checkpoint:
state_dict = checkpoint[state_dict_key]
else:
print("Error: No state_dict found in checkpoint {}.".format(args.checkpoint))
exit(1)
else:
state_dict = checkpoint
for k, v in state_dict.items():

@ -4,22 +4,24 @@ import os
from collections import OrderedDict
def load_checkpoint(model, checkpoint_path):
def load_checkpoint(model, checkpoint_path, use_ema=False):
if checkpoint_path and os.path.isfile(checkpoint_path):
print("=> Loading checkpoint '{}'".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
state_dict_key = ''
if isinstance(checkpoint, dict):
state_dict_key = 'state_dict'
if use_ema and 'state_dict_ema' in checkpoint:
state_dict_key = 'state_dict_ema'
if state_dict_key and state_dict_key in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
if k.startswith('module'):
name = k[7:] # remove `module.`
else:
name = k
for k, v in checkpoint[state_dict_key].items():
# strip `module.` prefix
name = k[7:] if k.startswith('module') else k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
else:
model.load_state_dict(checkpoint)
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
print("=> Loaded {} from checkpoint '{}'".format(state_dict_key or 'weights', checkpoint_path))
else:
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()
@ -28,27 +30,24 @@ def load_checkpoint(model, checkpoint_path):
def resume_checkpoint(model, checkpoint_path, start_epoch=None):
optimizer_state = None
if os.path.isfile(checkpoint_path):
print("=> loading checkpoint '{}'".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items():
if k.startswith('module'):
name = k[7:] # remove `module.`
else:
name = k
name = k[7:] if k.startswith('module') else k
new_state_dict[name] = v
model.load_state_dict(new_state_dict)
if 'optimizer' in checkpoint:
optimizer_state = checkpoint['optimizer']
print("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch
print("=> Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
else:
model.load_state_dict(checkpoint)
start_epoch = 0 if start_epoch is None else start_epoch
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
return optimizer_state, start_epoch
else:
print("=> No checkpoint found at '{}'".format(checkpoint_path))
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError()

@ -89,7 +89,7 @@ class RMSpropTF(Optimizer):
state['step'] += 1
if group['weight_decay'] != 0:
if group['decoupled_decay']:
if 'decoupled_decay' in group and group['decoupled_decay']:
p.data.add_(-group['weight_decay'], p.data)
else:
grad = grad.add(group['weight_decay'], p.data)
@ -109,7 +109,7 @@ class RMSpropTF(Optimizer):
if group['momentum'] > 0:
buf = state['momentum_buffer']
# Tensorflow accumulates the LR scaling in the momentum buffer
if group['lr_in_momentum']:
if 'lr_in_momentum' in group and group['lr_in_momentum']:
buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg)
p.data.add_(-buf)
else:

@ -6,12 +6,13 @@ from datetime import datetime
try:
from apex import amp
from apex.parallel import DistributedDataParallel as DDP
from apex.parallel import convert_syncbn_model
has_apex = True
except ImportError:
has_apex = False
from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
from models import create_model, resume_checkpoint
from models import create_model, resume_checkpoint, load_checkpoint
from utils import *
from loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from optim import create_optimizer
@ -91,11 +92,17 @@ parser.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.9998,
help='decay factor for model weights moving average (default: 0.9998)')
parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=1000, metavar='N',
parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 1)')
@ -109,6 +116,8 @@ parser.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA amp for mixed precision training')
parser.add_argument('--sync-bn', action='store_true',
help='enabling apex sync BN.')
parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
parser.add_argument('--output', default='', type=str, metavar='PATH',
@ -131,31 +140,28 @@ def main():
args.device = 'cuda:0'
args.world_size = 1
r = -1
args.rank = 0 # global rank
if args.distributed:
args.num_gpu = 1
args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl',
init_method='env://')
torch.distributed.init_process_group(
backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size()
r = torch.distributed.get_rank()
args.rank = torch.distributed.get_rank()
assert args.rank >= 0
if args.distributed:
print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (r, args.world_size))
% (args.rank, args.world_size))
else:
print('Training with a single process on %d GPUs.' % args.num_gpu)
# FIXME seed handling for multi-process distributed?
torch.manual_seed(args.seed)
torch.manual_seed(args.seed + args.rank)
output_dir = ''
if args.local_rank == 0:
if args.output:
output_base = args.output
else:
output_base = './output'
output_base = args.output if args.output else './output'
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
args.model,
@ -191,6 +197,8 @@ def main():
args.amp = False
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
else:
if args.distributed and args.sync_bn and has_apex:
model = convert_syncbn_model(model)
model.cuda()
optimizer = create_optimizer(args, model)
@ -205,8 +213,20 @@ def main():
use_amp = False
print('AMP disabled')
model_ema = None
if args.model_ema:
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '',
resume=args.resume)
if args.distributed:
model = DDP(model, delay_allreduce=True)
if model_ema is not None and not args.model_ema_force_cpu:
# must also distribute EMA model to allow validation
model_ema.ema = DDP(model_ema.ema, delay_allreduce=True)
model_ema.ema_has_module = True
lr_scheduler, num_epochs = create_scheduler(args, optimizer)
if start_epoch > 0:
@ -273,6 +293,7 @@ def main():
eval_metric = args.eval_metric
saver = None
if output_dir:
# only set if process is rank 0
decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
best_metric = None
@ -284,10 +305,15 @@ def main():
train_metrics = train_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp)
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
use_amp=use_amp, model_ema=model_ema)
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
eval_metrics = validate(
model, loader_eval, validate_loss_fn, args)
if model_ema is not None and not args.model_ema_force_cpu:
ema_eval_metrics = validate(
model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')
eval_metrics = ema_eval_metrics
if lr_scheduler is not None:
lr_scheduler.step(epoch, eval_metrics[eval_metric])
@ -298,15 +324,12 @@ def main():
if saver is not None:
# save proper checkpoint with eval metric
best_metric, best_epoch = saver.save_checkpoint({
'epoch': epoch + 1,
'arch': args.model,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'args': args,
},
save_metric = eval_metrics[eval_metric]
best_metric, best_epoch = saver.save_checkpoint(
model, optimizer, args,
epoch=epoch + 1,
metric=eval_metrics[eval_metric])
model_ema=model_ema,
metric=save_metric)
except KeyboardInterrupt:
pass
@ -316,7 +339,7 @@ def main():
def train_epoch(
epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir='', use_amp=False):
lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None):
if args.prefetcher and args.mixup > 0 and loader.mixup_enabled:
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
@ -359,6 +382,8 @@ def train_epoch(
optimizer.step()
torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
num_updates += 1
batch_time_m.update(time.time() - end)
@ -394,18 +419,11 @@ def train_epoch(
padding=0,
normalize=True)
if args.local_rank == 0 and (
saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0):
if saver is not None and args.recovery_interval and (
last_batch or (batch_idx + 1) % args.recovery_interval == 0):
save_epoch = epoch + 1 if last_batch else epoch
saver.save_recovery({
'epoch': save_epoch,
'arch': args.model,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'args': args,
},
epoch=save_epoch,
batch_idx=batch_idx)
saver.save_recovery(
model, optimizer, args, save_epoch, model_ema=model_ema, batch_idx=batch_idx)
if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
@ -415,7 +433,7 @@ def train_epoch(
return OrderedDict([('loss', losses_m.avg)])
def validate(model, loader, loss_fn, args):
def validate(model, loader, loss_fn, args, log_suffix=''):
batch_time_m = AverageMeter()
losses_m = AverageMeter()
prec1_m = AverageMeter()
@ -461,12 +479,13 @@ def validate(model, loader, loss_fn, args):
batch_time_m.update(time.time() - end)
end = time.time()
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
print('Test: [{0}/{1}]\t'
log_name = 'Test' + log_suffix
print('{0}: [{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss {loss.val:.4f} ({loss.avg:.4f}) '
'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) '
'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format(
batch_idx, last_idx,
log_name, batch_idx, last_idx,
batch_time=batch_time_m, loss=losses_m,
top1=prec1_m, top5=prec5_m))
@ -475,12 +494,5 @@ def validate(model, loader, loss_fn, args):
return metrics
def reduce_tensor(tensor, n):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= n
return rt
if __name__ == '__main__':
main()

@ -1,6 +1,9 @@
from copy import deepcopy
import torch
import math
import os
import re
import shutil
import glob
import csv
@ -8,6 +11,15 @@ import operator
import numpy as np
from collections import OrderedDict
from torch import distributed as dist
def get_state_dict(model):
if isinstance(model, ModelEma):
return get_state_dict(model.ema)
else:
return model.module.state_dict() if getattr(model, 'module') else model.state_dict()
class CheckpointSaver:
def __init__(
@ -39,17 +51,16 @@ class CheckpointSaver:
self.max_history = max_history
assert self.max_history >= 1
def save_checkpoint(self, state, epoch, metric=None):
def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None):
assert epoch >= 0
worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
if len(self.checkpoint_files) < self.max_history or self.cmp(metric, worst_file[1]):
if (len(self.checkpoint_files) < self.max_history
or metric is None or self.cmp(metric, worst_file[1])):
if len(self.checkpoint_files) >= self.max_history:
self._cleanup_checkpoints(1)
filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
save_path = os.path.join(self.checkpoint_dir, filename)
if metric is not None:
state['metric'] = metric
torch.save(state, save_path)
self._save(save_path, model, optimizer, args, epoch, model_ema, metric)
self.checkpoint_files.append((save_path, metric))
self.checkpoint_files = sorted(
self.checkpoint_files, key=lambda x: x[1],
@ -67,6 +78,20 @@ class CheckpointSaver:
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None):
save_state = {
'epoch': epoch,
'arch': args.model,
'state_dict': get_state_dict(model),
'optimizer': optimizer.state_dict(),
'args': args
}
if model_ema is not None:
save_state['state_dict_ema'] = get_state_dict(model_ema)
if metric is not None:
save_state['metric'] = metric
torch.save(save_state, save_path)
def _cleanup_checkpoints(self, trim=0):
trim = min(len(self.checkpoint_files), trim)
delete_index = self.max_history - trim
@ -82,10 +107,11 @@ class CheckpointSaver:
print('Exception (%s) while deleting checkpoint' % str(e))
self.checkpoint_files = self.checkpoint_files[:delete_index]
def save_recovery(self, state, epoch, batch_idx):
def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0):
assert epoch >= 0
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
save_path = os.path.join(self.recovery_dir, filename)
torch.save(state, save_path)
self._save(save_path, model, optimizer, args, epoch, model_ema)
if os.path.exists(self.last_recovery_file):
try:
if self.verbose:
@ -165,3 +191,81 @@ def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=Fa
if write_header: # first iteration (epoch == 1 can't be used)
dw.writeheader()
dw.writerow(rowd)
def natural_key(string_):
"""See http://www.codinghorror.com/blog/archives/001018.html"""
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def reduce_tensor(tensor, n):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= n
return rt
class ModelEma:
""" Model Exponential Moving Average
Keep a moving average of everything in the model state_dict (parameters and buffers).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well.
E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
smoothing of weights to match results. Pay attention to the decay constant you are using
relative to your update count per epoch.
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
disable validation of the EMA weights. Validation will have to be done manually in a separate
process, or after the training stops converging.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
"""
def __init__(self, model, decay=0.9999, device='', resume=''):
# make a copy of the model for accumulating moving average of weights
self.ema = deepcopy(model)
self.ema.eval()
self.decay = decay
self.device = device # perform ema on different device from model if set
if device:
self.ema.to(device=device)
self.ema_has_module = hasattr(self.ema, 'module')
if resume:
self._load_checkpoint(resume)
for p in self.ema.parameters():
p.requires_grad_(False)
def _load_checkpoint(self, checkpoint_path):
checkpoint = torch.load(checkpoint_path)
assert isinstance(checkpoint, dict)
if 'state_dict_ema' in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict_ema'].items():
# ema model may have been wrapped by DataParallel, and need module prefix
if self.ema_has_module:
name = 'module.' + k if not k.startswith('module') else k
else:
name = k
new_state_dict[name] = v
self.ema.load_state_dict(new_state_dict)
print("=> loaded state_dict_ema")
else:
print("=> failed to find state_dict_ema, starting from loaded model weights)")
def update(self, model):
# correct a mismatch in state dict keys
needs_module = hasattr(model, 'module') and not self.ema_has_module
with torch.no_grad():
msd = model.state_dict()
for k, ema_v in self.ema.state_dict().items():
if needs_module:
k = 'module.' + k
model_v = msd[k].detach()
if self.device:
model_v = model_v.to(device=self.device)
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

@ -4,14 +4,17 @@ from __future__ import print_function
import argparse
import os
import csv
import glob
import time
import torch
import torch.nn as nn
import torch.nn.parallel
from collections import OrderedDict
from models import create_model, apply_test_time_pool
from models import create_model, apply_test_time_pool, load_checkpoint
from data import Dataset, create_loader, resolve_data_config
from utils import accuracy, AverageMeter
from utils import accuracy, AverageMeter, natural_key
torch.backends.cudnn.benchmark = True
@ -46,21 +49,26 @@ parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
help='disable test time pool')
parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true',
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present')
def main():
args = parser.parse_args()
def validate(args):
# create model
model = create_model(
args.model,
num_classes=args.num_classes,
in_chans=3,
pretrained=args.pretrained,
checkpoint_path=args.checkpoint)
pretrained=args.pretrained)
if args.checkpoint and not args.pretrained:
load_checkpoint(model, args.checkpoint, args.use_ema)
else:
args.pretrained = True # might as well try to validate something...
print('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
param_count = sum([m.numel() for m in model.parameters()])
print('Model %s created, param count: %d' % (args.model, param_count))
data_config = resolve_data_config(model, args)
model, test_time_pool = apply_test_time_pool(model, data_config, args)
@ -120,8 +128,52 @@ def main():
rate_avg=input.size(0) / batch_time.avg,
loss=losses, top1=top1, top5=top5))
print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format(
top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg))
results = OrderedDict(
top1=round(top1.avg, 3), top1_err=round(100 - top1.avg, 3),
top5=round(top5.avg, 3), top5_err=round(100 - top5.avg, 3),
param_count=round(param_count / 1e6, 2))
print(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
return results
def main():
args = parser.parse_args()
if args.model == 'all':
# validate all models in a list of names with pretrained checkpoints
args.pretrained = True
# FIXME just an example list, need to add model name collections for
# batch testing of various pretrained combinations by arg string
models = ['tf_efficientnet_b0', 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3']
model_cfgs = [(n, '') for n in models]
elif os.path.isdir(args.checkpoint):
# validate all checkpoints in a path with same model
checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
checkpoints += glob.glob(args.checkpoint + '/*.pth')
model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
else:
model_cfgs = []
if len(model_cfgs):
header_written = False
with open('./results-all.csv', mode='w') as cf:
for m, c in model_cfgs:
args.model = m
args.checkpoint = c
result = OrderedDict(model=args.model)
result.update(validate(args))
if args.checkpoint:
result['checkpoint'] = args.checkpoint
dw = csv.DictWriter(cf, fieldnames=result.keys())
if not header_written:
dw.writeheader()
header_written = True
dw.writerow(result)
cf.flush()
else:
validate(args)
if __name__ == '__main__':

Loading…
Cancel
Save