diff --git a/train.py b/train.py index e51d7c90..ba365ba6 100755 --- a/train.py +++ b/train.py @@ -364,7 +364,7 @@ def _parse_args(): def main(): - utils.setup_default_logging() + args, args_text = _parse_args() if torch.cuda.is_available(): @@ -373,8 +373,53 @@ def main(): if args.data and not args.data_dir: args.data_dir = args.data + args.prefetcher = not args.no_prefetcher device = utils.init_distributed_device(args) + + # setup model based on args + in_chans = 3 + if args.in_chans is not None: + in_chans = args.in_chans + elif args.input_size is not None: + in_chans = args.input_size[0] + + model = create_model( + args.model, + pretrained=args.pretrained, + in_chans=in_chans, + num_classes=args.num_classes, + drop_rate=args.drop, + drop_path_rate=args.drop_path, + drop_block_rate=args.drop_block, + global_pool=args.gp, + bn_momentum=args.bn_momentum, + bn_eps=args.bn_eps, + scriptable=args.torchscript, + checkpoint_path=args.initial_checkpoint, + ) + if args.num_classes is None: + assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' + args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly + + if args.grad_checkpointing: + model.set_grad_checkpointing(enable=True) + + # initialize data config + data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args)) + output_dir = None + if args.experiment: + exp_name = args.experiment + else: + exp_name = '-'.join([ + datetime.now().strftime("%Y%m%d-%H%M%S"), + safe_model_name(args.model), + str(data_config['input_size'][-1]) + ]) + # confirm output directory & write 'train.log' to this directory by default + output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name) + utils.setup_default_logging(log_path=os.path.join(output_dir, 'train.log')) + if args.distributed: _logger.info( 'Training in distributed mode with multiple processes, 1 device per process.' @@ -390,7 +435,7 @@ def main(): _logger.warning( "You've requested to log metrics to wandb but package not found. " "Metrics not being logged to wandb, try `pip install wandb`") - + # resolve AMP arguments based on PyTorch / Apex availability use_amp = None amp_dtype = torch.float16 @@ -413,38 +458,11 @@ def main(): if args.fast_norm: set_fast_norm() - in_chans = 3 - if args.in_chans is not None: - in_chans = args.in_chans - elif args.input_size is not None: - in_chans = args.input_size[0] - - model = create_model( - args.model, - pretrained=args.pretrained, - in_chans=in_chans, - num_classes=args.num_classes, - drop_rate=args.drop, - drop_path_rate=args.drop_path, - drop_block_rate=args.drop_block, - global_pool=args.gp, - bn_momentum=args.bn_momentum, - bn_eps=args.bn_eps, - scriptable=args.torchscript, - checkpoint_path=args.initial_checkpoint, - ) - if args.num_classes is None: - assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' - args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly - - if args.grad_checkpointing: - model.set_grad_checkpointing(enable=True) - if utils.is_primary(args): _logger.info( f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}') - data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args)) + # setup augmentation batch splits for contrastive loss or split bn num_aug_splits = 0 @@ -686,17 +704,8 @@ def main(): best_metric = None best_epoch = None saver = None - output_dir = None if utils.is_primary(args): - if args.experiment: - exp_name = args.experiment - else: - exp_name = '-'.join([ - datetime.now().strftime("%Y%m%d-%H%M%S"), - safe_model_name(args.model), - str(data_config['input_size'][-1]) - ]) - output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name) + decreasing = True if eval_metric == 'loss' else False saver = utils.CheckpointSaver( model=model,