diff --git a/validate.py b/validate.py index ab5b644f..9b2c0f7e 100755 --- a/validate.py +++ b/validate.py @@ -296,6 +296,11 @@ def main(): model_names = list_models(args.model) model_cfgs = [(n, '') for n in model_names] + if not model_cfgs and os.path.isfile(args.model): + with open(args.model) as f: + model_names = [line.rstrip() for line in f] + model_cfgs = [(n, None) for n in model_names if n] + if len(model_cfgs): results_file = args.results_file or './results-all.csv' _logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))