Get Torchvision and Pytorch Hub models with scripts

pull/429/head
Csaba Kertesz 5 years ago
parent 1ad1645a50
commit 333af604f8

@ -12,7 +12,9 @@ import logging
import numpy as np
import torch
from timm.models import create_model, apply_test_time_pool
import torchvision.models as models
from timm.models import create_model, apply_test_time_pool, load_checkpoint
from timm.data import ImageDataset, create_loader, resolve_data_config
from timm.utils import AverageMeter, setup_default_logging
@ -27,6 +29,12 @@ parser.add_argument('--output_dir', metavar='DIR', default='./',
help='path to output files')
parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92',
help='model architecture (default: dpn92)')
parser.add_argument('--torchvision-model', default='', type=str, metavar='MODEL',
help='Get a Torchvision model by name')
parser.add_argument('--hub-model', default='', type=str, metavar='MODEL',
help='Get a model from PyTorch Hub by name')
parser.add_argument('--hub-model-github-or-dir', type=str,
help='Specify local directory or Github repository to load model by PyTorch Hub')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
@ -43,7 +51,7 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
parser.add_argument('--num-classes', type=int, default=1000,
help='Number classes in dataset')
parser.add_argument('--log-freq', default=10, type=int,
parser.add_argument('--log-interval', default=10, type=int,
metavar='N', help='batch logging frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
@ -64,15 +72,27 @@ def main():
args.pretrained = args.pretrained or not args.checkpoint
# create model
model = create_model(
args.model,
num_classes=args.num_classes,
in_chans=3,
pretrained=args.pretrained,
checkpoint_path=args.checkpoint)
model_name = args.model
if args.torchvision_model:
model_name = args.torchvision_model
model = models.__dict__[args.torchvision_model](pretrained=args.pretrained, num_classes=args.num_classes)
if args.checkpoint:
load_checkpoint(model, args.checkpoint)
elif args.hub_model and args.hub_model_github_or_dir:
model_name = args.hub_model
model = torch.hub.load(args.hub_model_github_or_dir, args.hub_model, pretrained=args.pretrained)
if args.checkpoint:
load_checkpoint(model, args.checkpoint)
else:
model = create_model(
args.model,
num_classes=args.num_classes,
in_chans=3,
pretrained=args.pretrained,
checkpoint_path=args.checkpoint)
_logger.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()])))
(model_name, sum([m.numel() for m in model.parameters()])))
config = resolve_data_config(vars(args), model=model)
model, test_time_pool = (model, False) if args.no_test_pool else apply_test_time_pool(model, config)
@ -110,7 +130,7 @@ def main():
batch_time.update(time.time() - end)
end = time.time()
if batch_idx % args.log_freq == 0:
if batch_idx % args.log_interval == 0:
_logger.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
batch_idx, len(loader), batch_time=batch_time))

@ -25,6 +25,7 @@ from datetime import datetime
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP
@ -74,7 +75,13 @@ parser.add_argument('--train-split', metavar='NAME', default='train',
parser.add_argument('--val-split', metavar='NAME', default='validation',
help='dataset validation split (default: validation)')
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"')
help='Name of model to train (default: "resnet101"')
parser.add_argument('--torchvision-model', default='', type=str, metavar='MODEL',
help='Get a Torchvision model by name')
parser.add_argument('--hub-model', default='', type=str, metavar='MODEL',
help='Get a model from PyTorch Hub by name')
parser.add_argument('--hub-model-github-or-dir', type=str,
help='Specify local directory or Github repository to load model by PyTorch Hub')
parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
@ -327,20 +334,33 @@ def main():
torch.manual_seed(args.seed + args.rank)
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
drop_rate=args.drop,
drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path
drop_path_rate=args.drop_path,
drop_block_rate=args.drop_block,
global_pool=args.gp,
bn_tf=args.bn_tf,
bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint)
model_name = args.model
if args.torchvision_model:
model_name = args.torchvision_model
model = models.__dict__[args.torchvision_model](pretrained=args.pretrained, num_classes=args.num_classes)
if args.initial_checkpoint:
load_checkpoint(model, args.initial_checkpoint)
elif args.hub_model and args.hub_model_github_or_dir:
model_name = args.hub_model
model = torch.hub.load(args.hub_model_github_or_dir, args.hub_model, pretrained=args.pretrained)
if args.initial_checkpoint:
load_checkpoint(model, args.initial_checkpoint)
else:
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
drop_rate=args.drop,
drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path
drop_path_rate=args.drop_path,
drop_block_rate=args.drop_block,
global_pool=args.gp,
bn_tf=args.bn_tf,
bn_momentum=args.bn_momentum,
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly

