Merge pull request #12 from rwightman/ema-cleanup

Model weights Exponential Moving Average
pull/13/head
Ross Wightman 5 years ago committed by GitHub
commit 0e1fd11ad8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,6 +9,8 @@ parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)') help='path to latest checkpoint (default: none)')
parser.add_argument('--output', default='./cleaned.pth', type=str, metavar='PATH', parser.add_argument('--output', default='./cleaned.pth', type=str, metavar='PATH',
help='output path') help='output path')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present')
def main(): def main():
@ -24,8 +26,13 @@ def main():
checkpoint = torch.load(args.checkpoint, map_location='cpu') checkpoint = torch.load(args.checkpoint, map_location='cpu')
new_state_dict = OrderedDict() new_state_dict = OrderedDict()
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if isinstance(checkpoint, dict):
state_dict = checkpoint['state_dict'] state_dict_key = 'state_dict_ema' if args.use_ema else 'state_dict'
if state_dict_key in checkpoint:
state_dict = checkpoint[state_dict_key]
else:
print("Error: No state_dict found in checkpoint {}.".format(args.checkpoint))
exit(1)
else: else:
state_dict = checkpoint state_dict = checkpoint
for k, v in state_dict.items(): for k, v in state_dict.items():

@ -4,22 +4,24 @@ import os
from collections import OrderedDict from collections import OrderedDict
def load_checkpoint(model, checkpoint_path): def load_checkpoint(model, checkpoint_path, use_ema=False):
if checkpoint_path and os.path.isfile(checkpoint_path): if checkpoint_path and os.path.isfile(checkpoint_path):
print("=> Loading checkpoint '{}'".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path) checkpoint = torch.load(checkpoint_path)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: state_dict_key = ''
if isinstance(checkpoint, dict):
state_dict_key = 'state_dict'
if use_ema and 'state_dict_ema' in checkpoint:
state_dict_key = 'state_dict_ema'
if state_dict_key and state_dict_key in checkpoint:
new_state_dict = OrderedDict() new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items(): for k, v in checkpoint[state_dict_key].items():
if k.startswith('module'): # strip `module.` prefix
name = k[7:] # remove `module.` name = k[7:] if k.startswith('module') else k
else:
name = k
new_state_dict[name] = v new_state_dict[name] = v
model.load_state_dict(new_state_dict) model.load_state_dict(new_state_dict)
else: else:
model.load_state_dict(checkpoint) model.load_state_dict(checkpoint)
print("=> Loaded checkpoint '{}'".format(checkpoint_path)) print("=> Loaded {} from checkpoint '{}'".format(state_dict_key or 'weights', checkpoint_path))
else: else:
print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError() raise FileNotFoundError()
@ -28,27 +30,24 @@ def load_checkpoint(model, checkpoint_path):
def resume_checkpoint(model, checkpoint_path, start_epoch=None): def resume_checkpoint(model, checkpoint_path, start_epoch=None):
optimizer_state = None optimizer_state = None
if os.path.isfile(checkpoint_path): if os.path.isfile(checkpoint_path):
print("=> loading checkpoint '{}'".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path) checkpoint = torch.load(checkpoint_path)
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
new_state_dict = OrderedDict() new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items(): for k, v in checkpoint['state_dict'].items():
if k.startswith('module'): name = k[7:] if k.startswith('module') else k
name = k[7:] # remove `module.`
else:
name = k
new_state_dict[name] = v new_state_dict[name] = v
model.load_state_dict(new_state_dict) model.load_state_dict(new_state_dict)
if 'optimizer' in checkpoint: if 'optimizer' in checkpoint:
optimizer_state = checkpoint['optimizer'] optimizer_state = checkpoint['optimizer']
print("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch
print("=> Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
else: else:
model.load_state_dict(checkpoint) model.load_state_dict(checkpoint)
start_epoch = 0 if start_epoch is None else start_epoch start_epoch = 0 if start_epoch is None else start_epoch
print("=> Loaded checkpoint '{}'".format(checkpoint_path))
return optimizer_state, start_epoch return optimizer_state, start_epoch
else: else:
print("=> No checkpoint found at '{}'".format(checkpoint_path)) print("=> Error: No checkpoint found at '{}'".format(checkpoint_path))
raise FileNotFoundError() raise FileNotFoundError()

@ -89,7 +89,7 @@ class RMSpropTF(Optimizer):
state['step'] += 1 state['step'] += 1
if group['weight_decay'] != 0: if group['weight_decay'] != 0:
if group['decoupled_decay']: if 'decoupled_decay' in group and group['decoupled_decay']:
p.data.add_(-group['weight_decay'], p.data) p.data.add_(-group['weight_decay'], p.data)
else: else:
grad = grad.add(group['weight_decay'], p.data) grad = grad.add(group['weight_decay'], p.data)
@ -109,7 +109,7 @@ class RMSpropTF(Optimizer):
if group['momentum'] > 0: if group['momentum'] > 0:
buf = state['momentum_buffer'] buf = state['momentum_buffer']
# Tensorflow accumulates the LR scaling in the momentum buffer # Tensorflow accumulates the LR scaling in the momentum buffer
if group['lr_in_momentum']: if 'lr_in_momentum' in group and group['lr_in_momentum']:
buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg) buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg)
p.data.add_(-buf) p.data.add_(-buf)
else: else:

