Cleanup CheckpointSaver, add support for increasing or decreasing metric, switch to prec1 metric in train loop

pull/1/head
Ross Wightman 5 years ago
parent c0e6e5f3db
commit f1cd1a5ce3

@ -93,6 +93,8 @@ parser.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA amp for mixed precision training')
parser.add_argument('--output', default='', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "prec1"')
parser.add_argument("--local_rank", default=0, type=int)
@ -238,10 +240,13 @@ def main():
if args.local_rank == 0:
print('Scheduled epochs: ', num_epochs)
eval_metric = args.eval_metric
saver = None
if output_dir:
saver = CheckpointSaver(checkpoint_dir=output_dir)
best_loss = None
decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
best_metric = None
best_epoch = None
try:
for epoch in range(start_epoch, num_epochs):
if args.distributed:
@ -255,15 +260,15 @@ def main():
model, loader_eval, validate_loss_fn, args)
if lr_scheduler is not None:
lr_scheduler.step(epoch, eval_metrics['eval_loss'])
lr_scheduler.step(epoch, eval_metrics[eval_metric])
update_summary(
epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'),
write_header=best_loss is None)
write_header=best_metric is None)
if saver is not None:
# save proper checkpoint with eval metric
best_loss = saver.save_checkpoint({
best_metric, best_epoch = saver.save_checkpoint({
'epoch': epoch + 1,
'arch': args.model,
'state_dict': model.state_dict(),
@ -271,11 +276,12 @@ def main():
'args': args,
},
epoch=epoch + 1,
metric=eval_metrics['eval_loss'])
metric=eval_metrics[eval_metric])
except KeyboardInterrupt:
pass
print('*** Best loss: {0} (epoch {1})'.format(best_loss[1], best_loss[0]))
if best_metric is not None:
print('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
def train_epoch(
@ -363,7 +369,7 @@ def train_epoch(
end = time.time()
return OrderedDict([('train_loss', losses_m.avg)])
return OrderedDict([('loss', losses_m.avg)])
def validate(model, loader, loss_fn, args):
@ -418,7 +424,7 @@ def validate(model, loader, loss_fn, args):
batch_time=batch_time_m, loss=losses_m,
top1=prec1_m, top5=prec5_m))
metrics = OrderedDict([('eval_loss', losses_m.avg), ('eval_prec1', prec1_m.avg)])
metrics = OrderedDict([('loss', losses_m.avg), ('prec1', prec1_m.avg), ('prec5', prec5_m.avg)])
return metrics

@ -6,6 +6,7 @@ import os
import shutil
import glob
import csv
import operator
from collections import OrderedDict
@ -16,24 +17,32 @@ class CheckpointSaver:
recovery_prefix='recovery',
checkpoint_dir='',
recovery_dir='',
decreasing=False,
verbose=True,
max_history=10):
self.checkpoint_files = []
# state
self.checkpoint_files = [] # (filename, metric) tuples in order of decreasing betterness
self.best_epoch = None
self.best_metric = None
self.worst_metric = None
self.max_history = max_history
assert self.max_history >= 1
self.curr_recovery_file = ''
self.last_recovery_file = ''
# config
self.checkpoint_dir = checkpoint_dir
self.recovery_dir = recovery_dir
self.save_prefix = checkpoint_prefix
self.recovery_prefix = recovery_prefix
self.extension = '.pth.tar'
self.decreasing = decreasing # a lower metric is better if True
self.cmp = operator.lt if decreasing else operator.gt # True if lhs better than rhs
self.verbose = verbose
self.max_history = max_history
assert self.max_history >= 1
def save_checkpoint(self, state, epoch, metric=None):
worst_metric = self.checkpoint_files[-1] if self.checkpoint_files else None
if len(self.checkpoint_files) < self.max_history or metric < worst_metric[1]:
worst_file = self.checkpoint_files[-1] if self.checkpoint_files else None
if len(self.checkpoint_files) < self.max_history or self.cmp(metric, worst_file[1]):
if len(self.checkpoint_files) >= self.max_history:
self._cleanup_checkpoints(1)
@ -43,16 +52,21 @@ class CheckpointSaver:
state['metric'] = metric
torch.save(state, save_path)
self.checkpoint_files.append((save_path, metric))
self.checkpoint_files = sorted(self.checkpoint_files, key=lambda x: x[1])
print("Current checkpoints:")
for c in self.checkpoint_files:
print(c)
if metric is not None and (self.best_metric is None or metric < self.best_metric[1]):
self.best_metric = (epoch, metric)
self.checkpoint_files = sorted(
self.checkpoint_files, key=lambda x: x[1],
reverse=not self.decreasing) # sort in descending order if a lower metric is not better
if self.verbose:
print("Current checkpoints:")
for c in self.checkpoint_files:
print(c)
if metric is not None and (self.best_metric is None or self.cmp(metric, self.best_metric)):
self.best_epoch = epoch
self.best_metric = metric
shutil.copyfile(save_path, os.path.join(self.checkpoint_dir, 'model_best' + self.extension))
return None, None if self.best_metric is None else self.best_metric
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
def _cleanup_checkpoints(self, trim=0):
trim = min(len(self.checkpoint_files), trim)
@ -62,7 +76,8 @@ class CheckpointSaver:
to_delete = self.checkpoint_files[delete_index:]
for d in to_delete:
try:
print('Cleaning checkpoint: ', d)
if self.verbose:
print('Cleaning checkpoint: ', d)
os.remove(d[0])
except Exception as e:
print('Exception (%s) while deleting checkpoint' % str(e))
@ -74,7 +89,8 @@ class CheckpointSaver:
torch.save(state, save_path)
if os.path.exists(self.last_recovery_file):
try:
print('Cleaning recovery', self.last_recovery_file)
if self.verbose:
print('Cleaning recovery', self.last_recovery_file)
os.remove(self.last_recovery_file)
except Exception as e:
print("Exception (%s) while removing %s" % (str(e), self.last_recovery_file))
@ -143,8 +159,8 @@ def get_outdir(path, *paths, inc=False):
def update_summary(epoch, train_metrics, eval_metrics, filename, write_header=False):
rowd = OrderedDict(epoch=epoch)
rowd.update(train_metrics)
rowd.update(eval_metrics)
rowd.update([('train_' + k, v) for k, v in train_metrics.items()])
rowd.update([('eval_' + k, v) for k, v in eval_metrics.items()])
with open(filename, mode='a') as cf:
dw = csv.DictWriter(cf, fieldnames=rowd.keys())
if write_header: # first iteration (epoch == 1 can't be used)

Loading…
Cancel
Save