Fix tta train bug, improve logging

pull/1/head
Ross Wightman 6 years ago
parent 72b4d162a2
commit c57717d325

@ -62,7 +62,7 @@ parser.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
parser.add_argument('--recovery-interval', type=int, default=1000, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
parser.add_argument('-j', '--workers', type=int, default=6, metavar='N',
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 1)')
parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use')
@ -216,19 +216,6 @@ def main():
epoch, model, loader_train, optimizer, train_loss_fn, args,
saver=saver, output_dir=output_dir)
# save a recovery in case validation blows up
saver.save_recovery({
'epoch': epoch + 1,
'arch': args.model,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'loss': train_loss_fn.state_dict(),
'args': args,
'gp': args.gp,
},
epoch=epoch + 1,
batch_idx=0)
step = epoch * len(loader_train)
eval_metrics = validate(
step, model, loader_eval, validate_loss_fn, args,
@ -275,12 +262,14 @@ def train_epoch(
model.train()
end = time.time()
last_idx = len(loader) - 1
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
step = epoch_step + batch_idx
data_time_m.update(time.time() - end)
input = input.cuda()
if isinstance(target, list):
if isinstance(target, (tuple, list)):
target = [t.cuda() for t in target]
else:
target = target.cuda()
@ -295,7 +284,7 @@ def train_epoch(
optimizer.step()
batch_time_m.update(time.time() - end)
if batch_idx % args.log_interval == 0:
if last_batch or batch_idx % args.log_interval == 0:
print('Train: {} [{}/{} ({:.0f}%)] '
'Loss: {loss.val:.6f} ({loss.avg:.4f}) '
'Time: {batch_time.val:.3f}s, {rate:.3f}/s '
@ -303,7 +292,7 @@ def train_epoch(
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
epoch,
batch_idx * len(input), len(loader.sampler),
100. * batch_idx / len(loader),
100. * batch_idx / last_idx,
loss=losses_m,
batch_time=batch_time_m,
rate=input.size(0) / batch_time_m.val,
@ -317,16 +306,17 @@ def train_epoch(
padding=0,
normalize=True)
if saver is not None and batch_idx % args.recovery_interval == 0:
if saver is not None and last_batch or batch_idx % args.recovery_interval == 0:
save_epoch = epoch + 1 if last_batch else epoch
saver.save_recovery({
'epoch': epoch,
'epoch': save_epoch,
'arch': args.model,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'args': args,
'gp': args.gp,
},
epoch=epoch,
epoch=save_epoch,
batch_idx=batch_idx)
end = time.time()
@ -343,8 +333,11 @@ def validate(step, model, loader, loss_fn, args, output_dir=''):
model.eval()
end = time.time()
last_idx = len(loader) - 1
with torch.no_grad():
for batch_idx, (input, target) in enumerate(loader):
last_batch = batch_idx == last_idx
input = input.cuda()
if isinstance(target, list):
target = target[0].cuda()
@ -353,11 +346,11 @@ def validate(step, model, loader, loss_fn, args, output_dir=''):
output = model(input)
if isinstance(output, list):
if isinstance(output, (tuple, list)):
output = output[0]
# augmentation reduction
reduce_factor = loader.dataset.get_aug_factor()
reduce_factor = args.tta
if reduce_factor > 1:
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
target = target[0:target.size(0):reduce_factor]
@ -373,13 +366,13 @@ def validate(step, model, loader, loss_fn, args, output_dir=''):
batch_time_m.update(time.time() - end)
end = time.time()
if batch_idx % args.log_interval == 0:
if last_batch or batch_idx % args.log_interval == 0:
print('Test: [{0}/{1}]\t'
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
'Loss {loss.val:.4f} ({loss.avg:.4f}) '
'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) '
'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format(
batch_idx, len(loader),
batch_idx, last_idx,
batch_time=batch_time_m, loss=losses_m,
top1=prec1_m, top5=prec5_m))

Loading…
Cancel
Save