Add support for loading args from yaml file (and saving them with each experiment)

pull/65/head
Ross Wightman 5 years ago
parent d3ba34ee7e
commit 187ecbafbe

@ -1,2 +1,3 @@
torch>=1.1.0
torchvision>=0.3.0
pyyaml

@ -2,6 +2,7 @@
import argparse
import time
import logging
import yaml
from datetime import datetime
try:
@ -26,6 +27,14 @@ import torchvision.utils
torch.backends.cudnn.benchmark = True
# The first arg parser parses out only the --config argument, this argument is used to
# load a yaml file containing key-values that override the defaults for the main parser below
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
help='YAML config file specifying default arguments')
parser = argparse.ArgumentParser(description='Training')
# Dataset / Model parameters
parser.add_argument('data', metavar='DIR',
@ -145,9 +154,27 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
parser.add_argument("--local_rank", default=0, type=int)
def _parse_args():
# Do we have a config file to parse?
args_config, remaining = config_parser.parse_known_args()
if args_config.config:
with open(args_config.config, 'r') as f:
cfg = yaml.safe_load(f)
parser.set_defaults(**cfg)
# The main arg parser parses the rest of the args, the usual
# defaults will have been overridden if config file specified.
args = parser.parse_args(remaining)
# Cache the args as a text string to save them in the output dir later
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
return args, args_text
def main():
setup_default_logging()
args = parser.parse_args()
args, args_text = _parse_args()
args.prefetcher = not args.no_prefetcher
args.distributed = False
if 'WORLD_SIZE' in os.environ:
@ -345,6 +372,8 @@ def main():
output_dir = get_outdir(output_base, 'train', exp_name)
decreasing = True if eval_metric == 'loss' else False
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
try:
for epoch in range(start_epoch, num_epochs):

Loading…
Cancel
Save