parent
7afadae610
commit
acc0d37d96
@ -0,0 +1,322 @@
|
||||
#!/usr/bin/env python3
|
||||
""" ImageNet Validation Script
|
||||
|
||||
This is intended to be a lean and easily modifiable ImageNet validation script for evaluating pretrained
|
||||
models or training checkpoints against ImageNet or similarly organized image datasets. It prioritizes
|
||||
canonical PyTorch, standard Python style, and good performance. Repurpose as you see fit.
|
||||
|
||||
Hacked together by Ross Wightman (https://github.com/rwightman)
|
||||
"""
|
||||
import argparse
|
||||
import os
|
||||
import csv
|
||||
import glob
|
||||
import time
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.parallel
|
||||
from collections import OrderedDict
|
||||
from contextlib import suppress
|
||||
import torchextractor as tx
|
||||
import torchvision.transforms as T
|
||||
|
||||
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
||||
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
|
||||
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
|
||||
from timm.data.transforms import _pil_interp
|
||||
|
||||
from PIL import Image
|
||||
import json
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
has_apex = False
|
||||
try:
|
||||
from apex import amp
|
||||
has_apex = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
has_native_amp = False
|
||||
try:
|
||||
if getattr(torch.cuda.amp, 'autocast') is not None:
|
||||
has_native_amp = True
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
_logger = logging.getLogger('validate')
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
|
||||
parser.add_argument('data', metavar='DIR',
|
||||
help='path to dataset')
|
||||
parser.add_argument('--dataset', '-d', metavar='NAME', default='',
|
||||
help='dataset type (default: ImageFolder/ImageTar if empty)')
|
||||
parser.add_argument('--split', metavar='NAME', default='validation',
|
||||
help='dataset split (default: validation)')
|
||||
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
|
||||
help='model architecture (default: dpn92)')
|
||||
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 2)')
|
||||
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
||||
metavar='N', help='mini-batch size (default: 256)')
|
||||
parser.add_argument('--img-size', default=None, type=int,
|
||||
metavar='N', help='Input image dimension, uses model default if empty')
|
||||
parser.add_argument('--input-size', default=None, nargs=3, type=int,
|
||||
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
|
||||
parser.add_argument('--crop-pct', default=None, type=float,
|
||||
metavar='N', help='Input image center crop pct')
|
||||
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
|
||||
help='Override mean pixel value of dataset')
|
||||
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
|
||||
help='Override std deviation of of dataset')
|
||||
parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
|
||||
help='Image resize interpolation type (overrides model)')
|
||||
parser.add_argument('--num-classes', type=int, default=None,
|
||||
help='Number classes in dataset')
|
||||
parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
|
||||
help='path to class to idx mapping file (default: "")')
|
||||
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
|
||||
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
|
||||
parser.add_argument('--log-freq', default=10, type=int,
|
||||
metavar='N', help='batch logging frequency (default: 10)')
|
||||
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
||||
help='path to latest checkpoint (default: none)')
|
||||
parser.add_argument('--pretrained', dest='pretrained', action='store_true',
|
||||
help='use pre-trained model')
|
||||
parser.add_argument('--num-gpu', type=int, default=1,
|
||||
help='Number of GPUS to use')
|
||||
parser.add_argument('--no-test-pool', dest='no_test_pool', action='store_true',
|
||||
help='disable test time pool')
|
||||
parser.add_argument('--no-prefetcher', action='store_true', default=False,
|
||||
help='disable fast prefetcher')
|
||||
parser.add_argument('--pin-mem', action='store_true', default=False,
|
||||
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
|
||||
parser.add_argument('--channels-last', action='store_true', default=False,
|
||||
help='Use channels_last memory layout')
|
||||
parser.add_argument('--amp', action='store_true', default=False,
|
||||
help='Use AMP mixed precision. Defaults to Apex, fallback to native Torch AMP.')
|
||||
parser.add_argument('--apex-amp', action='store_true', default=False,
|
||||
help='Use NVIDIA Apex AMP mixed precision')
|
||||
parser.add_argument('--native-amp', action='store_true', default=False,
|
||||
help='Use Native Torch AMP mixed precision')
|
||||
parser.add_argument('--tf-preprocessing', action='store_true', default=False,
|
||||
help='Use Tensorflow preprocessing pipeline (require CPU TF installed')
|
||||
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
||||
help='use ema version of weights if present')
|
||||
parser.add_argument('--torchscript', dest='torchscript', action='store_true',
|
||||
help='convert model torchscript for inference')
|
||||
parser.add_argument('--legacy-jit', dest='legacy_jit', action='store_true',
|
||||
help='use legacy jit mode for pytorch 1.5/1.5.1/1.6 to get back fusion performance')
|
||||
parser.add_argument('--results-file', default='', type=str, metavar='FILENAME',
|
||||
help='Output csv file for validation results (summary)')
|
||||
parser.add_argument('--real-labels', default='', type=str, metavar='FILENAME',
|
||||
help='Real labels JSON file for imagenet evaluation')
|
||||
parser.add_argument('--valid-labels', default='', type=str, metavar='FILENAME',
|
||||
help='Valid label indices txt file for validation of partial label space')
|
||||
parser.add_argument('--hook', dest='hook', action='store_true',
|
||||
help='hook activations')
|
||||
parser.add_argument('--prune', dest='prune', type=float, default=0.0,
|
||||
help='prune linear layers')
|
||||
parser.add_argument('--file', default='', type=str, help='Image file')
|
||||
|
||||
|
||||
def validate(args):
|
||||
# might as well try to validate something
|
||||
args.pretrained = args.pretrained or not args.checkpoint
|
||||
args.prefetcher = not args.no_prefetcher
|
||||
amp_autocast = suppress # do nothing
|
||||
if args.amp:
|
||||
if has_native_amp:
|
||||
args.native_amp = True
|
||||
elif has_apex:
|
||||
args.apex_amp = True
|
||||
else:
|
||||
_logger.warning("Neither APEX or Native Torch AMP is available.")
|
||||
assert not args.apex_amp or not args.native_amp, "Only one AMP mode should be set."
|
||||
if args.native_amp:
|
||||
amp_autocast = torch.cuda.amp.autocast
|
||||
_logger.info('Validating in mixed precision with native PyTorch AMP.')
|
||||
elif args.apex_amp:
|
||||
_logger.info('Validating in mixed precision with NVIDIA APEX AMP.')
|
||||
else:
|
||||
_logger.info('Validating in float32. AMP not enabled.')
|
||||
|
||||
if args.legacy_jit:
|
||||
set_jit_legacy()
|
||||
|
||||
# create model
|
||||
model = create_model(
|
||||
args.model,
|
||||
pretrained=args.pretrained,
|
||||
num_classes=args.num_classes,
|
||||
in_chans=3,
|
||||
global_pool=args.gp,
|
||||
scriptable=args.torchscript)
|
||||
if args.num_classes is None:
|
||||
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
||||
args.num_classes = model.num_classes
|
||||
|
||||
if args.checkpoint:
|
||||
load_checkpoint(model, args.checkpoint, args.use_ema)
|
||||
|
||||
param_count = sum([m.numel() for m in model.parameters()])
|
||||
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
|
||||
|
||||
data_config = resolve_data_config(vars(args), model=model, use_test_size=True, verbose=True)
|
||||
test_time_pool = False
|
||||
if not args.no_test_pool:
|
||||
model, test_time_pool = apply_test_time_pool(model, data_config, use_test_size=True)
|
||||
|
||||
# standard PyTorch mean-std input image normalization
|
||||
transform = T.Compose([
|
||||
T.Resize((data_config["input_size"][1], data_config["input_size"][2]), _pil_interp("bicubic")),
|
||||
T.CenterCrop(data_config["input_size"][1]),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=torch.tensor(data_config["mean"]), std=torch.tensor(data_config["std"]))
|
||||
])
|
||||
|
||||
if (args.hook):
|
||||
from torchsummary import summary
|
||||
|
||||
layer_names = []
|
||||
sparse_layers = 0
|
||||
layer_nr = 0
|
||||
|
||||
summary(model.cuda(), (3, 224, 224))
|
||||
|
||||
for name, module in model.named_modules():
|
||||
layer_names.append(name)
|
||||
#if (isinstance(module, torch.nn.Linear)):
|
||||
# print('Linear ', layer_nr, ' : ', name, ' shape: ', module.weight.shape)
|
||||
|
||||
if (hasattr(module, 'weight') and isinstance(module, torch.nn.Linear)):
|
||||
weights = module.weight.detach()
|
||||
zeros = weights.numel() - weights.nonzero().size(0)
|
||||
sparsity = zeros / weights.numel() * 100.0
|
||||
average = torch.mean(abs(weights))
|
||||
small_val = torch.sum((abs(weights) < 0.05).int()).item() / weights.numel() * 100.0
|
||||
#small_pos = torch.sum((weights < 0.05).int()).item() / weights.numel() * 100.0
|
||||
#small_neg = torch.sum((weights < -0.05).int()).item() / weights.numel() * 100.0
|
||||
if (small_val > 70):
|
||||
#print("layer: ", name, module, ", sparsity: ", sparsity, " small=", int(small_val), ", < 0.05: ", int(small_pos), " neg: ", int(small_neg))
|
||||
print("layer: ", name, module, " small=", int(small_val), ' sparse=', sparsity)
|
||||
sparse_layers += 1
|
||||
else:
|
||||
print("layer: ", name, module, " mean=", average)
|
||||
|
||||
|
||||
# if (name == "model.backbone.conv_stem"):
|
||||
# print(module.weight.shape, module.weight.detach().numpy())
|
||||
layer_nr += 1
|
||||
|
||||
exit()
|
||||
print(layer_names)
|
||||
|
||||
#model = tx.Extractor(model, layer_names)
|
||||
|
||||
if (args.prune != 0.0):
|
||||
# prune all linear layer weights with value < args.prune
|
||||
for name, module in model.named_modules():
|
||||
args.prune = 0
|
||||
|
||||
if args.torchscript:
|
||||
torch.jit.optimized_execution(True)
|
||||
model = torch.jit.script(model)
|
||||
|
||||
model = model.cuda()
|
||||
if args.apex_amp:
|
||||
model = amp.initialize(model, opt_level='O1')
|
||||
|
||||
if args.channels_last:
|
||||
model = model.to(memory_format=torch.channels_last)
|
||||
|
||||
if args.num_gpu > 1:
|
||||
model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu)))
|
||||
|
||||
with open(args.data + '/imagenet_class_index.json') as f:
|
||||
imagenet_dict = json.load(f)
|
||||
|
||||
model.eval()
|
||||
|
||||
if (args.file):
|
||||
im = Image.open(args.file)
|
||||
|
||||
print('Image ', args.file, ' size=', im.size)
|
||||
img = transform(im).unsqueeze(0)
|
||||
img = img.cuda()
|
||||
|
||||
with torch.no_grad():
|
||||
x_class = model(img)
|
||||
max_idx = np.argmax(x_class.cpu().detach().numpy())
|
||||
print(x_class[0][max_idx], max_idx)
|
||||
#for i in range(1000):
|
||||
# if (x_class[0][i] > 0.5):
|
||||
# print('Index: ', i, ' ', x_class[0][i], ' ', imagenet_dict[str(i)][1])
|
||||
#class_name = imagenet_dict[]
|
||||
|
||||
frame = cv2.cvtColor(np.uint8(im), cv2.COLOR_RGB2BGR)
|
||||
cv2.imshow('CLASS: ' + str(max_idx) + ' ' + imagenet_dict[str(max_idx)][1], np.uint8(frame))
|
||||
|
||||
ch = cv2.waitKey()
|
||||
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
setup_default_logging()
|
||||
args = parser.parse_args()
|
||||
model_cfgs = []
|
||||
model_names = []
|
||||
if os.path.isdir(args.checkpoint):
|
||||
# validate all checkpoints in a path with same model
|
||||
checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
|
||||
checkpoints += glob.glob(args.checkpoint + '/*.pth')
|
||||
model_names = list_models(args.model)
|
||||
model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
|
||||
else:
|
||||
if args.model == 'all':
|
||||
# validate all models in a list of names with pretrained checkpoints
|
||||
args.pretrained = True
|
||||
model_names = list_models(pretrained=True, exclude_filters=['*_in21k', '*_in22k'])
|
||||
model_cfgs = [(n, '') for n in model_names]
|
||||
elif not is_model(args.model):
|
||||
# model name doesn't exist, try as wildcard filter
|
||||
model_names = list_models(args.model)
|
||||
model_cfgs = [(n, '') for n in model_names]
|
||||
|
||||
if len(model_cfgs):
|
||||
_logger.info('Running bulk validation on these pretrained models: {}'.format(', '.join(model_names)))
|
||||
try:
|
||||
start_batch_size = args.batch_size
|
||||
for m, c in model_cfgs:
|
||||
batch_size = start_batch_size
|
||||
args.model = m
|
||||
args.checkpoint = c
|
||||
r = {}
|
||||
while not r and batch_size >= args.num_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
try:
|
||||
args.batch_size = batch_size
|
||||
print('Validating with batch size: %d' % args.batch_size)
|
||||
r = validate(args)
|
||||
except RuntimeError as e:
|
||||
if batch_size <= args.num_gpu:
|
||||
print("Validation failed with no ability to reduce batch size. Exiting.")
|
||||
raise e
|
||||
batch_size = max(batch_size // 2, args.num_gpu)
|
||||
print("Validation failed, reducing batch size by 50%")
|
||||
|
||||
except KeyboardInterrupt as e:
|
||||
pass
|
||||
else:
|
||||
validate(args)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
Loading…
Reference in new issue