|
|
|
@ -19,6 +19,8 @@ import torch.nn.parallel
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
from contextlib import suppress
|
|
|
|
|
|
|
|
|
|
import torchvision.models as models
|
|
|
|
|
|
|
|
|
|
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
|
|
|
|
|
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
|
|
|
|
|
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
|
|
|
|
@ -50,6 +52,12 @@ parser.add_argument('--split', metavar='NAME', default='validation',
|
|
|
|
|
help='dataset split (default: validation)')
|
|
|
|
|
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
|
|
|
|
|
help='model architecture (default: dpn92)')
|
|
|
|
|
parser.add_argument('--torchvision-model', default='', type=str, metavar='MODEL',
|
|
|
|
|
help='Get a Torchvision model by name')
|
|
|
|
|
parser.add_argument('--hub-model', default='', type=str, metavar='MODEL',
|
|
|
|
|
help='Get a model from PyTorch Hub by name')
|
|
|
|
|
parser.add_argument('--hub-model-github-or-dir', type=str,
|
|
|
|
|
help='Specify local directory or Github repository to load model by PyTorch Hub')
|
|
|
|
|
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
|
|
|
|
help='number of data loading workers (default: 2)')
|
|
|
|
|
parser.add_argument('-b', '--batch-size', default=256, type=int,
|
|
|
|
@ -72,7 +80,7 @@ parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
|
|
|
|
|
help='path to class to idx mapping file (default: "")')
|
|
|
|
|
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
|
|
|
|
|
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
|
|
|
|
|
parser.add_argument('--log-freq', default=10, type=int,
|
|
|
|
|
parser.add_argument('--log-interval', default=10, type=int,
|
|
|
|
|
metavar='N', help='batch logging frequency (default: 10)')
|
|
|
|
|
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
|
|
|
|
|
help='path to latest checkpoint (default: none)')
|
|
|
|
@ -135,13 +143,22 @@ def validate(args):
|
|
|
|
|
set_jit_legacy()
|
|
|
|
|
|
|
|
|
|
# create model
|
|
|
|
|
model = create_model(
|
|
|
|
|
args.model,
|
|
|
|
|
pretrained=args.pretrained,
|
|
|
|
|
num_classes=args.num_classes,
|
|
|
|
|
in_chans=3,
|
|
|
|
|
global_pool=args.gp,
|
|
|
|
|
scriptable=args.torchscript)
|
|
|
|
|
model_name = args.model
|
|
|
|
|
if args.torchvision_model:
|
|
|
|
|
model_name = args.torchvision_model
|
|
|
|
|
model = models.__dict__[args.torchvision_model](pretrained=args.pretrained, num_classes=args.num_classes)
|
|
|
|
|
elif args.hub_model and args.hub_model_github_or_dir:
|
|
|
|
|
model_name = args.hub_model
|
|
|
|
|
model = torch.hub.load(args.hub_model_github_or_dir, args.hub_model, pretrained=args.pretrained)
|
|
|
|
|
else:
|
|
|
|
|
model = create_model(
|
|
|
|
|
args.model,
|
|
|
|
|
pretrained=args.pretrained,
|
|
|
|
|
num_classes=args.num_classes,
|
|
|
|
|
in_chans=3,
|
|
|
|
|
global_pool=args.gp,
|
|
|
|
|
scriptable=args.torchscript)
|
|
|
|
|
|
|
|
|
|
if args.num_classes is None:
|
|
|
|
|
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
|
|
|
|
args.num_classes = model.num_classes
|
|
|
|
@ -150,7 +167,7 @@ def validate(args):
|
|
|
|
|
load_checkpoint(model, args.checkpoint, args.use_ema)
|
|
|
|
|
|
|
|
|
|
param_count = sum([m.numel() for m in model.parameters()])
|
|
|
|
|
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
|
|
|
|
|
_logger.info('Model %s created, param count: %d' % (model_name, param_count))
|
|
|
|
|
|
|
|
|
|
data_config = resolve_data_config(vars(args), model=model, use_test_size=True)
|
|
|
|
|
test_time_pool = False
|
|
|
|
@ -244,7 +261,7 @@ def validate(args):
|
|
|
|
|
batch_time.update(time.time() - end)
|
|
|
|
|
end = time.time()
|
|
|
|
|
|
|
|
|
|
if batch_idx % args.log_freq == 0:
|
|
|
|
|
if batch_idx % args.log_interval == 0:
|
|
|
|
|
_logger.info(
|
|
|
|
|
'Test: [{0:>4d}/{1}] '
|
|
|
|
|
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
|
|
|
|
@ -277,23 +294,30 @@ def validate(args):
|
|
|
|
|
def main():
|
|
|
|
|
setup_default_logging()
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
|
|
model_name = args.model
|
|
|
|
|
if args.torchvision_model:
|
|
|
|
|
model_name = args.torchvision_model
|
|
|
|
|
elif args.hub_model and args.hub_model_github_or_dir:
|
|
|
|
|
model_name = args.hub_model
|
|
|
|
|
|
|
|
|
|
model_cfgs = []
|
|
|
|
|
model_names = []
|
|
|
|
|
if os.path.isdir(args.checkpoint):
|
|
|
|
|
# validate all checkpoints in a path with same model
|
|
|
|
|
checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
|
|
|
|
|
checkpoints += glob.glob(args.checkpoint + '/*.pth')
|
|
|
|
|
model_names = list_models(args.model)
|
|
|
|
|
model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
|
|
|
|
|
model_names = list_models(model_name)
|
|
|
|
|
model_cfgs = [(model_name, c) for c in sorted(checkpoints, key=natural_key)]
|
|
|
|
|
else:
|
|
|
|
|
if args.model == 'all':
|
|
|
|
|
if model_name == 'all':
|
|
|
|
|
# validate all models in a list of names with pretrained checkpoints
|
|
|
|
|
args.pretrained = True
|
|
|
|
|
model_names = list_models(pretrained=True, exclude_filters=['*in21k'])
|
|
|
|
|
model_cfgs = [(n, '') for n in model_names]
|
|
|
|
|
elif not is_model(args.model):
|
|
|
|
|
elif not is_model(model_name):
|
|
|
|
|
# model name doesn't exist, try as wildcard filter
|
|
|
|
|
model_names = list_models(args.model)
|
|
|
|
|
model_names = list_models(model_name)
|
|
|
|
|
model_cfgs = [(n, '') for n in model_names]
|
|
|
|
|
|
|
|
|
|
if len(model_cfgs):
|
|
|
|
@ -304,9 +328,9 @@ def main():
|
|
|
|
|
start_batch_size = args.batch_size
|
|
|
|
|
for m, c in model_cfgs:
|
|
|
|
|
batch_size = start_batch_size
|
|
|
|
|
args.model = m
|
|
|
|
|
model_name = m
|
|
|
|
|
args.checkpoint = c
|
|
|
|
|
result = OrderedDict(model=args.model)
|
|
|
|
|
result = OrderedDict(model=model_name)
|
|
|
|
|
r = {}
|
|
|
|
|
while not r and batch_size >= args.num_gpu:
|
|
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|