Fix tta train bug, improve logging

pull/1/head
Ross Wightman 6 years ago
parent 72b4d162a2
commit c57717d325

@ -62,7 +62,7 @@ parser.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status') help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=1000, metavar='N', parser.add_argument('--recovery-interval', type=int, default=1000, metavar='N',
help='how many batches to wait before writing recovery checkpoint') help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('-j', '--workers', type=int, default=6, metavar='N', parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 1)') help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=1, parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use') help='Number of GPUS to use')
@ -216,19 +216,6 @@ def main():
epoch, model, loader_train, optimizer, train_loss_fn, args, epoch, model, loader_train, optimizer, train_loss_fn, args,
saver=saver, output_dir=output_dir) saver=saver, output_dir=output_dir)
# save a recovery in case validation blows up
saver.save_recovery({
'epoch': epoch + 1,
'arch': args.model,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'loss': train_loss_fn.state_dict(),
'args': args,
'gp': args.gp,
},
epoch=epoch + 1,
batch_idx=0)
step = epoch * len(loader_train) step = epoch * len(loader_train)
eval_metrics = validate( eval_metrics = validate(
step, model, loader_eval, validate_loss_fn, args, step, model, loader_eval, validate_loss_fn, args,
@ -275,12 +262,14 @@ def train_epoch(
model.train() model.train()
end = time.time() end = time.time()
last_idx = len(loader) - 1
for batch_idx, (input, target) in enumerate(loader): for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
step = epoch_step + batch_idx step = epoch_step + batch_idx
data_time_m.update(time.time() - end) data_time_m.update(time.time() - end)
input = input.cuda() input = input.cuda()
if isinstance(target, list): if isinstance(target, (tuple, list)):
target = [t.cuda() for t in target] target = [t.cuda() for t in target]
else: else:
target = target.cuda() target = target.cuda()
@ -295,7 +284,7 @@ def train_epoch(
optimizer.step() optimizer.step()
batch_time_m.update(time.time() - end) batch_time_m.update(time.time() - end)
if batch_idx % args.log_interval == 0: if last_batch or batch_idx % args.log_interval == 0:
print('Train: {} [{}/{} ({:.0f}%)] ' print('Train: {} [{}/{} ({:.0f}%)] '
'Loss: {loss.val:.6f} ({loss.avg:.4f}) ' 'Loss: {loss.val:.6f} ({loss.avg:.4f}) '
'Time: {batch_time.val:.3f}s, {rate:.3f}/s ' 'Time: {batch_time.val:.3f}s, {rate:.3f}/s '
@ -303,7 +292,7 @@ def train_epoch(
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format( 'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch, epoch,
batch_idx * len(input), len(loader.sampler), batch_idx * len(input), len(loader.sampler),
100. * batch_idx / len(loader), 100. * batch_idx / last_idx,
loss=losses_m, loss=losses_m,
batch_time=batch_time_m, batch_time=batch_time_m,
rate=input.size(0) / batch_time_m.val, rate=input.size(0) / batch_time_m.val,
@ -317,16 +306,17 @@ def train_epoch(
padding=0, padding=0,
normalize=True) normalize=True)
if saver is not None and batch_idx % args.recovery_interval == 0: if saver is not None and last_batch or batch_idx % args.recovery_interval == 0:
save_epoch = epoch + 1 if last_batch else epoch
saver.save_recovery({ saver.save_recovery({
'epoch': epoch, 'epoch': save_epoch,
'arch': args.model, 'arch': args.model,
'state_dict': model.state_dict(), 'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(), 'optimizer': optimizer.state_dict(),
'args': args, 'args': args,
'gp': args.gp, 'gp': args.gp,
}, },
epoch=epoch, epoch=save_epoch,
batch_idx=batch_idx) batch_idx=batch_idx)
end = time.time() end = time.time()
@ -343,8 +333,11 @@ def validate(step, model, loader, loss_fn, args, output_dir=''):
model.eval() model.eval()
end = time.time() end = time.time()
last_idx = len(loader) - 1
with torch.no_grad(): with torch.no_grad():
for batch_idx, (input, target) in enumerate(loader): for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
input = input.cuda() input = input.cuda()
if isinstance(target, list): if isinstance(target, list):
target = target[0].cuda() target = target[0].cuda()
@ -353,11 +346,11 @@ def validate(step, model, loader, loss_fn, args, output_dir=''):
output = model(input) output = model(input)
if isinstance(output, list): if isinstance(output, (tuple, list)):
output = output[0] output = output[0]
# augmentation reduction # augmentation reduction
reduce_factor = loader.dataset.get_aug_factor() reduce_factor = args.tta
if reduce_factor > 1: if reduce_factor > 1:
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2) output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
target = target[0:target.size(0):reduce_factor] target = target[0:target.size(0):reduce_factor]
@ -373,13 +366,13 @@ def validate(step, model, loader, loss_fn, args, output_dir=''):
batch_time_m.update(time.time() - end) batch_time_m.update(time.time() - end)
end = time.time() end = time.time()
if batch_idx % args.log_interval == 0: if last_batch or batch_idx % args.log_interval == 0:
print('Test: [{0}/{1}]\t' print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) ' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss {loss.val:.4f} ({loss.avg:.4f}) ' 'Loss {loss.val:.4f} ({loss.avg:.4f}) '
'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) ' 'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) '
'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format( 'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format(
batch_idx, len(loader), batch_idx, last_idx,
batch_time=batch_time_m, loss=losses_m, batch_time=batch_time_m, loss=losses_m,
top1=prec1_m, top5=prec5_m)) top1=prec1_m, top5=prec5_m))

Loading…
Cancel
Save