Significant ugprade to inference.py, support for different formats, formatting, etc.

pull/1582/head
Ross Wightman 2 years ago committed by Ross Wightman
parent 4d5c395160
commit eceeb9409a

@ -9,12 +9,44 @@ import os
import time import time
import argparse import argparse
import logging import logging
from contextlib import suppress
from functools import partial
import numpy as np import numpy as np
import pandas as pd
import torch import torch
from timm.models import create_model, apply_test_time_pool from timm.models import create_model, apply_test_time_pool, load_checkpoint
from timm.data import ImageDataset, create_loader, resolve_data_config from timm.data import create_dataset, create_loader, resolve_data_config
from timm.utils import AverageMeter, setup_default_logging from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser
try:
from apex import amp
has_apex = True
except ImportError:
has_apex = False
has_native_amp = False
try:
if getattr(torch.cuda.amp, 'autocast') is not None:
has_native_amp = True
except AttributeError:
pass
try:
from functorch.compile import memory_efficient_fusion
has_functorch = True
except ImportError as e:
has_functorch = False
try:
import torch._dynamo
has_dynamo = True
except ImportError:
has_dynamo = False
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
_logger = logging.getLogger('inference') _logger = logging.getLogger('inference')
@ -23,8 +55,10 @@ _logger = logging.getLogger('inference')
parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference') parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference')
parser.add_argument('data', metavar='DIR', parser.add_argument('data', metavar='DIR',
help='path to dataset') help='path to dataset')
parser.add_argument('--output_dir', metavar='DIR', default='./', parser.add_argument('--dataset', '-d', metavar='NAME', default='',
help='path to output files') help='dataset type (default: ImageFolder/ImageTar if empty)')
parser.add_argument('--split', metavar='NAME', default='validation',
help='dataset split (default: validation)')
parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92', parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92',
help='model architecture (default: dpn92)') help='model architecture (default: dpn92)')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
@ -32,17 +66,25 @@ parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
parser.add_argument('-b', '--batch-size', default=256, type=int, parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)') metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=None, type=int, parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension') metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--input-size', default=None, nargs=3, type=int, parser.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--use-train-size', action='store_true', default=False,
help='force use of train input size, even when test size is specified in pretrained cfg')
parser.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='Input image center crop pct')
parser.add_argument('--crop-mode', default=None, type=str,
metavar='N', help='Input image crop mode (squash, border, center). Model default if None.')
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset') help='Override mean pixel value of dataset')
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of of dataset') help='Override std deviation of of dataset')
parser.add_argument('--interpolation', default='', type=str, metavar='NAME', parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)') help='Image resize interpolation type (overrides model)')
parser.add_argument('--num-classes', type=int, default=1000, parser.add_argument('--num-classes', type=int, default=None,
help='Number classes in dataset') help='Number classes in dataset')
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
parser.add_argument('--log-freq', default=10, type=int, parser.add_argument('--log-freq', default=10, type=int,
metavar='N', help='batch logging frequency (default: 10)') metavar='N', help='batch logging frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
@ -51,10 +93,51 @@ parser.add_argument('--pretrained', dest='pretrained', action='store_true',
help='use pre-trained model') help='use pre-trained model')
parser.add_argument('--num-gpu', type=int, default=1, parser.add_argument('--num-gpu', type=int, default=1,
help='Number of GPUS to use') help='Number of GPUS to use')
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true', parser.add_argument('--test-pool', dest='test_pool', action='store_true',
help='disable test time pool') help='enable test time pool')
parser.add_argument('--topk', default=5, type=int, parser.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
parser.add_argument('--device', default='cuda', type=str,
help="Device (accelerator) to use.")
parser.add_argument('--amp', action='store_true', default=False,
help='use Native AMP for mixed precision training')
parser.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
help='use ema version of weights if present')
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--dynamo-backend', default=None, type=str,
help="Select dynamo backend. Default: None")
scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', default=False, action='store_true',
help='torch.jit.script the full model')
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support.")
scripting_group.add_argument('--dynamo', default=False, action='store_true',
help="Enable Dynamo optimization.")
parser.add_argument('--results-dir',type=str, default=None,
help='folder for output results')
parser.add_argument('--results-file', type=str, default=None,
help='results filename (relative to results-dir)')
parser.add_argument('--results-format', type=str, default='csv',
help='results format (one of "csv", "json", "json-split", "parquet")')
parser.add_argument('--topk', default=1, type=int,
metavar='N', help='Top-k to output to CSV') metavar='N', help='Top-k to output to CSV')
parser.add_argument('--fullname', action='store_true', default=False,
help='use full sample name in output (not just basename).')
parser.add_argument('--indices-name', default='index',
help='name for output indices column(s)')
parser.add_argument('--outputs-name', default=None,
help='name for logit/probs output column(s)')
parser.add_argument('--outputs-type', default='prob',
help='output type colum ("prob" for probabilities, "logit" for raw logits)')
parser.add_argument('--separate-columns', action='store_true', default=False,
help='separate output columns per result index.')
parser.add_argument('--exclude-outputs', action='store_true', default=False,
help='exclude logits/probs from results, just indices. topk must be set !=0.')
def main(): def main():
@ -63,48 +146,109 @@ def main():
# might as well try to do something useful... # might as well try to do something useful...
args.pretrained = args.pretrained or not args.checkpoint args.pretrained = args.pretrained or not args.checkpoint
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
device = torch.device(args.device)
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
amp_autocast = suppress
if args.amp:
assert has_native_amp, 'Please update PyTorch to a version with native AMP (or use APEX).'
assert args.amp_dtype in ('float16', 'bfloat16')
amp_dtype = torch.bfloat16 if args.amp_dtype == 'bfloat16' else torch.float16
amp_autocast = partial(torch.autocast, device_type=device.type, dtype=amp_dtype)
_logger.info('Running inference in mixed precision with native PyTorch AMP.')
else:
_logger.info('Running inference in float32. AMP not enabled.')
if args.fuser:
set_jit_fuser(args.fuser)
# create model # create model
model = create_model( model = create_model(
args.model, args.model,
num_classes=args.num_classes, num_classes=args.num_classes,
in_chans=3, in_chans=3,
pretrained=args.pretrained, pretrained=args.pretrained,
checkpoint_path=args.checkpoint) checkpoint_path=args.checkpoint,
)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes
_logger.info('Model %s created, param count: %d' % if args.checkpoint:
(args.model, sum([m.numel() for m in model.parameters()]))) load_checkpoint(model, args.checkpoint, args.use_ema)
config = resolve_data_config(vars(args), model=model) _logger.info(
model, test_time_pool = (model, False) if args.no_test_pool else apply_test_time_pool(model, config) f'Model {args.model} created, param count: {sum([m.numel() for m in model.parameters()])}')
if args.num_gpu > 1: data_config = resolve_data_config(vars(args), model=model)
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() test_time_pool = False
if args.test_pool:
model, test_time_pool = apply_test_time_pool(model, data_config)
model = model.to(device)
model.eval()
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
if args.torchscript:
model = torch.jit.script(model)
elif args.aot_autograd:
assert has_functorch, "functorch is needed for --aot-autograd"
model = memory_efficient_fusion(model)
elif args.dynamo:
assert has_dynamo, "torch._dynamo is needed for --dynamo"
torch._dynamo.reset()
if args.dynamo_backend is not None:
model = torch._dynamo.optimize(args.dynamo_backend)(model)
else: else:
model = model.cuda() model = torch._dynamo.optimize()(model)
if args.num_gpu > 1:
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
dataset = create_dataset(
root=args.data,
name=args.dataset,
split=args.split,
class_map=args.class_map,
)
if test_time_pool:
data_config['crop_pct'] = 1.0
loader = create_loader( loader = create_loader(
ImageDataset(args.data), dataset,
input_size=config['input_size'],
batch_size=args.batch_size, batch_size=args.batch_size,
use_prefetcher=True, use_prefetcher=True,
interpolation=config['interpolation'],
mean=config['mean'],
std=config['std'],
num_workers=args.workers, num_workers=args.workers,
crop_pct=1.0 if test_time_pool else config['crop_pct']) **data_config,
)
model.eval() top_k = min(args.topk, args.num_classes)
k = min(args.topk, args.num_classes)
batch_time = AverageMeter() batch_time = AverageMeter()
end = time.time() end = time.time()
topk_ids = [] all_indices = []
all_outputs = []
use_probs = args.outputs_type == 'prob'
with torch.no_grad(): with torch.no_grad():
for batch_idx, (input, _) in enumerate(loader): for batch_idx, (input, _) in enumerate(loader):
input = input.cuda()
labels = model(input) with amp_autocast():
topk = labels.topk(k)[1] output = model(input)
topk_ids.append(topk.cpu().numpy())
if use_probs:
output = output.softmax(-1)
if top_k:
output, indices = output.topk(top_k)
all_indices.append(indices.cpu().numpy())
all_outputs.append(output.cpu().numpy())
# measure elapsed time # measure elapsed time
batch_time.update(time.time() - end) batch_time.update(time.time() - end)
@ -114,13 +258,57 @@ def main():
_logger.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( _logger.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
batch_idx, len(loader), batch_time=batch_time)) batch_idx, len(loader), batch_time=batch_time))
topk_ids = np.concatenate(topk_ids, axis=0) all_indices = np.concatenate(all_indices, axis=0) if all_indices else None
all_outputs = np.concatenate(all_outputs, axis=0).astype(np.float32)
filenames = loader.dataset.filenames(basename=not args.fullname)
outputs_name = args.outputs_name or ('prob' if use_probs else 'logit')
data_dict = {'filename': filenames}
if args.separate_columns and all_outputs.shape[-1] > 1:
if all_indices is not None:
for i in range(all_indices.shape[-1]):
data_dict[f'{args.indices_name}_{i}'] = all_indices[:, i]
for i in range(all_outputs.shape[-1]):
data_dict[f'{outputs_name}_{i}'] = all_outputs[:, i]
else:
if all_indices is not None:
if all_indices.shape[-1] == 1:
all_indices = all_indices.squeeze(-1)
data_dict[args.indices_name] = list(all_indices)
if all_outputs.shape[-1] == 1:
all_outputs = all_outputs.squeeze(-1)
data_dict[outputs_name] = list(all_outputs)
df = pd.DataFrame(data=data_dict)
results_filename = args.results_file
needs_ext = False
if not results_filename:
# base default filename on model name + img-size
img_size = data_config["input_size"][1]
results_filename = f'{args.model}-{img_size}'
needs_ext = True
with open(os.path.join(args.output_dir, './topk_ids.csv'), 'w') as out_file: if args.results_dir:
filenames = loader.dataset.filenames(basename=True) results_filename = os.path.join(args.results_dir, results_filename)
for filename, label in zip(filenames, topk_ids):
out_file.write('{0},{1}\n'.format( if args.results_format == 'parquet':
filename, ','.join([ str(v) for v in label]))) if needs_ext:
results_filename += '.parquet'
df = df.set_index('filename')
df.to_parquet(results_filename)
elif args.results_format == 'json':
if needs_ext:
results_filename += '.json'
df.to_json(results_filename, lines=True, orient='records')
elif args.results_format == 'json-split':
if needs_ext:
results_filename += '.json'
df.to_json(results_filename, indent=4, orient='split', index=False)
else:
if needs_ext:
results_filename += '.csv'
df.to_csv(results_filename, index=False)
if __name__ == '__main__': if __name__ == '__main__':

Loading…
Cancel
Save