|
|
|
@ -62,7 +62,7 @@ parser.add_argument('--log-interval', type=int, default=50, metavar='N',
|
|
|
|
|
help='how many batches to wait before logging training status')
|
|
|
|
|
parser.add_argument('--recovery-interval', type=int, default=1000, metavar='N',
|
|
|
|
|
help='how many batches to wait before writing recovery checkpoint')
|
|
|
|
|
parser.add_argument('-j', '--workers', type=int, default=6, metavar='N',
|
|
|
|
|
parser.add_argument('-j', '--workers', type=int, default=4, metavar='N',
|
|
|
|
|
help='how many training processes to use (default: 1)')
|
|
|
|
|
parser.add_argument('--num-gpu', type=int, default=1,
|
|
|
|
|
help='Number of GPUS to use')
|
|
|
|
@ -216,19 +216,6 @@ def main():
|
|
|
|
|
epoch, model, loader_train, optimizer, train_loss_fn, args,
|
|
|
|
|
saver=saver, output_dir=output_dir)
|
|
|
|
|
|
|
|
|
|
# save a recovery in case validation blows up
|
|
|
|
|
saver.save_recovery({
|
|
|
|
|
'epoch': epoch + 1,
|
|
|
|
|
'arch': args.model,
|
|
|
|
|
'state_dict': model.state_dict(),
|
|
|
|
|
'optimizer': optimizer.state_dict(),
|
|
|
|
|
'loss': train_loss_fn.state_dict(),
|
|
|
|
|
'args': args,
|
|
|
|
|
'gp': args.gp,
|
|
|
|
|
},
|
|
|
|
|
epoch=epoch + 1,
|
|
|
|
|
batch_idx=0)
|
|
|
|
|
|
|
|
|
|
step = epoch * len(loader_train)
|
|
|
|
|
eval_metrics = validate(
|
|
|
|
|
step, model, loader_eval, validate_loss_fn, args,
|
|
|
|
@ -275,12 +262,14 @@ def train_epoch(
|
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
|
|
end = time.time()
|
|
|
|
|
last_idx = len(loader) - 1
|
|
|
|
|
for batch_idx, (input, target) in enumerate(loader):
|
|
|
|
|
last_batch = batch_idx == last_idx
|
|
|
|
|
step = epoch_step + batch_idx
|
|
|
|
|
data_time_m.update(time.time() - end)
|
|
|
|
|
|
|
|
|
|
input = input.cuda()
|
|
|
|
|
if isinstance(target, list):
|
|
|
|
|
if isinstance(target, (tuple, list)):
|
|
|
|
|
target = [t.cuda() for t in target]
|
|
|
|
|
else:
|
|
|
|
|
target = target.cuda()
|
|
|
|
@ -295,7 +284,7 @@ def train_epoch(
|
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
|
|
batch_time_m.update(time.time() - end)
|
|
|
|
|
if batch_idx % args.log_interval == 0:
|
|
|
|
|
if last_batch or batch_idx % args.log_interval == 0:
|
|
|
|
|
print('Train: {} [{}/{} ({:.0f}%)] '
|
|
|
|
|
'Loss: {loss.val:.6f} ({loss.avg:.4f}) '
|
|
|
|
|
'Time: {batch_time.val:.3f}s, {rate:.3f}/s '
|
|
|
|
@ -303,7 +292,7 @@ def train_epoch(
|
|
|
|
|
'Data: {data_time.val:.3f} ({data_time.avg:.3f})'.format(
|
|
|
|
|
epoch,
|
|
|
|
|
batch_idx * len(input), len(loader.sampler),
|
|
|
|
|
100. * batch_idx / len(loader),
|
|
|
|
|
100. * batch_idx / last_idx,
|
|
|
|
|
loss=losses_m,
|
|
|
|
|
batch_time=batch_time_m,
|
|
|
|
|
rate=input.size(0) / batch_time_m.val,
|
|
|
|
@ -317,16 +306,17 @@ def train_epoch(
|
|
|
|
|
padding=0,
|
|
|
|
|
normalize=True)
|
|
|
|
|
|
|
|
|
|
if saver is not None and batch_idx % args.recovery_interval == 0:
|
|
|
|
|
if saver is not None and last_batch or batch_idx % args.recovery_interval == 0:
|
|
|
|
|
save_epoch = epoch + 1 if last_batch else epoch
|
|
|
|
|
saver.save_recovery({
|
|
|
|
|
'epoch': epoch,
|
|
|
|
|
'epoch': save_epoch,
|
|
|
|
|
'arch': args.model,
|
|
|
|
|
'state_dict': model.state_dict(),
|
|
|
|
|
'optimizer': optimizer.state_dict(),
|
|
|
|
|
'args': args,
|
|
|
|
|
'gp': args.gp,
|
|
|
|
|
},
|
|
|
|
|
epoch=epoch,
|
|
|
|
|
epoch=save_epoch,
|
|
|
|
|
batch_idx=batch_idx)
|
|
|
|
|
|
|
|
|
|
end = time.time()
|
|
|
|
@ -343,8 +333,11 @@ def validate(step, model, loader, loss_fn, args, output_dir=''):
|
|
|
|
|
model.eval()
|
|
|
|
|
|
|
|
|
|
end = time.time()
|
|
|
|
|
last_idx = len(loader) - 1
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
for batch_idx, (input, target) in enumerate(loader):
|
|
|
|
|
last_batch = batch_idx == last_idx
|
|
|
|
|
|
|
|
|
|
input = input.cuda()
|
|
|
|
|
if isinstance(target, list):
|
|
|
|
|
target = target[0].cuda()
|
|
|
|
@ -353,11 +346,11 @@ def validate(step, model, loader, loss_fn, args, output_dir=''):
|
|
|
|
|
|
|
|
|
|
output = model(input)
|
|
|
|
|
|
|
|
|
|
if isinstance(output, list):
|
|
|
|
|
if isinstance(output, (tuple, list)):
|
|
|
|
|
output = output[0]
|
|
|
|
|
|
|
|
|
|
# augmentation reduction
|
|
|
|
|
reduce_factor = loader.dataset.get_aug_factor()
|
|
|
|
|
reduce_factor = args.tta
|
|
|
|
|
if reduce_factor > 1:
|
|
|
|
|
output = output.unfold(0, reduce_factor, reduce_factor).mean(dim=2)
|
|
|
|
|
target = target[0:target.size(0):reduce_factor]
|
|
|
|
@ -373,13 +366,13 @@ def validate(step, model, loader, loss_fn, args, output_dir=''):
|
|
|
|
|
|
|
|
|
|
batch_time_m.update(time.time() - end)
|
|
|
|
|
end = time.time()
|
|
|
|
|
if batch_idx % args.log_interval == 0:
|
|
|
|
|
if last_batch or batch_idx % args.log_interval == 0:
|
|
|
|
|
print('Test: [{0}/{1}]\t'
|
|
|
|
|
'Time {batch_time.val:.3f} ({batch_time.avg:.3f}) '
|
|
|
|
|
'Loss {loss.val:.4f} ({loss.avg:.4f}) '
|
|
|
|
|
'Prec@1 {top1.val:.4f} ({top1.avg:.4f}) '
|
|
|
|
|
'Prec@5 {top5.val:.4f} ({top5.avg:.4f})'.format(
|
|
|
|
|
batch_idx, len(loader),
|
|
|
|
|
batch_idx, last_idx,
|
|
|
|
|
batch_time=batch_time_m, loss=losses_m,
|
|
|
|
|
top1=prec1_m, top5=prec5_m))
|
|
|
|
|
|
|
|
|
|