From 7e2e69d608df843f14313993f79186e005052d8b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 25 Nov 2022 23:03:54 -0800 Subject: [PATCH] More inference script changes, arg naming, multiple output fmts at once --- inference.py | 80 ++++++++++++++++++++++++++++------------------------ 1 file changed, 43 insertions(+), 37 deletions(-) diff --git a/inference.py b/inference.py index 64a6787d..5a6b77e9 100755 --- a/inference.py +++ b/inference.py @@ -48,6 +48,13 @@ except ImportError: has_dynamo = False +_FMT_EXT = { + 'json': '.json', + 'json-split': '.json', + 'parquet': '.parquet', + 'csv': '.csv', +} + torch.backends.cudnn.benchmark = True _logger = logging.getLogger('inference') @@ -103,8 +110,6 @@ parser.add_argument('--amp', action='store_true', default=False, help='use Native AMP for mixed precision training') parser.add_argument('--amp-dtype', default='float16', type=str, help='lower precision AMP dtype (default: float16)') -parser.add_argument('--use-ema', dest='use_ema', action='store_true', - help='use ema version of weights if present') parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") parser.add_argument('--dynamo-backend', default=None, type=str, @@ -122,21 +127,23 @@ parser.add_argument('--results-dir',type=str, default=None, help='folder for output results') parser.add_argument('--results-file', type=str, default=None, help='results filename (relative to results-dir)') -parser.add_argument('--results-format', type=str, default='csv', +parser.add_argument('--results-format', type=str, nargs='+', default=['csv'], help='results format (one of "csv", "json", "json-split", "parquet")') +parser.add_argument('--results-separate-col', action='store_true', default=False, + help='separate output columns per result index.') parser.add_argument('--topk', default=1, type=int, metavar='N', help='Top-k to output to CSV') parser.add_argument('--fullname', action='store_true', default=False, help='use full sample name in output (not just basename).') -parser.add_argument('--indices-name', default='index', +parser.add_argument('--filename-col', default='filename', + help='name for filename / sample name column') +parser.add_argument('--index-col', default='index', help='name for output indices column(s)') -parser.add_argument('--outputs-name', default=None, +parser.add_argument('--output-col', default=None, help='name for logit/probs output column(s)') -parser.add_argument('--outputs-type', default='prob', +parser.add_argument('--output-type', default='prob', help='output type colum ("prob" for probabilities, "logit" for raw logits)') -parser.add_argument('--separate-columns', action='store_true', default=False, - help='separate output columns per result index.') -parser.add_argument('--exclude-outputs', action='store_true', default=False, +parser.add_argument('--exclude-output', action='store_true', default=False, help='exclude logits/probs from results, just indices. topk must be set !=0.') @@ -179,9 +186,6 @@ def main(): assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes - if args.checkpoint: - load_checkpoint(model, args.checkpoint, args.use_ema) - _logger.info( f'Model {args.model} created, param count: {sum([m.numel() for m in model.parameters()])}') @@ -221,11 +225,12 @@ def main(): if test_time_pool: data_config['crop_pct'] = 1.0 + workers = 1 if 'tfds' in args.dataset or 'wds' in args.dataset else args.workers loader = create_loader( dataset, batch_size=args.batch_size, use_prefetcher=True, - num_workers=args.workers, + num_workers=workers, **data_config, ) @@ -234,7 +239,7 @@ def main(): end = time.time() all_indices = [] all_outputs = [] - use_probs = args.outputs_type == 'prob' + use_probs = args.output_type == 'prob' with torch.no_grad(): for batch_idx, (input, _) in enumerate(loader): @@ -262,52 +267,53 @@ def main(): all_outputs = np.concatenate(all_outputs, axis=0).astype(np.float32) filenames = loader.dataset.filenames(basename=not args.fullname) - outputs_name = args.outputs_name or ('prob' if use_probs else 'logit') - data_dict = {'filename': filenames} - if args.separate_columns and all_outputs.shape[-1] > 1: + output_col = args.output_col or ('prob' if use_probs else 'logit') + data_dict = {args.filename_col: filenames} + if args.results_separate_col and all_outputs.shape[-1] > 1: if all_indices is not None: for i in range(all_indices.shape[-1]): - data_dict[f'{args.indices_name}_{i}'] = all_indices[:, i] + data_dict[f'{args.index_col}_{i}'] = all_indices[:, i] for i in range(all_outputs.shape[-1]): - data_dict[f'{outputs_name}_{i}'] = all_outputs[:, i] + data_dict[f'{output_col}_{i}'] = all_outputs[:, i] else: if all_indices is not None: if all_indices.shape[-1] == 1: all_indices = all_indices.squeeze(-1) - data_dict[args.indices_name] = list(all_indices) + data_dict[args.index_col] = list(all_indices) if all_outputs.shape[-1] == 1: all_outputs = all_outputs.squeeze(-1) - data_dict[outputs_name] = list(all_outputs) + data_dict[output_col] = list(all_outputs) df = pd.DataFrame(data=data_dict) results_filename = args.results_file - needs_ext = False - if not results_filename: + if results_filename: + filename_no_ext, ext = os.path.splitext(results_filename)[-1] + if ext and ext in _FMT_EXT.values(): + # if filename provided with one of expected ext, + # remove it as it will be added back + results_filename = filename_no_ext + else: # base default filename on model name + img-size img_size = data_config["input_size"][1] results_filename = f'{args.model}-{img_size}' - needs_ext = True if args.results_dir: results_filename = os.path.join(args.results_dir, results_filename) - if args.results_format == 'parquet': - if needs_ext: - results_filename += '.parquet' - df = df.set_index('filename') - df.to_parquet(results_filename) - elif args.results_format == 'json': - if needs_ext: - results_filename += '.json' + for fmt in args.results_format: + save_results(df, results_filename, fmt) + + +def save_results(df, results_filename, results_format='csv', filename_col='filename'): + results_filename += _FMT_EXT[results_format] + if results_format == 'parquet': + df.set_index(filename_col).to_parquet(results_filename) + elif results_format == 'json': df.to_json(results_filename, lines=True, orient='records') - elif args.results_format == 'json-split': - if needs_ext: - results_filename += '.json' + elif results_format == 'json-split': df.to_json(results_filename, indent=4, orient='split', index=False) else: - if needs_ext: - results_filename += '.csv' df.to_csv(results_filename, index=False)