@ -6,12 +6,13 @@ from datetime import datetime
try: try:
from apex import amp from apex import amp
from apex.parallel import DistributedDataParallel as DDP from apex.parallel import DistributedDataParallel as DDP
from apex.parallel import convert_syncbn_model
has_apex = True has_apex = True
except ImportError: except ImportError:
has_apex = False has_apex = False
from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target
from models import create_model, resume_checkpoint from models import create_model, resume_checkpoint, load_checkpoint
from utils import * from utils import *
from loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from optim import create_optimizer from optim import create_optimizer
@ -41,8 +42,8 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
parser.add_argument('--pretrained', action='store_true', default=False, parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)') help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--img-size', type=int, default=224, metavar='N', parser.add_argument('--img-size', type=int, default=None, metavar='N',
help='Image patch size (default: 224)') help='Image patch size (default: None => model default)')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset') help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
@ -91,11 +92,17 @@ parser.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)') help='BatchNorm momentum override (if not None)')
parser.add_argument('--bn-eps', type=float, default=None, parser.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)') help='BatchNorm epsilon override (if not None)')
parser.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
parser.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
parser.add_argument('--model-ema-decay', type=float, default=0.9998,
help='decay factor for model weights moving average (default: 0.9998)')
parser.add_argument('--seed', type=int, default=42, metavar='S', parser.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)') help='random seed (default: 42)')
parser.add_argument('--log-interval', type=int, default=50, metavar='N', parser.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status') help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=1000, metavar='N', parser.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint') help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 1)') help='how many training processes to use (default: 1)')
@ -109,6 +116,8 @@ parser.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging') help='save images of input bathes every log interval for debugging')
parser.add_argument('--amp', action='store_true', default=False, parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA amp for mixed precision training') help='use NVIDIA amp for mixed precision training')
parser.add_argument('--sync-bn', action='store_true',
help='enabling apex sync BN.')
parser.add_argument('--no-prefetcher', action='store_true', default=False, parser.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher') help='disable fast prefetcher')
parser.add_argument('--output', default='', type=str, metavar='PATH', parser.add_argument('--output', default='', type=str, metavar='PATH',
@ -131,36 +140,24 @@ def main():
args.device = 'cuda:0' args.device = 'cuda:0'
args.world_size = 1 args.world_size = 1
r = -1 args.rank = 0 # global rank
if args.distributed: if args.distributed:
args.num_gpu = 1 args.num_gpu = 1
args.device = 'cuda:%d' % args.local_rank args.device = 'cuda:%d' % args.local_rank
torch.cuda.set_device(args.local_rank) torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group(backend='nccl', torch.distributed.init_process_group(
init_method='env://') backend='nccl', init_method='env://')
args.world_size = torch.distributed.get_world_size() args.world_size = torch.distributed.get_world_size()
r = torch.distributed.get_rank() args.rank = torch.distributed.get_rank()
assert args.rank >= 0
if args.distributed: if args.distributed:
print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
% (r, args.world_size)) % (args.rank, args.world_size))
else: else:
print('Training with a single process on %d GPUs.' % args.num_gpu) print('Training with a single process on %d GPUs.' % args.num_gpu)
# FIXME seed handling for multi-process distributed? torch.manual_seed(args.seed + args.rank)
torch.manual_seed(args.seed)
output_dir = ''
if args.local_rank == 0:
if args.output:
output_base = args.output
else:
output_base = './output'
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
args.model,
str(args.img_size)])
output_dir = get_outdir(output_base, 'train', exp_name)
model = create_model( model = create_model(
args.model, args.model,
@ -191,6 +188,8 @@ def main():
args.amp = False args.amp = False
model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda()
else: else:
if args.distributed and args.sync_bn and has_apex:
model = convert_syncbn_model(model)
model.cuda() model.cuda()
optimizer = create_optimizer(args, model) optimizer = create_optimizer(args, model)
@ -205,8 +204,20 @@ def main():
use_amp = False use_amp = False
print('AMP disabled') print('AMP disabled')
model_ema = None
if args.model_ema:
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '',
resume=args.resume)
if args.distributed: if args.distributed:
model = DDP(model, delay_allreduce=True) model = DDP(model, delay_allreduce=True)
if model_ema is not None and not args.model_ema_force_cpu:
# must also distribute EMA model to allow validation
model_ema.ema = DDP(model_ema.ema, delay_allreduce=True)
model_ema.ema_has_module = True
lr_scheduler, num_epochs = create_scheduler(args, optimizer) lr_scheduler, num_epochs = create_scheduler(args, optimizer)
if start_epoch > 0: if start_epoch > 0:
@ -271,12 +282,21 @@ def main():
validate_loss_fn = train_loss_fn validate_loss_fn = train_loss_fn
eval_metric = args.eval_metric eval_metric = args.eval_metric
best_metric = None
best_epoch = None
saver = None saver = None
if output_dir: output_dir = ''
if args.local_rank == 0:
output_base = args.output if args.output else './output'
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
args.model,
str(data_config['input_size'][-1])
])
output_dir = get_outdir(output_base, 'train', exp_name)
decreasing = True if eval_metric == 'loss' else False decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
best_metric = None
best_epoch = None
try: try:
for epoch in range(start_epoch, num_epochs): for epoch in range(start_epoch, num_epochs):
if args.distributed: if args.distributed:
@ -284,10 +304,15 @@ def main():
train_metrics = train_epoch( train_metrics = train_epoch(
epoch, model, loader_train, optimizer, train_loss_fn, args, epoch, model, loader_train, optimizer, train_loss_fn, args,
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp) lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
use_amp=use_amp, model_ema=model_ema)
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
eval_metrics = validate( if model_ema is not None and not args.model_ema_force_cpu:
model, loader_eval, validate_loss_fn, args) ema_eval_metrics = validate(
model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')
eval_metrics = ema_eval_metrics
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step(epoch, eval_metrics[eval_metric]) lr_scheduler.step(epoch, eval_metrics[eval_metric])
@ -298,15 +323,12 @@ def main():
if saver is not None: if saver is not None:
# save proper checkpoint with eval metric # save proper checkpoint with eval metric
best_metric, best_epoch = saver.save_checkpoint({ save_metric = eval_metrics[eval_metric]
'epoch': epoch + 1, best_metric, best_epoch = saver.save_checkpoint(
'arch': args.model, model, optimizer, args,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'args': args,
},
epoch=epoch + 1, epoch=epoch + 1,
metric=eval_metrics[eval_metric]) model_ema=model_ema,
metric=save_metric)
except KeyboardInterrupt: except KeyboardInterrupt:
pass pass
@ -316,7 +338,7 @@ def main():
def train_epoch( def train_epoch(
epoch, model, loader, optimizer, loss_fn, args, epoch, model, loader, optimizer, loss_fn, args,
lr_scheduler=None, saver=None, output_dir='', use_amp=False): lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None):
if args.prefetcher and args.mixup > 0 and loader.mixup_enabled: if args.prefetcher and args.mixup > 0 and loader.mixup_enabled:
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
@ -359,6 +381,8 @@ def train_epoch(
optimizer.step() optimizer.step()
torch.cuda.synchronize() torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
num_updates += 1 num_updates += 1
batch_time_m.update(time.time() - end) batch_time_m.update(time.time() - end)
@ -394,18 +418,11 @@ def train_epoch(
padding=0, padding=0,
normalize=True) normalize=True)
if args.local_rank == 0 and ( if saver is not None and args.recovery_interval and (
saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0): last_batch or (batch_idx + 1) % args.recovery_interval == 0):
save_epoch = epoch + 1 if last_batch else epoch save_epoch = epoch + 1 if last_batch else epoch
saver.save_recovery({ saver.save_recovery(
'epoch': save_epoch, model, optimizer, args, save_epoch, model_ema=model_ema, batch_idx=batch_idx)
'arch': args.model,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'args': args,
},
epoch=save_epoch,
batch_idx=batch_idx)
if lr_scheduler is not None: if lr_scheduler is not None:
lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg)
@ -415,7 +432,7 @@ def train_epoch(
return OrderedDict([('loss', losses_m.avg)]) return OrderedDict([('loss', losses_m.avg)])
def validate(model, loader, loss_fn, args): def validate(model, loader, loss_fn, args, log_suffix=''):
batch_time_m = AverageMeter() batch_time_m = AverageMeter()
losses_m = AverageMeter() losses_m = AverageMeter()
prec1_m = AverageMeter() prec1_m = AverageMeter()
@ -461,12 +478,13 @@ def validate(model, loader, loss_fn, args):
batch_time_m.update(time.time() - end) batch_time_m.update(time.time() - end)
end = time.time() end = time.time()
if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0):
print('Test: [{0}/{1}]\t' log_name = 'Test' + log_suffix
print('{0}: [{1}/{2}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss {loss.val:.4f} ({loss.avg:.4f}) ' 'Loss {loss.val:.4f} ({loss.avg:.4f}) '
'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) ' 'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) '
'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format( 'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format(
batch_idx, last_idx, log_name, batch_idx, last_idx,
batch_time=batch_time_m, loss=losses_m, batch_time=batch_time_m, loss=losses_m,
top1=prec1_m, top5=prec5_m)) top1=prec1_m, top5=prec5_m))
@ -475,12 +493,5 @@ def validate(model, loader, loss_fn, args):
return metrics return metrics
def reduce_tensor(tensor, n):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= n
return rt
if __name__ == '__main__': if __name__ == '__main__':
main() main()

