Get Torchvision and Pytorch Hub models with scripts

pull/429/head
Csaba Kertesz 5 years ago
parent 1ad1645a50
commit 333af604f8

@ -12,7 +12,9 @@ import logging
import numpy as np import numpy as np
import torch import torch
from timm.models import create_model, apply_test_time_pool import torchvision.models as models
from timm.models import create_model, apply_test_time_pool, load_checkpoint
from timm.data import ImageDataset, create_loader, resolve_data_config from timm.data import ImageDataset, create_loader, resolve_data_config
from timm.utils import AverageMeter, setup_default_logging from timm.utils import AverageMeter, setup_default_logging
@ -27,6 +29,12 @@ parser.add_argument('--output_dir', metavar='DIR', default='./',
help='path to output files') help='path to output files')
parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92', parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92',
help='model architecture (default: dpn92)') help='model architecture (default: dpn92)')
parser.add_argument('--torchvision-model', default='', type=str, metavar='MODEL',
help='Get a Torchvision model by name')
parser.add_argument('--hub-model', default='', type=str, metavar='MODEL',
help='Get a model from PyTorch Hub by name')
parser.add_argument('--hub-model-github-or-dir', type=str,
help='Specify local directory or Github repository to load model by PyTorch Hub')
parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', parser.add_argument('-j', '--workers', default=2, type=int, metavar='N',
help='number of data loading workers (default: 2)') help='number of data loading workers (default: 2)')
parser.add_argument('-b', '--batch-size', default=256, type=int, parser.add_argument('-b', '--batch-size', default=256, type=int,
@ -43,7 +51,7 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)') help='Image resize interpolation type (overrides model)')
parser.add_argument('--num-classes', type=int, default=1000, parser.add_argument('--num-classes', type=int, default=1000,
help='Number classes in dataset') help='Number classes in dataset')
parser.add_argument('--log-freq', default=10, type=int, parser.add_argument('--log-interval', default=10, type=int,
metavar='N', help='batch logging frequency (default: 10)') metavar='N', help='batch logging frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)') help='path to latest checkpoint (default: none)')
@ -64,6 +72,18 @@ def main():
args.pretrained = args.pretrained or not args.checkpoint args.pretrained = args.pretrained or not args.checkpoint
# create model # create model
model_name = args.model
if args.torchvision_model:
model_name = args.torchvision_model
model = models.__dict__[args.torchvision_model](pretrained=args.pretrained, num_classes=args.num_classes)
if args.checkpoint:
load_checkpoint(model, args.checkpoint)
elif args.hub_model and args.hub_model_github_or_dir:
model_name = args.hub_model
model = torch.hub.load(args.hub_model_github_or_dir, args.hub_model, pretrained=args.pretrained)
if args.checkpoint:
load_checkpoint(model, args.checkpoint)
else:
model = create_model( model = create_model(
args.model, args.model,
num_classes=args.num_classes, num_classes=args.num_classes,
@ -72,7 +92,7 @@ def main():
checkpoint_path=args.checkpoint) checkpoint_path=args.checkpoint)
_logger.info('Model %s created, param count: %d' % _logger.info('Model %s created, param count: %d' %
(args.model, sum([m.numel() for m in model.parameters()]))) (model_name, sum([m.numel() for m in model.parameters()])))
config = resolve_data_config(vars(args), model=model) config = resolve_data_config(vars(args), model=model)
model, test_time_pool = (model, False) if args.no_test_pool else apply_test_time_pool(model, config) model, test_time_pool = (model, False) if args.no_test_pool else apply_test_time_pool(model, config)
@ -110,7 +130,7 @@ def main():
batch_time.update(time.time() - end) batch_time.update(time.time() - end)
end = time.time() end = time.time()
if batch_idx % args.log_freq == 0: if batch_idx % args.log_interval == 0:
_logger.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( _logger.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format(
batch_idx, len(loader), batch_time=batch_time)) batch_idx, len(loader), batch_time=batch_time))

