|
|
|
@ -364,7 +364,7 @@ def _parse_args():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
utils.setup_default_logging()
|
|
|
|
|
|
|
|
|
|
args, args_text = _parse_args()
|
|
|
|
|
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
@ -373,8 +373,53 @@ def main():
|
|
|
|
|
|
|
|
|
|
if args.data and not args.data_dir:
|
|
|
|
|
args.data_dir = args.data
|
|
|
|
|
|
|
|
|
|
args.prefetcher = not args.no_prefetcher
|
|
|
|
|
device = utils.init_distributed_device(args)
|
|
|
|
|
|
|
|
|
|
# setup model based on args
|
|
|
|
|
in_chans = 3
|
|
|
|
|
if args.in_chans is not None:
|
|
|
|
|
in_chans = args.in_chans
|
|
|
|
|
elif args.input_size is not None:
|
|
|
|
|
in_chans = args.input_size[0]
|
|
|
|
|
|
|
|
|
|
model = create_model(
|
|
|
|
|
args.model,
|
|
|
|
|
pretrained=args.pretrained,
|
|
|
|
|
in_chans=in_chans,
|
|
|
|
|
num_classes=args.num_classes,
|
|
|
|
|
drop_rate=args.drop,
|
|
|
|
|
drop_path_rate=args.drop_path,
|
|
|
|
|
drop_block_rate=args.drop_block,
|
|
|
|
|
global_pool=args.gp,
|
|
|
|
|
bn_momentum=args.bn_momentum,
|
|
|
|
|
bn_eps=args.bn_eps,
|
|
|
|
|
scriptable=args.torchscript,
|
|
|
|
|
checkpoint_path=args.initial_checkpoint,
|
|
|
|
|
)
|
|
|
|
|
if args.num_classes is None:
|
|
|
|
|
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
|
|
|
|
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
|
|
|
|
|
|
|
|
|
|
if args.grad_checkpointing:
|
|
|
|
|
model.set_grad_checkpointing(enable=True)
|
|
|
|
|
|
|
|
|
|
# initialize data config
|
|
|
|
|
data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args))
|
|
|
|
|
output_dir = None
|
|
|
|
|
if args.experiment:
|
|
|
|
|
exp_name = args.experiment
|
|
|
|
|
else:
|
|
|
|
|
exp_name = '-'.join([
|
|
|
|
|
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
|
|
|
|
safe_model_name(args.model),
|
|
|
|
|
str(data_config['input_size'][-1])
|
|
|
|
|
])
|
|
|
|
|
# confirm output directory & write 'train.log' to this directory by default
|
|
|
|
|
output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
|
|
|
|
|
utils.setup_default_logging(log_path=os.path.join(output_dir, 'train.log'))
|
|
|
|
|
|
|
|
|
|
if args.distributed:
|
|
|
|
|
_logger.info(
|
|
|
|
|
'Training in distributed mode with multiple processes, 1 device per process.'
|
|
|
|
@ -390,7 +435,7 @@ def main():
|
|
|
|
|
_logger.warning(
|
|
|
|
|
"You've requested to log metrics to wandb but package not found. "
|
|
|
|
|
"Metrics not being logged to wandb, try `pip install wandb`")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# resolve AMP arguments based on PyTorch / Apex availability
|
|
|
|
|
use_amp = None
|
|
|
|
|
amp_dtype = torch.float16
|
|
|
|
@ -413,38 +458,11 @@ def main():
|
|
|
|
|
if args.fast_norm:
|
|
|
|
|
set_fast_norm()
|
|
|
|
|
|
|
|
|
|
in_chans = 3
|
|
|
|
|
if args.in_chans is not None:
|
|
|
|
|
in_chans = args.in_chans
|
|
|
|
|
elif args.input_size is not None:
|
|
|
|
|
in_chans = args.input_size[0]
|
|
|
|
|
|
|
|
|
|
model = create_model(
|
|
|
|
|
args.model,
|
|
|
|
|
pretrained=args.pretrained,
|
|
|
|
|
in_chans=in_chans,
|
|
|
|
|
num_classes=args.num_classes,
|
|
|
|
|
drop_rate=args.drop,
|
|
|
|
|
drop_path_rate=args.drop_path,
|
|
|
|
|
drop_block_rate=args.drop_block,
|
|
|
|
|
global_pool=args.gp,
|
|
|
|
|
bn_momentum=args.bn_momentum,
|
|
|
|
|
bn_eps=args.bn_eps,
|
|
|
|
|
scriptable=args.torchscript,
|
|
|
|
|
checkpoint_path=args.initial_checkpoint,
|
|
|
|
|
)
|
|
|
|
|
if args.num_classes is None:
|
|
|
|
|
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
|
|
|
|
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
|
|
|
|
|
|
|
|
|
|
if args.grad_checkpointing:
|
|
|
|
|
model.set_grad_checkpointing(enable=True)
|
|
|
|
|
|
|
|
|
|
if utils.is_primary(args):
|
|
|
|
|
_logger.info(
|
|
|
|
|
f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
|
|
|
|
|
|
|
|
|
|
data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# setup augmentation batch splits for contrastive loss or split bn
|
|
|
|
|
num_aug_splits = 0
|
|
|
|
@ -686,17 +704,8 @@ def main():
|
|
|
|
|
best_metric = None
|
|
|
|
|
best_epoch = None
|
|
|
|
|
saver = None
|
|
|
|
|
output_dir = None
|
|
|
|
|
if utils.is_primary(args):
|
|
|
|
|
if args.experiment:
|
|
|
|
|
exp_name = args.experiment
|
|
|
|
|
else:
|
|
|
|
|
exp_name = '-'.join([
|
|
|
|
|
datetime.now().strftime("%Y%m%d-%H%M%S"),
|
|
|
|
|
safe_model_name(args.model),
|
|
|
|
|
str(data_config['input_size'][-1])
|
|
|
|
|
])
|
|
|
|
|
output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
|
|
|
|
|
|
|
|
|
|
decreasing = True if eval_metric == 'loss' else False
|
|
|
|
|
saver = utils.CheckpointSaver(
|
|
|
|
|
model=model,
|
|
|
|
|