diff --git a/inference.py b/inference.py index 89efb1fb..e9fd89ef 100755 --- a/inference.py +++ b/inference.py @@ -12,7 +12,9 @@ import logging import numpy as np import torch -from timm.models import create_model, apply_test_time_pool +import torchvision.models as models + +from timm.models import create_model, apply_test_time_pool, load_checkpoint from timm.data import ImageDataset, create_loader, resolve_data_config from timm.utils import AverageMeter, setup_default_logging @@ -27,6 +29,12 @@ parser.add_argument('--output_dir', metavar='DIR', default='./', help='path to output files') parser.add_argument('--model', '-m', metavar='MODEL', default='dpn92', help='model architecture (default: dpn92)') +parser.add_argument('--torchvision-model', default='', type=str, metavar='MODEL', + help='Get a Torchvision model by name') +parser.add_argument('--hub-model', default='', type=str, metavar='MODEL', + help='Get a model from PyTorch Hub by name') +parser.add_argument('--hub-model-github-or-dir', type=str, + help='Specify local directory or Github repository to load model by PyTorch Hub') parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('-b', '--batch-size', default=256, type=int, @@ -43,7 +51,7 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('--num-classes', type=int, default=1000, help='Number classes in dataset') -parser.add_argument('--log-freq', default=10, type=int, +parser.add_argument('--log-interval', default=10, type=int, metavar='N', help='batch logging frequency (default: 10)') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') @@ -64,15 +72,27 @@ def main(): args.pretrained = args.pretrained or not args.checkpoint # create model - model = create_model( - args.model, - num_classes=args.num_classes, - in_chans=3, - pretrained=args.pretrained, - checkpoint_path=args.checkpoint) + model_name = args.model + if args.torchvision_model: + model_name = args.torchvision_model + model = models.__dict__[args.torchvision_model](pretrained=args.pretrained, num_classes=args.num_classes) + if args.checkpoint: + load_checkpoint(model, args.checkpoint) + elif args.hub_model and args.hub_model_github_or_dir: + model_name = args.hub_model + model = torch.hub.load(args.hub_model_github_or_dir, args.hub_model, pretrained=args.pretrained) + if args.checkpoint: + load_checkpoint(model, args.checkpoint) + else: + model = create_model( + args.model, + num_classes=args.num_classes, + in_chans=3, + pretrained=args.pretrained, + checkpoint_path=args.checkpoint) _logger.info('Model %s created, param count: %d' % - (args.model, sum([m.numel() for m in model.parameters()]))) + (model_name, sum([m.numel() for m in model.parameters()]))) config = resolve_data_config(vars(args), model=model) model, test_time_pool = (model, False) if args.no_test_pool else apply_test_time_pool(model, config) @@ -110,7 +130,7 @@ def main(): batch_time.update(time.time() - end) end = time.time() - if batch_idx % args.log_freq == 0: + if batch_idx % args.log_interval == 0: _logger.info('Predict: [{0}/{1}] Time {batch_time.val:.3f} ({batch_time.avg:.3f})'.format( batch_idx, len(loader), batch_time=batch_time)) diff --git a/train.py b/train.py index f3da4a36..5fa0be4e 100755 --- a/train.py +++ b/train.py @@ -25,6 +25,7 @@ from datetime import datetime import torch import torch.nn as nn +import torchvision.models as models import torchvision.utils from torch.nn.parallel import DistributedDataParallel as NativeDDP @@ -74,7 +75,13 @@ parser.add_argument('--train-split', metavar='NAME', default='train', parser.add_argument('--val-split', metavar='NAME', default='validation', help='dataset validation split (default: validation)') parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL', - help='Name of model to train (default: "countception"') + help='Name of model to train (default: "resnet101"') +parser.add_argument('--torchvision-model', default='', type=str, metavar='MODEL', + help='Get a Torchvision model by name') +parser.add_argument('--hub-model', default='', type=str, metavar='MODEL', + help='Get a model from PyTorch Hub by name') +parser.add_argument('--hub-model-github-or-dir', type=str, + help='Specify local directory or Github repository to load model by PyTorch Hub') parser.add_argument('--pretrained', action='store_true', default=False, help='Start with pretrained version of specified network (if avail)') parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', @@ -327,20 +334,33 @@ def main(): torch.manual_seed(args.seed + args.rank) - model = create_model( - args.model, - pretrained=args.pretrained, - num_classes=args.num_classes, - drop_rate=args.drop, - drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path - drop_path_rate=args.drop_path, - drop_block_rate=args.drop_block, - global_pool=args.gp, - bn_tf=args.bn_tf, - bn_momentum=args.bn_momentum, - bn_eps=args.bn_eps, - scriptable=args.torchscript, - checkpoint_path=args.initial_checkpoint) + model_name = args.model + if args.torchvision_model: + model_name = args.torchvision_model + model = models.__dict__[args.torchvision_model](pretrained=args.pretrained, num_classes=args.num_classes) + if args.initial_checkpoint: + load_checkpoint(model, args.initial_checkpoint) + elif args.hub_model and args.hub_model_github_or_dir: + model_name = args.hub_model + model = torch.hub.load(args.hub_model_github_or_dir, args.hub_model, pretrained=args.pretrained) + if args.initial_checkpoint: + load_checkpoint(model, args.initial_checkpoint) + else: + model = create_model( + args.model, + pretrained=args.pretrained, + num_classes=args.num_classes, + drop_rate=args.drop, + drop_connect_rate=args.drop_connect, # DEPRECATED, use drop_path + drop_path_rate=args.drop_path, + drop_block_rate=args.drop_block, + global_pool=args.gp, + bn_tf=args.bn_tf, + bn_momentum=args.bn_momentum, + bn_eps=args.bn_eps, + scriptable=args.torchscript, + checkpoint_path=args.initial_checkpoint) + if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly diff --git a/validate.py b/validate.py index 3f201314..5d77fcba 100755 --- a/validate.py +++ b/validate.py @@ -19,6 +19,8 @@ import torch.nn.parallel from collections import OrderedDict from contextlib import suppress +import torchvision.models as models + from timm.models import create_model, apply_test_time_pool, load_checkpoint, is_model, list_models from timm.data import create_dataset, create_loader, resolve_data_config, RealLabelsImagenet from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_legacy @@ -50,6 +52,12 @@ parser.add_argument('--split', metavar='NAME', default='validation', help='dataset split (default: validation)') parser.add_argument('--model', '-m', metavar='NAME', default='dpn92', help='model architecture (default: dpn92)') +parser.add_argument('--torchvision-model', default='', type=str, metavar='MODEL', + help='Get a Torchvision model by name') +parser.add_argument('--hub-model', default='', type=str, metavar='MODEL', + help='Get a model from PyTorch Hub by name') +parser.add_argument('--hub-model-github-or-dir', type=str, + help='Specify local directory or Github repository to load model by PyTorch Hub') parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 2)') parser.add_argument('-b', '--batch-size', default=256, type=int, @@ -72,7 +80,7 @@ parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', help='path to class to idx mapping file (default: "")') parser.add_argument('--gp', default=None, type=str, metavar='POOL', help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') -parser.add_argument('--log-freq', default=10, type=int, +parser.add_argument('--log-interval', default=10, type=int, metavar='N', help='batch logging frequency (default: 10)') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', help='path to latest checkpoint (default: none)') @@ -135,13 +143,22 @@ def validate(args): set_jit_legacy() # create model - model = create_model( - args.model, - pretrained=args.pretrained, - num_classes=args.num_classes, - in_chans=3, - global_pool=args.gp, - scriptable=args.torchscript) + model_name = args.model + if args.torchvision_model: + model_name = args.torchvision_model + model = models.__dict__[args.torchvision_model](pretrained=args.pretrained, num_classes=args.num_classes) + elif args.hub_model and args.hub_model_github_or_dir: + model_name = args.hub_model + model = torch.hub.load(args.hub_model_github_or_dir, args.hub_model, pretrained=args.pretrained) + else: + model = create_model( + args.model, + pretrained=args.pretrained, + num_classes=args.num_classes, + in_chans=3, + global_pool=args.gp, + scriptable=args.torchscript) + if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' args.num_classes = model.num_classes @@ -150,7 +167,7 @@ def validate(args): load_checkpoint(model, args.checkpoint, args.use_ema) param_count = sum([m.numel() for m in model.parameters()]) - _logger.info('Model %s created, param count: %d' % (args.model, param_count)) + _logger.info('Model %s created, param count: %d' % (model_name, param_count)) data_config = resolve_data_config(vars(args), model=model, use_test_size=True) test_time_pool = False @@ -244,7 +261,7 @@ def validate(args): batch_time.update(time.time() - end) end = time.time() - if batch_idx % args.log_freq == 0: + if batch_idx % args.log_interval == 0: _logger.info( 'Test: [{0:>4d}/{1}] ' 'Time: {batch_time.val:.3f}s ({batch_time.avg:.3f}s, {rate_avg:>7.2f}/s) ' @@ -277,23 +294,30 @@ def validate(args): def main(): setup_default_logging() args = parser.parse_args() + + model_name = args.model + if args.torchvision_model: + model_name = args.torchvision_model + elif args.hub_model and args.hub_model_github_or_dir: + model_name = args.hub_model + model_cfgs = [] model_names = [] if os.path.isdir(args.checkpoint): # validate all checkpoints in a path with same model checkpoints = glob.glob(args.checkpoint + '/*.pth.tar') checkpoints += glob.glob(args.checkpoint + '/*.pth') - model_names = list_models(args.model) - model_cfgs = [(args.model, c) for c in sorted(checkpoints, key=natural_key)] + model_names = list_models(model_name) + model_cfgs = [(model_name, c) for c in sorted(checkpoints, key=natural_key)] else: - if args.model == 'all': + if model_name == 'all': # validate all models in a list of names with pretrained checkpoints args.pretrained = True model_names = list_models(pretrained=True, exclude_filters=['*in21k']) model_cfgs = [(n, '') for n in model_names] - elif not is_model(args.model): + elif not is_model(model_name): # model name doesn't exist, try as wildcard filter - model_names = list_models(args.model) + model_names = list_models(model_name) model_cfgs = [(n, '') for n in model_names] if len(model_cfgs): @@ -304,9 +328,9 @@ def main(): start_batch_size = args.batch_size for m, c in model_cfgs: batch_size = start_batch_size - args.model = m + model_name = m args.checkpoint = c - result = OrderedDict(model=args.model) + result = OrderedDict(model=model_name) r = {} while not r and batch_size >= args.num_gpu: torch.cuda.empty_cache()