@ -1,6 +1,9 @@
from copy import deepcopy
import torch import torch
import math import math
import os import os
import re
import shutil import shutil
import glob import glob
import csv import csv
@ -8,6 +11,15 @@ import operator
import numpy as np import numpy as np
from collections import OrderedDict from collections import OrderedDict
from torch import distributed as dist
def get_state_dict(model):
if isinstance(model, ModelEma):
return get_state_dict(model.ema)
else:
return model.module.state_dict() if getattr(model, 'module') else model.state_dict()
class CheckpointSaver: class CheckpointSaver:
def __init__( def __init__(
@ -39,17 +51,16 @@ class CheckpointSaver:
self.max_history = max_history self.max_history = max_history
assert self.max_history >= 1 assert self.max_history >= 1
def save_checkpoint(self, state, epoch, metric=None): def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None):
assert epoch >= 0
worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
if len(self.checkpoint_files) < self.max_history or self.cmp(metric, worst_file[1]): if (len(self.checkpoint_files) < self.max_history
or metric is None or self.cmp(metric, worst_file[1])):
if len(self.checkpoint_files) >= self.max_history: if len(self.checkpoint_files) >= self.max_history:
self._cleanup_checkpoints(1) self._cleanup_checkpoints(1)
filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension
save_path = os.path.join(self.checkpoint_dir, filename) save_path = os.path.join(self.checkpoint_dir, filename)
if metric is not None: self._save(save_path, model, optimizer, args, epoch, model_ema, metric)
state['metric'] = metric
torch.save(state, save_path)
self.checkpoint_files.append((save_path, metric)) self.checkpoint_files.append((save_path, metric))
self.checkpoint_files = sorted( self.checkpoint_files = sorted(
self.checkpoint_files, key=lambda x: x[1], self.checkpoint_files, key=lambda x: x[1],
@ -67,6 +78,20 @@ class CheckpointSaver:
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None):
save_state = {
'epoch': epoch,
'arch': args.model,
'state_dict': get_state_dict(model),
'optimizer': optimizer.state_dict(),
'args': args
}
if model_ema is not None:
save_state['state_dict_ema'] = get_state_dict(model_ema)
if metric is not None:
save_state['metric'] = metric
torch.save(save_state, save_path)
def _cleanup_checkpoints(self, trim=0): def _cleanup_checkpoints(self, trim=0):
trim = min(len(self.checkpoint_files), trim) trim = min(len(self.checkpoint_files), trim)
delete_index = self.max_history - trim delete_index = self.max_history - trim
@ -82,10 +107,11 @@ class CheckpointSaver:
print('Exception (%s) while deleting checkpoint' % str(e)) print('Exception (%s) while deleting checkpoint' % str(e))
self.checkpoint_files = self.checkpoint_files[:delete_index] self.checkpoint_files = self.checkpoint_files[:delete_index]
def save_recovery(self, state, epoch, batch_idx): def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0):
assert epoch >= 0
filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension
save_path = os.path.join(self.recovery_dir, filename) save_path = os.path.join(self.recovery_dir, filename)
torch.save(state, save_path) self._save(save_path, model, optimizer, args, epoch, model_ema)
if os.path.exists(self.last_recovery_file): if os.path.exists(self.last_recovery_file):
try: try:
if self.verbose: if self.verbose:
@ -165,3 +191,81 @@ def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=Fa
if write_header: # first iteration (epoch == 1 can't be used) if write_header: # first iteration (epoch == 1 can't be used)
dw.writeheader() dw.writeheader()
dw.writerow(rowd) dw.writerow(rowd)
def natural_key(string_):
"""See http://www.codinghorror.com/blog/archives/001018.html"""
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def reduce_tensor(tensor, n):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= n
return rt
class ModelEma:
""" Model Exponential Moving Average
Keep a moving average of everything in the model state_dict (parameters and buffers).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well.
E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use
RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA
smoothing of weights to match results. Pay attention to the decay constant you are using
relative to your update count per epoch.
To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
disable validation of the EMA weights. Validation will have to be done manually in a separate
process, or after the training stops converging.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU.
"""
def __init__(self, model, decay=0.9999, device='', resume=''):
# make a copy of the model for accumulating moving average of weights
self.ema = deepcopy(model)
self.ema.eval()
self.decay = decay
self.device = device # perform ema on different device from model if set
if device:
self.ema.to(device=device)
self.ema_has_module = hasattr(self.ema, 'module')
if resume:
self._load_checkpoint(resume)
for p in self.ema.parameters():
p.requires_grad_(False)
def _load_checkpoint(self, checkpoint_path):
checkpoint = torch.load(checkpoint_path)
assert isinstance(checkpoint, dict)
if 'state_dict_ema' in checkpoint:
new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict_ema'].items():
# ema model may have been wrapped by DataParallel, and need module prefix
if self.ema_has_module:
name = 'module.' + k if not k.startswith('module') else k
else:
name = k
new_state_dict[name] = v
self.ema.load_state_dict(new_state_dict)
print("=> Loaded state_dict_ema")
else:
print("=> Failed to find state_dict_ema, starting from loaded model weights")
def update(self, model):
# correct a mismatch in state dict keys
needs_module = hasattr(model, 'module') and not self.ema_has_module
with torch.no_grad():
msd = model.state_dict()
for k, ema_v in self.ema.state_dict().items():
if needs_module:
k = 'module.' + k
model_v = msd[k].detach()
if self.device:
model_v = model_v.to(device=self.device)
ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

@ -4,14 +4,17 @@ from __future__ import print_function
import argparse import argparse
import os import os
import csv
import glob
import time import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.parallel import torch.nn.parallel
from collections import OrderedDict
from models import create_model, apply_test_time_pool from models import create_model, apply_test_time_pool, load_checkpoint
from data import Dataset, create_loader, resolve_data_config from data import Dataset, create_loader, resolve_data_config
from utils import accuracy, AverageMeter from utils import accuracy, AverageMeter, natural_key
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
@ -46,21 +49,26 @@ parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
help='disable test time pool') help='disable test time pool')
parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true', parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true',
help='Use Tensorflow preprocessing pipeline (require CPU TF installed') help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present')
def main(): def validate(args):
args = parser.parse_args()
# create model # create model
model = create_model( model = create_model(
args.model, args.model,
num_classes=args.num_classes, num_classes=args.num_classes,
in_chans=3, in_chans=3,
pretrained=args.pretrained, pretrained=args.pretrained)
checkpoint_path=args.checkpoint)
if args.checkpoint and not args.pretrained:
load_checkpoint(model, args.checkpoint, args.use_ema)
else:
args.pretrained = True # might as well try to validate something...
print('Model %s created, param count: %d' % param_count = sum([m.numel() for m in model.parameters()])
(args.model, sum([m.numel() for m in model.parameters()]))) print('Model %s created, param count: %d' % (args.model, param_count))
data_config = resolve_data_config(model, args) data_config = resolve_data_config(model, args)
model, test_time_pool = apply_test_time_pool(model, data_config, args) model, test_time_pool = apply_test_time_pool(model, data_config, args)
@ -120,8 +128,52 @@ def main():
rate_avg=input.size(0) / batch_time.avg, rate_avg=input.size(0) / batch_time.avg,
loss=losses, top1=top1, top5=top5)) loss=losses, top1=top1, top5=top5))
print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format( results = OrderedDict(
top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg)) top1=round(top1.avg, 3), top1_err=round(100 - top1.avg, 3),
top5=round(top5.avg, 3), top5_err=round(100 - top5.avg, 3),
param_count=round(param_count / 1e6, 2))
print(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format(
results['top1'], results['top1_err'], results['top5'], results['top5_err']))
return results
def main():
args = parser.parse_args()
if args.model == 'all':
# validate all models in a list of names with pretrained checkpoints
args.pretrained = True
# FIXME just an example list, need to add model name collections for
# batch testing of various pretrained combinations by arg string
models = ['tf_efficientnet_b0', 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3']
model_cfgs = [(n, '') for n in models]
elif os.path.isdir(args.checkpoint):
# validate all checkpoints in a path with same model
checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
checkpoints += glob.glob(args.checkpoint + '/*.pth')
model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
else:
model_cfgs = []
if len(model_cfgs):
header_written = False
with open('./results-all.csv', mode='w') as cf:
for m, c in model_cfgs:
args.model = m
args.checkpoint = c
result = OrderedDict(model=args.model)
result.update(validate(args))
if args.checkpoint:
result['checkpoint'] = args.checkpoint
dw = csv.DictWriter(cf, fieldnames=result.keys())
if not header_written:
dw.writeheader()
header_written = True
dw.writerow(result)
cf.flush()
else:
validate(args)
if __name__ == '__main__': if __name__ == '__main__':

Loading…
Cancel
Save