|
|
@ -259,6 +259,8 @@ parser.add_argument('--no-prefetcher', action='store_true', default=False,
|
|
|
|
help='disable fast prefetcher')
|
|
|
|
help='disable fast prefetcher')
|
|
|
|
parser.add_argument('--output', default='', type=str, metavar='PATH',
|
|
|
|
parser.add_argument('--output', default='', type=str, metavar='PATH',
|
|
|
|
help='path to output folder (default: none, current dir)')
|
|
|
|
help='path to output folder (default: none, current dir)')
|
|
|
|
|
|
|
|
parser.add_argument('--experiment', default='', type=str, metavar='NAME',
|
|
|
|
|
|
|
|
help='name of train experiment, name of sub-folder for output')
|
|
|
|
parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
|
|
|
|
parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
|
|
|
|
help='Best metric (default: "top1"')
|
|
|
|
help='Best metric (default: "top1"')
|
|
|
|
parser.add_argument('--tta', type=int, default=0, metavar='N',
|
|
|
|
parser.add_argument('--tta', type=int, default=0, metavar='N',
|
|
|
@ -544,13 +546,15 @@ def main():
|
|
|
|
saver = None
|
|
|
|
saver = None
|
|
|
|
output_dir = ''
|
|
|
|
output_dir = ''
|
|
|
|
if args.local_rank == 0:
|
|
|
|
if args.local_rank == 0:
|
|
|
|
output_base = args.output if args.output else './output'
|
|
|
|
if args.experiment:
|
|
|
|
exp_name = '-'.join([
|
|
|
|
exp_name = args.experiment
|
|
|
|
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
|
|
|
else:
|
|
|
|
args.model,
|
|
|
|
exp_name = '-'.join([
|
|
|
|
str(data_config['input_size'][-1])
|
|
|
|
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
|
|
|
])
|
|
|
|
args.model,
|
|
|
|
output_dir = get_outdir(output_base, 'train', exp_name)
|
|
|
|
str(data_config['input_size'][-1])
|
|
|
|
|
|
|
|
])
|
|
|
|
|
|
|
|
output_dir = get_outdir(args.output if args.output else './output/train', exp_name)
|
|
|
|
decreasing = True if eval_metric == 'loss' else False
|
|
|
|
decreasing = True if eval_metric == 'loss' else False
|
|
|
|
saver = CheckpointSaver(
|
|
|
|
saver = CheckpointSaver(
|
|
|
|
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
|
|
|
|
model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,
|
|
|
|