|
|
|
@ -2,6 +2,7 @@
|
|
|
|
|
import argparse
|
|
|
|
|
import time
|
|
|
|
|
import logging
|
|
|
|
|
import yaml
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
@ -26,6 +27,14 @@ import torchvision.utils
|
|
|
|
|
|
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# The first arg parser parses out only the --config argument, this argument is used to
|
|
|
|
|
# load a yaml file containing key-values that override the defaults for the main parser below
|
|
|
|
|
config_parser = parser = argparse.ArgumentParser(description='Training Config', add_help=False)
|
|
|
|
|
parser.add_argument('-c', '--config', default='', type=str, metavar='FILE',
|
|
|
|
|
help='YAML config file specifying default arguments')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Training')
|
|
|
|
|
# Dataset / Model parameters
|
|
|
|
|
parser.add_argument('data', metavar='DIR',
|
|
|
|
@ -145,9 +154,27 @@ parser.add_argument('--tta', type=int, default=0, metavar='N',
|
|
|
|
|
parser.add_argument("--local_rank", default=0, type=int)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _parse_args():
|
|
|
|
|
# Do we have a config file to parse?
|
|
|
|
|
args_config, remaining = config_parser.parse_known_args()
|
|
|
|
|
if args_config.config:
|
|
|
|
|
with open(args_config.config, 'r') as f:
|
|
|
|
|
cfg = yaml.safe_load(f)
|
|
|
|
|
parser.set_defaults(**cfg)
|
|
|
|
|
|
|
|
|
|
# The main arg parser parses the rest of the args, the usual
|
|
|
|
|
# defaults will have been overridden if config file specified.
|
|
|
|
|
args = parser.parse_args(remaining)
|
|
|
|
|
|
|
|
|
|
# Cache the args as a text string to save them in the output dir later
|
|
|
|
|
args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
|
|
|
|
|
return args, args_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
setup_default_logging()
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
args, args_text = _parse_args()
|
|
|
|
|
|
|
|
|
|
args.prefetcher = not args.no_prefetcher
|
|
|
|
|
args.distributed = False
|
|
|
|
|
if 'WORLD_SIZE' in os.environ:
|
|
|
|
@ -345,6 +372,8 @@ def main():
|
|
|
|
|
output_dir = get_outdir(output_base, 'train', exp_name)
|
|
|
|
|
decreasing = True if eval_metric == 'loss' else False
|
|
|
|
|
saver = CheckpointSaver(checkpoint_dir=output_dir, decreasing=decreasing)
|
|
|
|
|
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
|
|
|
|
|
f.write(args_text)
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
for epoch in range(start_epoch, num_epochs):
|
|
|
|
|