|
|
|
@ -42,8 +42,8 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
|
|
|
|
|
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
|
|
|
|
|
parser.add_argument('--pretrained', action='store_true', default=False,
|
|
|
|
|
help='Start with pretrained version of specified network (if avail)')
|
|
|
|
|
parser.add_argument('--img-size', type=int, default=224, metavar='N',
|
|
|
|
|
help='Image patch size (default: 224)')
|
|
|
|
|
parser.add_argument('--img-size', type=int, default=None, metavar='N',
|
|
|
|
|
help='Image patch size (default: None => model default)')
|
|
|
|
|
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
|
|
|
|
help='Override mean pixel value of dataset')
|
|
|
|
|
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
|
|
|
@ -159,15 +159,6 @@ def main():
|
|
|
|
|
|
|
|
|
|
torch.manual_seed(args.seed + args.rank)
|
|
|
|
|
|
|
|
|
|
output_dir = ''
|
|
|
|
|
if args.local_rank == 0:
|
|
|
|
|
output_base = args.output if args.output else './output'
|
|
|
|
|
exp_name = '-'.join([
|
|
|
|
|
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
|
|
|
|
args.model,
|
|
|
|
|
str(args.img_size)])
|
|
|
|
|
output_dir = get_outdir(output_base, 'train', exp_name)
|
|
|
|
|
|
|
|
|
|
model = create_model(
|
|
|
|
|
args.model,
|
|
|
|
|
pretrained=args.pretrained,
|
|
|
|
@ -291,13 +282,21 @@ def main():
|
|
|
|
|
validate_loss_fn = train_loss_fn
|
|
|
|
|
|
|
|
|
|
eval_metric = args.eval_metric
|
|
|
|
|
best_metric = None
|
|
|
|
|
best_epoch = None
|
|
|
|
|
saver = None
|
|
|
|
|
if output_dir:
|
|
|
|
|
# only set if process is rank 0
|
|
|
|
|
output_dir = ''
|
|
|
|
|
if args.local_rank == 0:
|
|
|
|
|
output_base = args.output if args.output else './output'
|
|
|
|
|
exp_name = '-'.join([
|
|
|
|
|
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
|
|
|
|
args.model,
|
|
|
|
|
str(data_config['input_size'][-1])
|
|
|
|
|
])
|
|
|
|
|
output_dir = get_outdir(output_base, 'train', exp_name)
|
|
|
|
|
decreasing = True if eval_metric == 'loss' else False
|
|
|
|
|
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
|
|
|
|
|
best_metric = None
|
|
|
|
|
best_epoch = None
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
for epoch in range(start_epoch, num_epochs):
|
|
|
|
|
if args.distributed:
|
|
|
|
|