@ -25,6 +25,7 @@ from datetime import datetime
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision.models as models
import torchvision.utils import torchvision.utils
from torch.nn.parallel import DistributedDataParallel as NativeDDP from torch.nn.parallel import DistributedDataParallel as NativeDDP
@ -74,7 +75,13 @@ parser.add_argument('--train-split', metavar='NAME', default='train',
parser.add_argument('--val-split', metavar='NAME', default='validation', parser.add_argument('--val-split', metavar='NAME', default='validation',
help='dataset validation split (default: validation)') help='dataset validation split (default: validation)')
parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL', parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL',
help='Name of model to train (default: "countception"') help='Name of model to train (default: "resnet101"')
parser.add_argument('--torchvision-model', default='', type=str, metavar='MODEL',
help='Get a Torchvision model by name')
parser.add_argument('--hub-model', default='', type=str, metavar='MODEL',
help='Get a model from PyTorch Hub by name')
parser.add_argument('--hub-model-github-or-dir', type=str,
help='Specify local directory or Github repository to load model by PyTorch Hub')
parser.add_argument('--pretrained', action='store_true', default=False, parser.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)') help='Start with pretrained version of specified network (if avail)')
parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
@ -327,6 +334,18 @@ def main():
torch.manual_seed(args.seed + args.rank) torch.manual_seed(args.seed + args.rank)
model_name = args.model
if args.torchvision_model:
model_name = args.torchvision_model
model = models.__dict__[args.torchvision_model](pretrained=args.pretrained, num_classes=args.num_classes)
if args.initial_checkpoint:
load_checkpoint(model, args.initial_checkpoint)
elif args.hub_model and args.hub_model_github_or_dir:
model_name = args.hub_model
model = torch.hub.load(args.hub_model_github_or_dir, args.hub_model, pretrained=args.pretrained)
if args.initial_checkpoint:
load_checkpoint(model, args.initial_checkpoint)
else:
model = create_model( model = create_model(
args.model, args.model,
pretrained=args.pretrained, pretrained=args.pretrained,
@ -341,6 +360,7 @@ def main():
bn_eps=args.bn_eps, bn_eps=args.bn_eps,
scriptable=args.torchscript, scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint) checkpoint_path=args.initial_checkpoint)
if args.num_classes is None: if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly

