From 9bcd65181badec95213b47edbf3a96bc12f8e39b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 7 Jun 2019 15:39:36 -0700 Subject: [PATCH 1/2] Add exponential moving average for model weights + few other additions and cleanup * ModelEma class added to track an EMA set of weights for the model being trained * EMA handling added to train, validation and clean_checkpoint scripts * Add multi checkpoint or multi-model validation support to validate.py * Add syncbn option (APEX) to train script for experimentation * Cleanup interface of CheckpointSaver while adding ema functionality --- clean_checkpoint.py | 11 +++- models/helpers.py | 31 ++++++------ optim/rmsprop_tf.py | 4 +- train.py | 104 +++++++++++++++++++++----------------- utils.py | 120 +++++++++++++++++++++++++++++++++++++++++--- validate.py | 72 ++++++++++++++++++++++---- 6 files changed, 258 insertions(+), 84 deletions(-) diff --git a/clean_checkpoint.py b/clean_checkpoint.py index 471630b6..59a6e306 100644 --- a/clean_checkpoint.py +++ b/clean_checkpoint.py @@ -9,6 +9,8 @@ parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('--output', default='./cleaned.pth', type=str, metavar='PATH', help='output path') +parser.add_argument('--use-ema', dest='use_ema', action='store_true', + help='use ema version of weights if present') def main(): @@ -24,8 +26,13 @@ def main(): checkpoint = torch.load(args.checkpoint, map_location='cpu') new_state_dict = OrderedDict() - if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: - state_dict = checkpoint['state_dict'] + if isinstance(checkpoint, dict): + state_dict_key = 'state_dict_ema' if args.use_ema else 'state_dict' + if state_dict_key in checkpoint: + state_dict = checkpoint[state_dict_key] + else: + print("Error: No state_dict found in checkpoint {}.".format(args.checkpoint)) + exit(1) else: state_dict = checkpoint for k, v in state_dict.items(): diff --git a/models/helpers.py b/models/helpers.py index deb4d1b1..fbf92e37 100644 --- a/models/helpers.py +++ b/models/helpers.py @@ -4,22 +4,24 @@ import os from collections import OrderedDict -def load_checkpoint(model, checkpoint_path): +def load_checkpoint(model, checkpoint_path, use_ema=False): if checkpoint_path and os.path.isfile(checkpoint_path): - print("=> Loading checkpoint '{}'".format(checkpoint_path)) checkpoint = torch.load(checkpoint_path) - if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + state_dict_key = '' + if isinstance(checkpoint, dict): + state_dict_key = 'state_dict' + if use_ema and 'state_dict_ema' in checkpoint: + state_dict_key = 'state_dict_ema' + if state_dict_key and state_dict_key in checkpoint: new_state_dict = OrderedDict() - for k, v in checkpoint['state_dict'].items(): - if k.startswith('module'): - name = k[7:] # remove `module.` - else: - name = k + for k, v in checkpoint[state_dict_key].items(): + # strip `module.` prefix + name = k[7:] if k.startswith('module') else k new_state_dict[name] = v model.load_state_dict(new_state_dict) else: model.load_state_dict(checkpoint) - print("=> Loaded checkpoint '{}'".format(checkpoint_path)) + print("=> Loaded {} from checkpoint '{}'".format(state_dict_key or 'weights', checkpoint_path)) else: print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() @@ -28,27 +30,24 @@ def load_checkpoint(model, checkpoint_path): def resume_checkpoint(model, checkpoint_path, start_epoch=None): optimizer_state = None if os.path.isfile(checkpoint_path): - print("=> loading checkpoint '{}'".format(checkpoint_path)) checkpoint = torch.load(checkpoint_path) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: new_state_dict = OrderedDict() for k, v in checkpoint['state_dict'].items(): - if k.startswith('module'): - name = k[7:] # remove `module.` - else: - name = k + name = k[7:] if k.startswith('module') else k new_state_dict[name] = v model.load_state_dict(new_state_dict) if 'optimizer' in checkpoint: optimizer_state = checkpoint['optimizer'] - print("=> loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch + print("=> Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) else: model.load_state_dict(checkpoint) start_epoch = 0 if start_epoch is None else start_epoch + print("=> Loaded checkpoint '{}'".format(checkpoint_path)) return optimizer_state, start_epoch else: - print("=> No checkpoint found at '{}'".format(checkpoint_path)) + print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() diff --git a/optim/rmsprop_tf.py b/optim/rmsprop_tf.py index b4298734..0f44b944 100644 --- a/optim/rmsprop_tf.py +++ b/optim/rmsprop_tf.py @@ -89,7 +89,7 @@ class RMSpropTF(Optimizer): state['step'] += 1 if group['weight_decay'] != 0: - if group['decoupled_decay']: + if 'decoupled_decay' in group and group['decoupled_decay']: p.data.add_(-group['weight_decay'], p.data) else: grad = grad.add(group['weight_decay'], p.data) @@ -109,7 +109,7 @@ class RMSpropTF(Optimizer): if group['momentum'] > 0: buf = state['momentum_buffer'] # Tensorflow accumulates the LR scaling in the momentum buffer - if group['lr_in_momentum']: + if 'lr_in_momentum' in group and group['lr_in_momentum']: buf.mul_(group['momentum']).addcdiv_(group['lr'], grad, avg) p.data.add_(-buf) else: diff --git a/train.py b/train.py index 9a81eecb..afd54a6f 100644 --- a/train.py +++ b/train.py @@ -6,12 +6,13 @@ from datetime import datetime try: from apex import amp from apex.parallel import DistributedDataParallel as DDP + from apex.parallel import convert_syncbn_model has_apex = True except ImportError: has_apex = False from data import Dataset, create_loader, resolve_data_config, FastCollateMixup, mixup_target -from models import create_model, resume_checkpoint +from models import create_model, resume_checkpoint, load_checkpoint from utils import * from loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy from optim import create_optimizer @@ -91,11 +92,17 @@ parser.add_argument('--bn-momentum', type=float, default=None, help='BatchNorm momentum override (if not None)') parser.add_argument('--bn-eps', type=float, default=None, help='BatchNorm epsilon override (if not None)') +parser.add_argument('--model-ema', action='store_true', default=False, + help='Enable tracking moving average of model weights') +parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, + help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') +parser.add_argument('--model-ema-decay', type=float, default=0.9998, + help='decay factor for model weights moving average (default: 0.9998)') parser.add_argument('--seed', type=int, default=42, metavar='S', help='random seed (default: 42)') parser.add_argument('--log-interval', type=int, default=50, metavar='N', help='how many batches to wait before logging training status') -parser.add_argument('--recovery-interval', type=int, default=1000, metavar='N', +parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', help='how many batches to wait before writing recovery checkpoint') parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', help='how many training processes to use (default: 1)') @@ -109,6 +116,8 @@ parser.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') parser.add_argument('--amp', action='store_true', default=False, help='use NVIDIA amp for mixed precision training') +parser.add_argument('--sync-bn', action='store_true', + help='enabling apex sync BN.') parser.add_argument('--no-prefetcher', action='store_true', default=False, help='disable fast prefetcher') parser.add_argument('--output', default='', type=str, metavar='PATH', @@ -131,31 +140,28 @@ def main(): args.device = 'cuda:0' args.world_size = 1 - r = -1 + args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) - torch.distributed.init_process_group(backend='nccl', - init_method='env://') + torch.distributed.init_process_group( + backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() - r = torch.distributed.get_rank() + args.rank = torch.distributed.get_rank() + assert args.rank >= 0 if args.distributed: print('Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' - % (r, args.world_size)) + % (args.rank, args.world_size)) else: print('Training with a single process on %d GPUs.' % args.num_gpu) - # FIXME seed handling for multi-process distributed? - torch.manual_seed(args.seed) + torch.manual_seed(args.seed + args.rank) output_dir = '' if args.local_rank == 0: - if args.output: - output_base = args.output - else: - output_base = './output' + output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, @@ -191,6 +197,8 @@ def main(): args.amp = False model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: + if args.distributed and args.sync_bn and has_apex: + model = convert_syncbn_model(model) model.cuda() optimizer = create_optimizer(args, model) @@ -205,8 +213,20 @@ def main(): use_amp = False print('AMP disabled') + model_ema = None + if args.model_ema: + model_ema = ModelEma( + model, + decay=args.model_ema_decay, + device='cpu' if args.model_ema_force_cpu else '', + resume=args.resume) + if args.distributed: model = DDP(model, delay_allreduce=True) + if model_ema is not None and not args.model_ema_force_cpu: + # must also distribute EMA model to allow validation + model_ema.ema = DDP(model_ema.ema, delay_allreduce=True) + model_ema.ema_has_module = True lr_scheduler, num_epochs = create_scheduler(args, optimizer) if start_epoch > 0: @@ -273,6 +293,7 @@ def main(): eval_metric = args.eval_metric saver = None if output_dir: + # only set if process is rank 0 decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) best_metric = None @@ -284,10 +305,15 @@ def main(): train_metrics = train_epoch( epoch, model, loader_train, optimizer, train_loss_fn, args, - lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp) + lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, + use_amp=use_amp, model_ema=model_ema) + + eval_metrics = validate(model, loader_eval, validate_loss_fn, args) - eval_metrics = validate( - model, loader_eval, validate_loss_fn, args) + if model_ema is not None and not args.model_ema_force_cpu: + ema_eval_metrics = validate( + model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)') + eval_metrics = ema_eval_metrics if lr_scheduler is not None: lr_scheduler.step(epoch, eval_metrics[eval_metric]) @@ -298,15 +324,12 @@ def main(): if saver is not None: # save proper checkpoint with eval metric - best_metric, best_epoch = saver.save_checkpoint({ - 'epoch': epoch + 1, - 'arch': args.model, - 'state_dict': model.state_dict(), - 'optimizer': optimizer.state_dict(), - 'args': args, - }, + save_metric = eval_metrics[eval_metric] + best_metric, best_epoch = saver.save_checkpoint( + model, optimizer, args, epoch=epoch + 1, - metric=eval_metrics[eval_metric]) + model_ema=model_ema, + metric=save_metric) except KeyboardInterrupt: pass @@ -316,7 +339,7 @@ def main(): def train_epoch( epoch, model, loader, optimizer, loss_fn, args, - lr_scheduler=None, saver=None, output_dir='', use_amp=False): + lr_scheduler=None, saver=None, output_dir='', use_amp=False, model_ema=None): if args.prefetcher and args.mixup > 0 and loader.mixup_enabled: if args.mixup_off_epoch and epoch >= args.mixup_off_epoch: @@ -359,6 +382,8 @@ def train_epoch( optimizer.step() torch.cuda.synchronize() + if model_ema is not None: + model_ema.update(model) num_updates += 1 batch_time_m.update(time.time() - end) @@ -394,18 +419,11 @@ def train_epoch( padding=0, normalize=True) - if args.local_rank == 0 and ( - saver is not None and last_batch or (batch_idx + 1) % args.recovery_interval == 0): + if saver is not None and args.recovery_interval and ( + last_batch or (batch_idx + 1) % args.recovery_interval == 0): save_epoch = epoch + 1 if last_batch else epoch - saver.save_recovery({ - 'epoch': save_epoch, - 'arch': args.model, - 'state_dict': model.state_dict(), - 'optimizer': optimizer.state_dict(), - 'args': args, - }, - epoch=save_epoch, - batch_idx=batch_idx) + saver.save_recovery( + model, optimizer, args, save_epoch, model_ema=model_ema, batch_idx=batch_idx) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) @@ -415,7 +433,7 @@ def train_epoch( return OrderedDict([('loss', losses_m.avg)]) -def validate(model, loader, loss_fn, args): +def validate(model, loader, loss_fn, args, log_suffix=''): batch_time_m = AverageMeter() losses_m = AverageMeter() prec1_m = AverageMeter() @@ -461,12 +479,13 @@ def validate(model, loader, loss_fn, args): batch_time_m.update(time.time() - end) end = time.time() if args.local_rank == 0 and (last_batch or batch_idx % args.log_interval == 0): - print('Test: [{0}/{1}]\t' + log_name = 'Test' + log_suffix + print('{0}: [{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Loss {loss.val:.4f} ({loss.avg:.4f}) ' 'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) ' 'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format( - batch_idx, last_idx, + log_name, batch_idx, last_idx, batch_time=batch_time_m, loss=losses_m, top1=prec1_m, top5=prec5_m)) @@ -475,12 +494,5 @@ def validate(model, loader, loss_fn, args): return metrics -def reduce_tensor(tensor, n): - rt = tensor.clone() - dist.all_reduce(rt, op=dist.ReduceOp.SUM) - rt /= n - return rt - - if __name__ == '__main__': main() diff --git a/utils.py b/utils.py index 48936aad..43f58117 100644 --- a/utils.py +++ b/utils.py @@ -1,6 +1,9 @@ +from copy import deepcopy + import torch import math import os +import re import shutil import glob import csv @@ -8,6 +11,15 @@ import operator import numpy as np from collections import OrderedDict +from torch import distributed as dist + + +def get_state_dict(model): + if isinstance(model, ModelEma): + return get_state_dict(model.ema) + else: + return model.module.state_dict() if getattr(model, 'module') else model.state_dict() + class CheckpointSaver: def __init__( @@ -39,17 +51,16 @@ class CheckpointSaver: self.max_history = max_history assert self.max_history >= 1 - def save_checkpoint(self, state, epoch, metric=None): + def save_checkpoint(self, model, optimizer, args, epoch, model_ema=None, metric=None): + assert epoch >= 0 worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None - if len(self.checkpoint_files) < self.max_history or self.cmp(metric, worst_file[1]): + if (len(self.checkpoint_files) < self.max_history + or metric is None or self.cmp(metric, worst_file[1])): if len(self.checkpoint_files) >= self.max_history: self._cleanup_checkpoints(1) - filename = '-'.join([self.save_prefix, str(epoch)]) + self.extension save_path = os.path.join(self.checkpoint_dir, filename) - if metric is not None: - state['metric'] = metric - torch.save(state, save_path) + self._save(save_path, model, optimizer, args, epoch, model_ema, metric) self.checkpoint_files.append((save_path, metric)) self.checkpoint_files = sorted( self.checkpoint_files, key=lambda x: x[1], @@ -67,6 +78,20 @@ class CheckpointSaver: return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) + def _save(self, save_path, model, optimizer, args, epoch, model_ema=None, metric=None): + save_state = { + 'epoch': epoch, + 'arch': args.model, + 'state_dict': get_state_dict(model), + 'optimizer': optimizer.state_dict(), + 'args': args + } + if model_ema is not None: + save_state['state_dict_ema'] = get_state_dict(model_ema) + if metric is not None: + save_state['metric'] = metric + torch.save(save_state, save_path) + def _cleanup_checkpoints(self, trim=0): trim = min(len(self.checkpoint_files), trim) delete_index = self.max_history - trim @@ -82,10 +107,11 @@ class CheckpointSaver: print('Exception (%s) while deleting checkpoint' % str(e)) self.checkpoint_files = self.checkpoint_files[:delete_index] - def save_recovery(self, state, epoch, batch_idx): + def save_recovery(self, model, optimizer, args, epoch, model_ema=None, batch_idx=0): + assert epoch >= 0 filename = '-'.join([self.recovery_prefix, str(epoch), str(batch_idx)]) + self.extension save_path = os.path.join(self.recovery_dir, filename) - torch.save(state, save_path) + self._save(save_path, model, optimizer, args, epoch, model_ema) if os.path.exists(self.last_recovery_file): try: if self.verbose: @@ -165,3 +191,81 @@ def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=Fa if write_header: # first iteration (epoch == 1 can't be used) dw.writeheader() dw.writerow(rowd) + + +def natural_key(string_): + """See http://www.codinghorror.com/blog/archives/001018.html""" + return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] + + +def reduce_tensor(tensor, n): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + rt /= n + return rt + + +class ModelEma: + """ Model Exponential Moving Average + Keep a moving average of everything in the model state_dict (parameters and buffers). + + This is intended to allow functionality like + https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + + A smoothed version of the weights is necessary for some training schemes to perform well. + E.g. Google's hyper-params for training MNASNet, MobileNet-V3, EfficientNet, etc that use + RMSprop with a short 2.4-3 epoch decay period and slow LR decay rate of .96-.99 requires EMA + smoothing of weights to match results. Pay attention to the decay constant you are using + relative to your update count per epoch. + + To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but + disable validation of the EMA weights. Validation will have to be done manually in a separate + process, or after the training stops converging. + + This class is sensitive where it is initialized in the sequence of model init, + GPU assignment and distributed training wrappers. + I've tested with the sequence in my own train.py for torch.DataParallel, apex.DDP, and single-GPU. + """ + def __init__(self, model, decay=0.9999, device='', resume=''): + # make a copy of the model for accumulating moving average of weights + self.ema = deepcopy(model) + self.ema.eval() + self.decay = decay + self.device = device # perform ema on different device from model if set + if device: + self.ema.to(device=device) + self.ema_has_module = hasattr(self.ema, 'module') + if resume: + self._load_checkpoint(resume) + for p in self.ema.parameters(): + p.requires_grad_(False) + + def _load_checkpoint(self, checkpoint_path): + checkpoint = torch.load(checkpoint_path) + assert isinstance(checkpoint, dict) + if 'state_dict_ema' in checkpoint: + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict_ema'].items(): + # ema model may have been wrapped by DataParallel, and need module prefix + if self.ema_has_module: + name = 'module.' + k if not k.startswith('module') else k + else: + name = k + new_state_dict[name] = v + self.ema.load_state_dict(new_state_dict) + print("=> loaded state_dict_ema") + else: + print("=> failed to find state_dict_ema, starting from loaded model weights)") + + def update(self, model): + # correct a mismatch in state dict keys + needs_module = hasattr(model, 'module') and not self.ema_has_module + with torch.no_grad(): + msd = model.state_dict() + for k, ema_v in self.ema.state_dict().items(): + if needs_module: + k = 'module.' + k + model_v = msd[k].detach() + if self.device: + model_v = model_v.to(device=self.device) + ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v) diff --git a/validate.py b/validate.py index 4373756a..e0d0a538 100644 --- a/validate.py +++ b/validate.py @@ -4,14 +4,17 @@ from __future__ import print_function import argparse import os +import csv +import glob import time import torch import torch.nn as nn import torch.nn.parallel +from collections import OrderedDict -from models import create_model, apply_test_time_pool +from models import create_model, apply_test_time_pool, load_checkpoint from data import Dataset, create_loader, resolve_data_config -from utils import accuracy, AverageMeter +from utils import accuracy, AverageMeter, natural_key torch.backends.cudnn.benchmark = True @@ -46,21 +49,26 @@ parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true', help='disable test time pool') parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true', help='Use Tensorflow preprocessing pipeline (require CPU TF installed') +parser.add_argument('--use-ema', dest='use_ema', action='store_true', + help='use ema version of weights if present') -def main(): - args = parser.parse_args() +def validate(args): # create model model = create_model( args.model, num_classes=args.num_classes, in_chans=3, - pretrained=args.pretrained, - checkpoint_path=args.checkpoint) + pretrained=args.pretrained) + + if args.checkpoint and not args.pretrained: + load_checkpoint(model, args.checkpoint, args.use_ema) + else: + args.pretrained = True # might as well try to validate something... - print('Model %s created, param count: %d' % - (args.model, sum([m.numel() for m in model.parameters()]))) + param_count = sum([m.numel() for m in model.parameters()]) + print('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(model, args) model, test_time_pool = apply_test_time_pool(model, data_config, args) @@ -120,8 +128,52 @@ def main(): rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) - print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format( - top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg)) + results = OrderedDict( + top1=round(top1.avg, 3), top1_err=round(100 - top1.avg, 3), + top5=round(top5.avg, 3), top5_err=round(100 - top5.avg, 3), + param_count=round(param_count / 1e6, 2)) + + print(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format( + results['top1'], results['top1_err'], results['top5'], results['top5_err'])) + + return results + + +def main(): + args = parser.parse_args() + if args.model == 'all': + # validate all models in a list of names with pretrained checkpoints + args.pretrained = True + # FIXME just an example list, need to add model name collections for + # batch testing of various pretrained combinations by arg string + models = ['tf_efficientnet_b0', 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3'] + model_cfgs = [(n, '') for n in models] + elif os.path.isdir(args.checkpoint): + # validate all checkpoints in a path with same model + checkpoints = glob.glob(args.checkpoint + '/*.pth.tar') + checkpoints += glob.glob(args.checkpoint + '/*.pth') + model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)] + else: + model_cfgs = [] + + if len(model_cfgs): + header_written = False + with open('./results-all.csv', mode='w') as cf: + for m, c in model_cfgs: + args.model = m + args.checkpoint = c + result = OrderedDict(model=args.model) + result.update(validate(args)) + if args.checkpoint: + result['checkpoint'] = args.checkpoint + dw = csv.DictWriter(cf, fieldnames=result.keys()) + if not header_written: + dw.writeheader() + header_written = True + dw.writerow(result) + cf.flush() + else: + validate(args) if __name__ == '__main__': From 7dab6d1ec7466d5eb23ddad4e275ff277d7b8cb4 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 10 Jun 2019 13:34:42 -0700 Subject: [PATCH 2/2] Default to img_size in model default_cfg, defer output folder creation until later in the init sequence --- train.py | 29 ++++++++++++++--------------- utils.py | 4 ++-- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/train.py b/train.py index afd54a6f..3811f7b4 100644 --- a/train.py +++ b/train.py @@ -42,8 +42,8 @@ parser.add_argument('--tta', type=int, default=0, metavar='N', help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') parser.add_argument('--pretrained', action='store_true', default=False, help='Start with pretrained version of specified network (if avail)') -parser.add_argument('--img-size', type=int, default=224, metavar='N', - help='Image patch size (default: 224)') +parser.add_argument('--img-size', type=int, default=None, metavar='N', + help='Image patch size (default: None => model default)') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override mean pixel value of dataset') parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', @@ -159,15 +159,6 @@ def main(): torch.manual_seed(args.seed + args.rank) - output_dir = '' - if args.local_rank == 0: - output_base = args.output if args.output else './output' - exp_name = '-'.join([ - datetime.now().strftime("%Y%m%d-%H%M%S"), - args.model, - str(args.img_size)]) - output_dir = get_outdir(output_base, 'train', exp_name) - model = create_model( args.model, pretrained=args.pretrained, @@ -291,13 +282,21 @@ def main(): validate_loss_fn = train_loss_fn eval_metric = args.eval_metric + best_metric = None + best_epoch = None saver = None - if output_dir: - # only set if process is rank 0 + output_dir = '' + if args.local_rank == 0: + output_base = args.output if args.output else './output' + exp_name = '-'.join([ + datetime.now().strftime("%Y%m%d-%H%M%S"), + args.model, + str(data_config['input_size'][-1]) + ]) + output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) - best_metric = None - best_epoch = None + try: for epoch in range(start_epoch, num_epochs): if args.distributed: diff --git a/utils.py b/utils.py index 43f58117..626ae9dc 100644 --- a/utils.py +++ b/utils.py @@ -253,9 +253,9 @@ class ModelEma: name = k new_state_dict[name] = v self.ema.load_state_dict(new_state_dict) - print("=> loaded state_dict_ema") + print("=> Loaded state_dict_ema") else: - print("=> failed to find state_dict_ema, starting from loaded model weights)") + print("=> Failed to find state_dict_ema, starting from loaded model weights") def update(self, model): # correct a mismatch in state dict keys