diff --git a/requirements.txt b/requirements.txt index 5cde46bc..88ce152f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ torch>=1.1.0 torchvision>=0.3.0 +pyyaml diff --git a/train.py b/train.py index f6b6e407..d4dd7332 100644 --- a/train.py +++ b/train.py @@ -2,6 +2,7 @@ import argparse import time import logging +import yaml from datetime import datetime try: @@ -26,6 +27,14 @@ import torchvision.utils torch.backends.cudnn.benchmark = True + +# The first arg parser parses out only the --config argument, this argument is used to +# load a yaml file containing key-values that override the defaults for the main parser below +config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False) +parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', + help='YAML config file specifying default arguments') + + parser = argparse.ArgumentParser(description='Training') # Dataset / Model parameters parser.add_argument('data', metavar='DIR', @@ -145,9 +154,27 @@ parser.add_argument('--tta', type=int, default=0, metavar='N', parser.add_argument("--local_rank", default=0, type=int) +def _parse_args(): + # Do we have a config file to parse? + args_config, remaining = config_parser.parse_known_args() + if args_config.config: + with open(args_config.config, 'r') as f: + cfg = yaml.safe_load(f) + parser.set_defaults(**cfg) + + # The main arg parser parses the rest of the args, the usual + # defaults will have been overridden if config file specified. + args = parser.parse_args(remaining) + + # Cache the args as a text string to save them in the output dir later + args_text = yaml.safe_dump(args.__dict__, default_flow_style=False) + return args, args_text + + def main(): setup_default_logging() - args = parser.parse_args() + args, args_text = _parse_args() + args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: @@ -345,6 +372,8 @@ def main(): output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing) + with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: + f.write(args_text) try: for epoch in range(start_epoch, num_epochs):