From c57717d325dea896e44869ea508b6e8bc0475984 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 2 Feb 2019 10:17:04 -0800 Subject: [PATCH] Fix tta train bug, improve logging --- train.py | 41 +++++++++++++++++------------------------ 1 file changed, 17 insertions(+), 24 deletions(-) diff --git a/train.py b/train.py index 37f85138..4639250d 100644 --- a/train.py +++ b/train.py @@ -62,7 +62,7 @@ parser.add_argument('--log-interval', type=int, default=50, metavar='N', help='how many batches to wait before logging training status') parser.add_argument('--recovery-interval', type=int, default=1000, metavar='N', help='how many batches to wait before writing recovery checkpoint') -parser.add_argument('-j', '--workers', type=int, default=6, metavar='N', +parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', help='how many training processes to use (default: 1)') parser.add_argument('--num-gpu', type=int, default=1, help='Number of GPUS to use') @@ -216,19 +216,6 @@ def main(): epoch, model, loader_train, optimizer, train_loss_fn, args, saver=saver, output_dir=output_dir) - # save a recovery in case validation blows up - saver.save_recovery({ - 'epoch': epoch + 1, - 'arch': args.model, - 'state_dict': model.state_dict(), - 'optimizer': optimizer.state_dict(), - 'loss': train_loss_fn.state_dict(), - 'args': args, - 'gp': args.gp, - }, - epoch=epoch + 1, - batch_idx=0) - step = epoch * len(loader_train) eval_metrics = validate( step, model, loader_eval, validate_loss_fn, args, @@ -275,12 +262,14 @@ def train_epoch( model.train() end = time.time() + last_idx = len(loader) - 1 for batch_idx, (input, target) in enumerate(loader): + last_batch = batch_idx == last_idx step = epoch_step + batch_idx data_time_m.update(time.time() - end) input = input.cuda() - if isinstance(target, list): + if isinstance(target, (tuple, list)): target = [t.cuda() for t in target] else: target = target.cuda() @@ -295,7 +284,7 @@ def train_epoch( optimizer.step() batch_time_m.update(time.time() - end) - if batch_idx % args.log_interval == 0: + if last_batch or batch_idx % args.log_interval == 0: print('Train: {} [{}/{} ({:.0f}%)] ' 'Loss: {loss.val:.6f} ({loss.avg:.4f}) ' 'Time: {batch_time.val:.3f}s, {rate:.3f}/s ' @@ -303,7 +292,7 @@ def train_epoch( 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( epoch, batch_idx * len(input), len(loader.sampler), - 100. * batch_idx / len(loader), + 100. * batch_idx / last_idx, loss=losses_m, batch_time=batch_time_m, rate=input.size(0) / batch_time_m.val, @@ -317,16 +306,17 @@ def train_epoch( padding=0, normalize=True) - if saver is not None and batch_idx % args.recovery_interval == 0: + if saver is not None and last_batch or batch_idx % args.recovery_interval == 0: + save_epoch = epoch + 1 if last_batch else epoch saver.save_recovery({ - 'epoch': epoch, + 'epoch': save_epoch, 'arch': args.model, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'args': args, 'gp': args.gp, }, - epoch=epoch, + epoch=save_epoch, batch_idx=batch_idx) end = time.time() @@ -343,8 +333,11 @@ def validate(step, model, loader, loss_fn, args, output_dir=''): model.eval() end = time.time() + last_idx = len(loader) - 1 with torch.no_grad(): for batch_idx, (input, target) in enumerate(loader): + last_batch = batch_idx == last_idx + input = input.cuda() if isinstance(target, list): target = target[0].cuda() @@ -353,11 +346,11 @@ def validate(step, model, loader, loss_fn, args, output_dir=''): output = model(input) - if isinstance(output, list): + if isinstance(output, (tuple, list)): output = output[0] # augmentation reduction - reduce_factor = loader.dataset.get_aug_factor() + reduce_factor = args.tta if reduce_factor > 1: output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) target = target[0:target.size(0):reduce_factor] @@ -373,13 +366,13 @@ def validate(step, model, loader, loss_fn, args, output_dir=''): batch_time_m.update(time.time() - end) end = time.time() - if batch_idx % args.log_interval == 0: + if last_batch or batch_idx % args.log_interval == 0: print('Test: [{0}/{1}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Loss {loss.val:.4f} ({loss.avg:.4f}) ' 'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) ' 'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format( - batch_idx, len(loader), + batch_idx, last_idx, batch_time=batch_time_m, loss=losses_m, top1=prec1_m, top5=prec5_m))