@ -19,6 +19,8 @@ import torch.nn.parallel
from collections import OrderedDict
from contextlib import suppress
import torchvision.models as models
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
@ -50,6 +52,12 @@ parser.add_argument('--split', metavar='NAME', default='validation',
help='dataset split (default: validation)')
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
help='model architecture (default: dpn92)')
parser.add_argument('--torchvision-model', default='', type=str, metavar='MODEL',
help='Get a Torchvision model by name')
parser.add_argument('--hub-model', default='', type=str, metavar='MODEL',
help='Get a model from PyTorch Hub by name')
parser.add_argument('--hub-model-github-or-dir', type=str,
help='Specify local directory or Github repository to load model by PyTorch Hub')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 2)')
parser.add_argument('-b', '--batch-size', default=256, type=int,
@ -72,7 +80,7 @@ parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--log-freq', default=10, type=int,
parser.add_argument('--log-interval', default=10, type=int,
metavar='N', help='batch logging frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)')
@ -135,13 +143,22 @@ def validate(args):
set_jit_legacy()
# create model
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
in_chans=3,
global_pool=args.gp,
scriptable=args.torchscript)
model_name = args.model
if args.torchvision_model:
model_name = args.torchvision_model
model = models.__dict__[args.torchvision_model](pretrained=args.pretrained, num_classes=args.num_classes)
elif args.hub_model and args.hub_model_github_or_dir:
model_name = args.hub_model
model = torch.hub.load(args.hub_model_github_or_dir, args.hub_model, pretrained=args.pretrained)
else:
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
in_chans=3,
global_pool=args.gp,
scriptable=args.torchscript)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes
@ -150,7 +167,7 @@ def validate(args):
load_checkpoint(model, args.checkpoint, args.use_ema)
param_count = sum([m.numel() for m in model.parameters()])
_logger.info('Model %s created, param count: %d' % (args.model, param_count))
_logger.info('Model %s created, param count: %d' % (model_name, param_count))
data_config = resolve_data_config(vars(args), model=model, use_test_size=True)
test_time_pool = False
@ -244,7 +261,7 @@ def validate(args):
batch_time.update(time.time() - end)
end = time.time()
if batch_idx % args.log_freq == 0:
if batch_idx % args.log_interval == 0:
_logger.info(
'Test: [{0:>4d}/{1}] '
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
@ -277,23 +294,30 @@ def validate(args):
def main():
setup_default_logging()
args = parser.parse_args()
model_name = args.model
if args.torchvision_model:
model_name = args.torchvision_model
elif args.hub_model and args.hub_model_github_or_dir:
model_name = args.hub_model
model_cfgs = []
model_names = []
if os.path.isdir(args.checkpoint):
# validate all checkpoints in a path with same model
checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
checkpoints += glob.glob(args.checkpoint + '/*.pth')
model_names = list_models(args.model)
model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)]
model_names = list_models(model_name)
model_cfgs = [(model_name, c) for c in sorted(checkpoints, key=natural_key)]
else:
if args.model == 'all':
if model_name == 'all':
# validate all models in a list of names with pretrained checkpoints
args.pretrained = True
model_names = list_models(pretrained=True, exclude_filters=['*in21k'])
model_cfgs = [(n, '') for n in model_names]
elif not is_model(args.model):
elif not is_model(model_name):
# model name doesn't exist, try as wildcard filter
model_names = list_models(args.model)
model_names = list_models(model_name)
model_cfgs = [(n, '') for n in model_names]
if len(model_cfgs):
@ -304,9 +328,9 @@ def main():
start_batch_size = args.batch_size
for m, c in model_cfgs:
batch_size = start_batch_size
args.model = m
model_name = m
args.checkpoint = c
result = OrderedDict(model=args.model)
result = OrderedDict(model=model_name)
r = {}
while not r and batch_size >= args.num_gpu:
torch.cuda.empty_cache()

Loading…
Cancel
Save