@ -19,6 +19,8 @@ import torch.nn.parallel
from collections import OrderedDict from collections import OrderedDict
from contextlib import suppress from contextlib import suppress
import torchvision.models as models
from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models
from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy
@ -50,6 +52,12 @@ parser.add_argument('--split', metavar='NAME', default='validation',
help='dataset split (default: validation)') help='dataset split (default: validation)')
parser.add_argument('--model', '-m', metavar='NAME', default='dpn92', parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
help='model architecture (default: dpn92)') help='model architecture (default: dpn92)')
parser.add_argument('--torchvision-model', default='', type=str, metavar='MODEL',
help='Get a Torchvision model by name')
parser.add_argument('--hub-model', default='', type=str, metavar='MODEL',
help='Get a model from PyTorch Hub by name')
parser.add_argument('--hub-model-github-or-dir', type=str,
help='Specify local directory or Github repository to load model by PyTorch Hub')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 2)') help='number of data loading workers (default: 2)')
parser.add_argument('-b', '--batch-size', default=256, type=int, parser.add_argument('-b', '--batch-size', default=256, type=int,
@ -72,7 +80,7 @@ parser.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")') help='path to class to idx mapping file (default: "")')
parser.add_argument('--gp', default=None, type=str, metavar='POOL', parser.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
parser.add_argument('--log-freq', default=10, type=int, parser.add_argument('--log-interval', default=10, type=int,
metavar='N', help='batch logging frequency (default: 10)') metavar='N', help='batch logging frequency (default: 10)')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
help='path to latest checkpoint (default: none)') help='path to latest checkpoint (default: none)')
@ -135,6 +143,14 @@ def validate(args):
set_jit_legacy() set_jit_legacy()
# create model # create model
model_name = args.model
if args.torchvision_model:
model_name = args.torchvision_model
model = models.__dict__[args.torchvision_model](pretrained=args.pretrained, num_classes=args.num_classes)
elif args.hub_model and args.hub_model_github_or_dir:
model_name = args.hub_model
model = torch.hub.load(args.hub_model_github_or_dir, args.hub_model, pretrained=args.pretrained)
else:
model = create_model( model = create_model(
args.model, args.model,
pretrained=args.pretrained, pretrained=args.pretrained,
@ -142,6 +158,7 @@ def validate(args):
in_chans=3, in_chans=3,
global_pool=args.gp, global_pool=args.gp,
scriptable=args.torchscript) scriptable=args.torchscript)
if args.num_classes is None: if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes args.num_classes = model.num_classes
@ -150,7 +167,7 @@ def validate(args):
load_checkpoint(model, args.checkpoint, args.use_ema) load_checkpoint(model, args.checkpoint, args.use_ema)
param_count = sum([m.numel() for m in model.parameters()]) param_count = sum([m.numel() for m in model.parameters()])
_logger.info('Model %s created, param count: %d' % (args.model, param_count)) _logger.info('Model %s created, param count: %d' % (model_name, param_count))
data_config = resolve_data_config(vars(args), model=model, use_test_size=True) data_config = resolve_data_config(vars(args), model=model, use_test_size=True)
test_time_pool = False test_time_pool = False
@ -244,7 +261,7 @@ def validate(args):
batch_time.update(time.time() - end) batch_time.update(time.time() - end)
end = time.time() end = time.time()
if batch_idx % args.log_freq == 0: if batch_idx % args.log_interval == 0:
_logger.info( _logger.info(
'Test: [{0:>4d}/{1}] ' 'Test: [{0:>4d}/{1}] '
'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) '
@ -277,23 +294,30 @@ def validate(args):
def main(): def main():
setup_default_logging() setup_default_logging()
args = parser.parse_args() args = parser.parse_args()
model_name = args.model
if args.torchvision_model:
model_name = args.torchvision_model
elif args.hub_model and args.hub_model_github_or_dir:
model_name = args.hub_model
model_cfgs = [] model_cfgs = []
model_names = [] model_names = []
if os.path.isdir(args.checkpoint): if os.path.isdir(args.checkpoint):
# validate all checkpoints in a path with same model # validate all checkpoints in a path with same model
checkpoints = glob.glob(args.checkpoint + '/*.pth.tar') checkpoints = glob.glob(args.checkpoint + '/*.pth.tar')
checkpoints += glob.glob(args.checkpoint + '/*.pth') checkpoints += glob.glob(args.checkpoint + '/*.pth')
model_names = list_models(args.model) model_names = list_models(model_name)
model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)] model_cfgs = [(model_name, c) for c in sorted(checkpoints, key=natural_key)]
else: else:
if args.model == 'all': if model_name == 'all':
# validate all models in a list of names with pretrained checkpoints # validate all models in a list of names with pretrained checkpoints
args.pretrained = True args.pretrained = True
model_names = list_models(pretrained=True, exclude_filters=['*in21k']) model_names = list_models(pretrained=True, exclude_filters=['*in21k'])
model_cfgs = [(n, '') for n in model_names] model_cfgs = [(n, '') for n in model_names]
elif not is_model(args.model): elif not is_model(model_name):
# model name doesn't exist, try as wildcard filter # model name doesn't exist, try as wildcard filter
model_names = list_models(args.model) model_names = list_models(model_name)
model_cfgs = [(n, '') for n in model_names] model_cfgs = [(n, '') for n in model_names]
if len(model_cfgs): if len(model_cfgs):
@ -304,9 +328,9 @@ def main():
start_batch_size = args.batch_size start_batch_size = args.batch_size
for m, c in model_cfgs: for m, c in model_cfgs:
batch_size = start_batch_size batch_size = start_batch_size
args.model = m model_name = m
args.checkpoint = c args.checkpoint = c
result = OrderedDict(model=args.model) result = OrderedDict(model=model_name)
r = {} r = {}
while not r and batch_size >= args.num_gpu: while not r and batch_size >= args.num_gpu:
torch.cuda.empty_cache() torch.cuda.empty_cache()

Loading…
Cancel
Save