diff --git a/inference.py b/inference.py index 5fcf1e60..64a6787d 100755 --- a/inference.py +++ b/inference.py @@ -9,12 +9,44 @@ import os import time import argparse import logging +from contextlib import suppress +from functools import partial + import numpy as np +import pandas as pd import torch -from timm.models import create_model, apply_test_time_pool -from timm.data import ImageDataset, create_loader, resolve_data_config -from timm.utils import AverageMeter, setup_default_logging +from timm.models import create_model, apply_test_time_pool, load_checkpoint +from timm.data import create_dataset, create_loader, resolve_data_config +from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser + + + +try: + from apex import amp + has_apex = True +except ImportError: + has_apex = False + +has_native_amp = False +try: + if getattr(torch.cuda.amp, 'autocast') is not None: + has_native_amp = True +except AttributeError: + pass + +try: + from functorch.compile import memory_efficient_fusion + has_functorch = True +except ImportError as e: + has_functorch = False + +try: + import torch._dynamo + has_dynamo = True +except ImportError: + has_dynamo = False + torch.backends.cudnn.benchmark = True _logger = logging.getLogger('inference') @@ -23,8 +55,10 @@ _logger = logging.getLogger('inference') parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference') parser.add_argument('data', metavar='DIR', help='path to dataset') -parser.add_argument('--output_dir', metavar='DIR', default='./', - help='path to output files') +parser.add_argument('--dataset', '-d', metavar='NAME', default='', + help='dataset type (default: ImageFolder/ImageTar if empty)') +parser.add_argument('--split', metavar='NAME', default='validation', + help='dataset split (default: validation)') parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92', help='model architecture (default: dpn92)') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', @@ -32,17 +66,25 @@ parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256)') parser.add_argument('--img-size', default=None, type=int, - metavar='N', help='Input image dimension') + metavar='N', help='Input image dimension, uses model default if empty') parser.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') +parser.add_argument('--use-train-size', action='store_true', default=False, + help='force use of train input size, even when test size is specified in pretrained cfg') +parser.add_argument('--crop-pct', default=None, type=float, + metavar='N', help='Input image center crop pct') +parser.add_argument('--crop-mode', default=None, type=str, + metavar='N', help='Input image crop mode (squash, border, center). Model default if None.') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override mean pixel value of dataset') -parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', +parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', help='Override std deviation of of dataset') parser.add_argument('--interpolation', default='', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') -parser.add_argument('--num-classes', type=int, default=1000, +parser.add_argument('--num-classes', type=int, default=None, help='Number classes in dataset') +parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', + help='path to class to idx mapping file (default: "")') parser.add_argument('--log-freq', default=10, type=int, metavar='N', help='batch logging frequency (default: 10)') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', @@ -51,10 +93,51 @@ parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') parser.add_argument('--num-gpu', type=int, default=1, help='Number of GPUS to use') -parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true', - help='disable test time pool') -parser.add_argument('--topk', default=5, type=int, +parser.add_argument('--test-pool', dest='test_pool', action='store_true', + help='enable test time pool') +parser.add_argument('--channels-last', action='store_true', default=False, + help='Use channels_last memory layout') +parser.add_argument('--device', default='cuda', type=str, + help="Device (accelerator) to use.") +parser.add_argument('--amp', action='store_true', default=False, + help='use Native AMP for mixed precision training') +parser.add_argument('--amp-dtype', default='float16', type=str, + help='lower precision AMP dtype (default: float16)') +parser.add_argument('--use-ema', dest='use_ema', action='store_true', + help='use ema version of weights if present') +parser.add_argument('--fuser', default='', type=str, + help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") +parser.add_argument('--dynamo-backend', default=None, type=str, + help="Select dynamo backend. Default: None") + +scripting_group = parser.add_mutually_exclusive_group() +scripting_group.add_argument('--torchscript', default=False, action='store_true', + help='torch.jit.script the full model') +scripting_group.add_argument('--aot-autograd', default=False, action='store_true', + help="Enable AOT Autograd support.") +scripting_group.add_argument('--dynamo', default=False, action='store_true', + help="Enable Dynamo optimization.") + +parser.add_argument('--results-dir',type=str, default=None, + help='folder for output results') +parser.add_argument('--results-file', type=str, default=None, + help='results filename (relative to results-dir)') +parser.add_argument('--results-format', type=str, default='csv', + help='results format (one of "csv", "json", "json-split", "parquet")') +parser.add_argument('--topk', default=1, type=int, metavar='N', help='Top-k to output to CSV') +parser.add_argument('--fullname', action='store_true', default=False, + help='use full sample name in output (not just basename).') +parser.add_argument('--indices-name', default='index', + help='name for output indices column(s)') +parser.add_argument('--outputs-name', default=None, + help='name for logit/probs output column(s)') +parser.add_argument('--outputs-type', default='prob', + help='output type colum ("prob" for probabilities, "logit" for raw logits)') +parser.add_argument('--separate-columns', action='store_true', default=False, + help='separate output columns per result index.') +parser.add_argument('--exclude-outputs', action='store_true', default=False, + help='exclude logits/probs from results, just indices. topk must be set !=0.') def main(): @@ -63,48 +146,109 @@ def main(): # might as well try to do something useful... args.pretrained = args.pretrained or not args.checkpoint + if torch.cuda.is_available(): + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.benchmark = True + + device = torch.device(args.device) + + # resolve AMP arguments based on PyTorch / Apex availability + use_amp = None + amp_autocast = suppress + if args.amp: + assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).' + assert args.amp_dtype in ('float16', 'bfloat16') + amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16 + amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype) + _logger.info('Running inference in mixed precision with native PyTorch AMP.') + else: + _logger.info('Running inference in float32. AMP not enabled.') + + if args.fuser: + set_jit_fuser(args.fuser) + # create model model = create_model( args.model, num_classes=args.num_classes, in_chans=3, pretrained=args.pretrained, - checkpoint_path=args.checkpoint) + checkpoint_path=args.checkpoint, + ) + if args.num_classes is None: + assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' + args.num_classes = model.num_classes + + if args.checkpoint: + load_checkpoint(model, args.checkpoint, args.use_ema) + + _logger.info( + f'Model {args.model} created, param count: {sum([m.numel() for m in model.parameters()])}') - _logger.info('Model %s created, param count: %d' % - (args.model, sum([m.numel() for m in model.parameters()]))) + data_config = resolve_data_config(vars(args), model=model) + test_time_pool = False + if args.test_pool: + model, test_time_pool = apply_test_time_pool(model, data_config) - config = resolve_data_config(vars(args), model=model) - model, test_time_pool = (model, False) if args.no_test_pool else apply_test_time_pool(model, config) + model = model.to(device) + model.eval() + if args.channels_last: + model = model.to(memory_format=torch.channels_last) + + if args.torchscript: + model = torch.jit.script(model) + elif args.aot_autograd: + assert has_functorch, "functorch is needed for --aot-autograd" + model = memory_efficient_fusion(model) + elif args.dynamo: + assert has_dynamo, "torch._dynamo is needed for --dynamo" + torch._dynamo.reset() + if args.dynamo_backend is not None: + model = torch._dynamo.optimize(args.dynamo_backend)(model) + else: + model = torch._dynamo.optimize()(model) if args.num_gpu > 1: - model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() - else: - model = model.cuda() + model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))) + + dataset = create_dataset( + root=args.data, + name=args.dataset, + split=args.split, + class_map=args.class_map, + ) + + if test_time_pool: + data_config['crop_pct'] = 1.0 loader = create_loader( - ImageDataset(args.data), - input_size=config['input_size'], + dataset, batch_size=args.batch_size, use_prefetcher=True, - interpolation=config['interpolation'], - mean=config['mean'], - std=config['std'], num_workers=args.workers, - crop_pct=1.0 if test_time_pool else config['crop_pct']) + **data_config, + ) - model.eval() - - k = min(args.topk, args.num_classes) + top_k = min(args.topk, args.num_classes) batch_time = AverageMeter() end = time.time() - topk_ids = [] + all_indices = [] + all_outputs = [] + use_probs = args.outputs_type == 'prob' with torch.no_grad(): for batch_idx, (input, _) in enumerate(loader): - input = input.cuda() - labels = model(input) - topk = labels.topk(k)[1] - topk_ids.append(topk.cpu().numpy()) + + with amp_autocast(): + output = model(input) + + if use_probs: + output = output.softmax(-1) + + if top_k: + output, indices = output.topk(top_k) + all_indices.append(indices.cpu().numpy()) + + all_outputs.append(output.cpu().numpy()) # measure elapsed time batch_time.update(time.time() - end) @@ -114,13 +258,57 @@ def main(): _logger.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( batch_idx, len(loader), batch_time=batch_time)) - topk_ids = np.concatenate(topk_ids, axis=0) + all_indices = np.concatenate(all_indices, axis=0) if all_indices else None + all_outputs = np.concatenate(all_outputs, axis=0).astype(np.float32) + filenames = loader.dataset.filenames(basename=not args.fullname) + + outputs_name = args.outputs_name or ('prob' if use_probs else 'logit') + data_dict = {'filename': filenames} + if args.separate_columns and all_outputs.shape[-1] > 1: + if all_indices is not None: + for i in range(all_indices.shape[-1]): + data_dict[f'{args.indices_name}_{i}'] = all_indices[:, i] + for i in range(all_outputs.shape[-1]): + data_dict[f'{outputs_name}_{i}'] = all_outputs[:, i] + else: + if all_indices is not None: + if all_indices.shape[-1] == 1: + all_indices = all_indices.squeeze(-1) + data_dict[args.indices_name] = list(all_indices) + if all_outputs.shape[-1] == 1: + all_outputs = all_outputs.squeeze(-1) + data_dict[outputs_name] = list(all_outputs) + + df = pd.DataFrame(data=data_dict) + + results_filename = args.results_file + needs_ext = False + if not results_filename: + # base default filename on model name + img-size + img_size = data_config["input_size"][1] + results_filename = f'{args.model}-{img_size}' + needs_ext = True - with open(os.path.join(args.output_dir, './topk_ids.csv'), 'w') as out_file: - filenames = loader.dataset.filenames(basename=True) - for filename, label in zip(filenames, topk_ids): - out_file.write('{0},{1}\n'.format( - filename, ','.join([ str(v) for v in label]))) + if args.results_dir: + results_filename = os.path.join(args.results_dir, results_filename) + + if args.results_format == 'parquet': + if needs_ext: + results_filename += '.parquet' + df = df.set_index('filename') + df.to_parquet(results_filename) + elif args.results_format == 'json': + if needs_ext: + results_filename += '.json' + df.to_json(results_filename, lines=True, orient='records') + elif args.results_format == 'json-split': + if needs_ext: + results_filename += '.json' + df.to_json(results_filename, indent=4, orient='split', index=False) + else: + if needs_ext: + results_filename += '.csv' + df.to_csv(results_filename, index=False) if __name__ == '__main__':