From f1cd1a5ce3e1e9dde77f9cb5d024f6e3268da1b9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 7 Apr 2019 10:22:55 -0700 Subject: [PATCH] Cleanup CheckpointSaver, add support for increasing or decreasing metric, switch to prec1 metric in train loop --- train.py | 24 +++++++++++++++--------- utils.py | 54 +++++++++++++++++++++++++++++++++++------------------- 2 files changed, 50 insertions(+), 28 deletions(-) diff --git a/train.py b/train.py index 89632d39..5494b3ff 100644 --- a/train.py +++ b/train.py @@ -93,6 +93,8 @@ parser.add_argument('--amp', action='store_true', default=False, help='use NVIDIA amp for mixed precision training') parser.add_argument('--output', default='', type=str, metavar='PATH', help='path to output folder (default: none, current dir)') +parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC', + help='Best metric (default: "prec1"') parser.add_argument("--local_rank", default=0, type=int) @@ -238,10 +240,13 @@ def main(): if args.local_rank == 0: print('Scheduled epochs: ', num_epochs) + eval_metric = args.eval_metric saver = None if output_dir: - saver = CheckpointSaver(checkpoint_dir=output_dir) - best_loss = None + decreasing = True if eval_metric == 'loss' else False + saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) + best_metric = None + best_epoch = None try: for epoch in range(start_epoch, num_epochs): if args.distributed: @@ -255,15 +260,15 @@ def main(): model, loader_eval, validate_loss_fn, args) if lr_scheduler is not None: - lr_scheduler.step(epoch, eval_metrics['eval_loss']) + lr_scheduler.step(epoch, eval_metrics[eval_metric]) update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), - write_header=best_loss is None) + write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric - best_loss = saver.save_checkpoint({ + best_metric, best_epoch = saver.save_checkpoint({ 'epoch': epoch + 1, 'arch': args.model, 'state_dict': model.state_dict(), @@ -271,11 +276,12 @@ def main(): 'args': args, }, epoch=epoch + 1, - metric=eval_metrics['eval_loss']) + metric=eval_metrics[eval_metric]) except KeyboardInterrupt: pass - print('*** Best loss: {0} (epoch {1})'.format(best_loss[1], best_loss[0])) + if best_metric is not None: + print('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch)) def train_epoch( @@ -363,7 +369,7 @@ def train_epoch( end = time.time() - return OrderedDict([('train_loss', losses_m.avg)]) + return OrderedDict([('loss', losses_m.avg)]) def validate(model, loader, loss_fn, args): @@ -418,7 +424,7 @@ def validate(model, loader, loss_fn, args): batch_time=batch_time_m, loss=losses_m, top1=prec1_m, top5=prec5_m)) - metrics = OrderedDict([('eval_loss', losses_m.avg), ('eval_prec1', prec1_m.avg)]) + metrics = OrderedDict([('loss', losses_m.avg), ('prec1', prec1_m.avg), ('prec5', prec5_m.avg)]) return metrics diff --git a/utils.py b/utils.py index 4604a258..f206f945 100644 --- a/utils.py +++ b/utils.py @@ -6,6 +6,7 @@ import os import shutil import glob import csv +import operator from collections import OrderedDict @@ -16,24 +17,32 @@ class CheckpointSaver: recovery_prefix='recovery', checkpoint_dir='', recovery_dir='', + decreasing=False, + verbose=True, max_history=10): - self.checkpoint_files = [] + # state + self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness + self.best_epoch = None self.best_metric = None - self.worst_metric = None - self.max_history = max_history - assert self.max_history >= 1 self.curr_recovery_file = '' self.last_recovery_file = '' + + # config self.checkpoint_dir = checkpoint_dir self.recovery_dir = recovery_dir self.save_prefix = checkpoint_prefix self.recovery_prefix = recovery_prefix self.extension = '.pth.tar' + self.decreasing = decreasing # a lower metric is better if True + self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs + self.verbose = verbose + self.max_history = max_history + assert self.max_history >= 1 def save_checkpoint(self, state, epoch, metric=None): - worst_metric = self.checkpoint_files[-1] if self.checkpoint_files else None - if len(self.checkpoint_files) < self.max_history or metric < worst_metric[1]: + worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None + if len(self.checkpoint_files) < self.max_history or self.cmp(metric, worst_file[1]): if len(self.checkpoint_files) >= self.max_history: self._cleanup_checkpoints(1) @@ -43,16 +52,21 @@ class CheckpointSaver: state['metric'] = metric torch.save(state, save_path) self.checkpoint_files.append((save_path, metric)) - self.checkpoint_files = sorted(self.checkpoint_files, key=lambda x: x[1]) - - print("Current checkpoints:") - for c in self.checkpoint_files: - print(c) - - if metric is not None and (self.best_metric is None or metric < self.best_metric[1]): - self.best_metric = (epoch, metric) + self.checkpoint_files = sorted( + self.checkpoint_files, key=lambda x: x[1], + reverse=not self.decreasing) # sort in descending order if a lower metric is not better + + if self.verbose: + print("Current checkpoints:") + for c in self.checkpoint_files: + print(c) + + if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)): + self.best_epoch = epoch + self.best_metric = metric shutil.copyfile(save_path, os.path.join(self.checkpoint_dir, 'model_best' + self.extension)) - return None, None if self.best_metric is None else self.best_metric + + return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch) def _cleanup_checkpoints(self, trim=0): trim = min(len(self.checkpoint_files), trim) @@ -62,7 +76,8 @@ class CheckpointSaver: to_delete = self.checkpoint_files[delete_index:] for d in to_delete: try: - print('Cleaning checkpoint: ', d) + if self.verbose: + print('Cleaning checkpoint: ', d) os.remove(d[0]) except Exception as e: print('Exception (%s) while deleting checkpoint' % str(e)) @@ -74,7 +89,8 @@ class CheckpointSaver: torch.save(state, save_path) if os.path.exists(self.last_recovery_file): try: - print('Cleaning recovery', self.last_recovery_file) + if self.verbose: + print('Cleaning recovery', self.last_recovery_file) os.remove(self.last_recovery_file) except Exception as e: print("Exception (%s) while removing %s" % (str(e), self.last_recovery_file)) @@ -143,8 +159,8 @@ def get_outdir(path, *paths, inc=False): def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False): rowd = OrderedDict(epoch=epoch) - rowd.update(train_metrics) - rowd.update(eval_metrics) + rowd.update([('train_' + k, v) for k, v in train_metrics.items()]) + rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()]) with open(filename, mode='a') as cf: dw = csv.DictWriter(cf, fieldnames=rowd.keys()) if write_header: # first iteration (epoch == 1 can't be used)