#!/usr/bin/env python """ ImageNet Validation Script This is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained models or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes canonical PyTorch, standard Python style, and good performance. Repurpose as you see fit. Hacked together by Ross Wightman (https://github.com/rwightman) """ import argparse import os import csv import glob import time import logging import torch import torch.nn as nn import torch.nn.parallel from collections import OrderedDict from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models from timm.data import Dataset, DatasetTar, create_loader, resolve_data_config from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging torch.backends.cudnn.benchmark = True parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') parser.add_argument('data', metavar='DIR', help='path to dataset') parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92', help='model architecture (default: dpn92)') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256)') parser.add_argument('--img-size', default=None, type=int, metavar='N', help='Input image dimension, uses model default if empty') parser.add_argument('--crop-pct', default=None, type=float, metavar='N', help='Input image center crop pct') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override mean pixel value of dataset') parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', help='Override std deviation of of dataset') parser.add_argument('--interpolation', default='', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('--num-classes', type=int, default=1000, help='Number classes in dataset') parser.add_argument('--log-freq', default=10, type=int, metavar='N', help='batch logging frequency (default: 10)') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') parser.add_argument('--num-gpu', type=int, default=1, help='Number of GPUS to use') parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true', help='disable test time pool') parser.add_argument('--no-prefetcher', action='store_true', default=False, help='disable fast prefetcher') parser.add_argument('--pin-mem', action='store_true', default=False, help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') parser.add_argument('--fp16', action='store_true', default=False, help='Use half precision (fp16)') parser.add_argument('--tf-preprocessing', action='store_true', default=False, help='Use Tensorflow preprocessing pipeline (require CPU TF installed') parser.add_argument('--use-ema', dest='use_ema', action='store_true', help='use ema version of weights if present') parser.add_argument('--torchscript', dest='torchscript', action='store_true', help='convert model torchscript for inference') def validate(args): # might as well try to validate something args.pretrained = args.pretrained or not args.checkpoint args.prefetcher = not args.no_prefetcher # create model model = create_model( args.model, num_classes=args.num_classes, in_chans=3, pretrained=args.pretrained) if args.checkpoint: load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) logging.info('Model %s created, param count: %d' % (args.model, param_count)) data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = apply_test_time_pool(model, data_config, args) if args.torchscript: torch.jit.optimized_execution(True) model = torch.jit.script(model) if args.num_gpu > 1: model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model = model.cuda() if args.fp16: model = model.half() criterion = nn.CrossEntropyLoss().cuda() if os.path.splitext(args.data)[1] == '.tar' and os.path.isfile(args.data): dataset = DatasetTar(args.data, load_bytes=args.tf_preprocessing) else: dataset = Dataset(args.data, load_bytes=args.tf_preprocessing) crop_pct = 1.0 if test_time_pool else data_config['crop_pct'] loader = create_loader( dataset, input_size=data_config['input_size'], batch_size=args.batch_size, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, crop_pct=crop_pct, pin_memory=args.pin_mem, fp16=args.fp16, tf_preprocessing=args.tf_preprocessing) batch_time = AverageMeter() losses = AverageMeter() top1 = AverageMeter() top5 = AverageMeter() model.eval() end = time.time() with torch.no_grad(): for i, (input, target) in enumerate(loader): if args.no_prefetcher: target = target.cuda() input = input.cuda() if args.fp16: input = input.half() # compute output output = model(input) loss = criterion(output, target) # measure accuracy and record loss prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) losses.update(loss.item(), input.size(0)) top1.update(prec1.item(), input.size(0)) top5.update(prec5.item(), input.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.log_freq == 0: logging.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Loss: {loss.val:>7.4f} ({loss.avg:>6.4f}) ' 'Prec@1: {top1.val:>7.3f} ({top1.avg:>7.3f}) ' 'Prec@5: {top5.val:>7.3f} ({top5.avg:>7.3f})'.format( i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, loss=losses, top1=top1, top5=top5)) results = OrderedDict( top1=round(top1.avg, 4), top1_err=round(100 - top1.avg, 4), top5=round(top5.avg, 4), top5_err=round(100 - top5.avg, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], cropt_pct=crop_pct, interpolation=data_config['interpolation']) logging.info(' * Prec@1 {:.3f} ({:.3f}) Prec@5 {:.3f} ({:.3f})'.format( results['top1'], results['top1_err'], results['top5'], results['top5_err'])) return results def main(): setup_default_logging() args = parser.parse_args() model_cfgs = [] model_names = [] if os.path.isdir(args.checkpoint): # validate all checkpoints in a path with same model checkpoints = glob.glob(args.checkpoint + '/*.pth.tar') checkpoints += glob.glob(args.checkpoint + '/*.pth') model_names = list_models(args.model) model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)] else: if args.model == 'all': # validate all models in a list of names with pretrained checkpoints args.pretrained = True model_names = list_models(pretrained=True) model_cfgs = [(n, '') for n in model_names] elif not is_model(args.model): # model name doesn't exist, try as wildcard filter model_names = list_models(args.model) model_cfgs = [(n, '') for n in model_names] if len(model_cfgs): logging.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names))) header_written = False with open('./results-all.csv', mode='w') as cf: for m, c in model_cfgs: args.model = m args.checkpoint = c result = OrderedDict(model=args.model) r = validate(args) result.update(r) if args.checkpoint: result['checkpoint'] = args.checkpoint dw = csv.DictWriter(cf, fieldnames=result.keys()) if not header_written: dw.writeheader() header_written = True dw.writerow(result) cf.flush() else: validate(args) if __name__ == '__main__': main()