diff --git a/train.py b/train.py index 2fdf68d8..d0886c08 100755 --- a/train.py +++ b/train.py @@ -259,6 +259,8 @@ parser.add_argument('--no-prefetcher', action='store_true', default=False, help='disable fast prefetcher') parser.add_argument('--output', default='', type=str, metavar='PATH', help='path to output folder (default: none, current dir)') +parser.add_argument('--experiment', default='', type=str, metavar='NAME', + help='name of train experiment, name of sub-folder for output') parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', help='Best metric (default: "top1"') parser.add_argument('--tta', type=int, default=0, metavar='N', @@ -544,13 +546,15 @@ def main(): saver = None output_dir = '' if args.local_rank == 0: - output_base = args.output if args.output else './output' - exp_name = '-'.join([ - datetime.now().strftime("%Y%m%d-%H%M%S"), - args.model, - str(data_config['input_size'][-1]) - ]) - output_dir = get_outdir(output_base, 'train', exp_name) + if args.experiment: + exp_name = args.experiment + else: + exp_name = '-'.join([ + datetime.now().strftime("%Y%m%d-%H%M%S"), + args.model, + str(data_config['input_size'][-1]) + ]) + output_dir = get_outdir(args.output if args.output else './output/train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver( model=model, optimizer=optimizer, args=args, model_ema=model_ema, amp_scaler=loss_scaler,