slightly changed some setup order in main/train.py to achieve writing 'train.log' to output directory by default

pull/1617/head
hova88 2 years ago
parent 3698e79ac5
commit 2ed4871622

@ -364,7 +364,7 @@ def _parse_args():
def main():
utils.setup_default_logging()
args, args_text = _parse_args()
if torch.cuda.is_available():
@ -373,8 +373,53 @@ def main():
if args.data and not args.data_dir:
args.data_dir = args.data
args.prefetcher = not args.no_prefetcher
device = utils.init_distributed_device(args)
# setup model based on args
in_chans = 3
if args.in_chans is not None:
in_chans = args.in_chans
elif args.input_size is not None:
in_chans = args.input_size[0]
model = create_model(
args.model,
pretrained=args.pretrained,
in_chans=in_chans,
num_classes=args.num_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
drop_block_rate=args.drop_block,
global_pool=args.gp,
bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint,
)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
if args.grad_checkpointing:
model.set_grad_checkpointing(enable=True)
# initialize data config
data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args))
output_dir = None
if args.experiment:
exp_name = args.experiment
else:
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
safe_model_name(args.model),
str(data_config['input_size'][-1])
])
# confirm output directory & write 'train.log' to this directory by default
output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
utils.setup_default_logging(log_path=os.path.join(output_dir, 'train.log'))
if args.distributed:
_logger.info(
'Training in distributed mode with multiple processes, 1 device per process.'
@ -390,7 +435,7 @@ def main():
_logger.warning(
"You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
amp_dtype = torch.float16
@ -413,38 +458,11 @@ def main():
if args.fast_norm:
set_fast_norm()
in_chans = 3
if args.in_chans is not None:
in_chans = args.in_chans
elif args.input_size is not None:
in_chans = args.input_size[0]
model = create_model(
args.model,
pretrained=args.pretrained,
in_chans=in_chans,
num_classes=args.num_classes,
drop_rate=args.drop,
drop_path_rate=args.drop_path,
drop_block_rate=args.drop_block,
global_pool=args.gp,
bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint,
)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
if args.grad_checkpointing:
model.set_grad_checkpointing(enable=True)
if utils.is_primary(args):
_logger.info(
f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
data_config = resolve_data_config(vars(args), model=model, verbose=utils.is_primary(args))
# setup augmentation batch splits for contrastive loss or split bn
num_aug_splits = 0
@ -686,17 +704,8 @@ def main():
best_metric = None
best_epoch = None
saver = None
output_dir = None
if utils.is_primary(args):
if args.experiment:
exp_name = args.experiment
else:
exp_name = '-'.join([
datetime.now().strftime("%Y%m%d-%H%M%S"),
safe_model_name(args.model),
str(data_config['input_size'][-1])
])
output_dir = utils.get_outdir(args.output if args.output else './output/train', exp_name)
decreasing = True if eval_metric == 'loss' else False
saver = utils.CheckpointSaver(
model=model,

Loading…
Cancel
Save