From c6b32cbe734967f42aaed21d9eb0f96cebe9e991 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 28 Jun 2019 13:49:20 -0700 Subject: [PATCH 1/6] A number of tweaks to arguments, epoch handling, config * reorganize train args * allow resolve_data_config to be used with dict args, not just arparse * stop incrementing epoch before save, more consistent naming vs csv, etc * update resume and start epoch handling to match above * stop auto-incrementing epoch in scheduler --- inference.py | 2 +- timm/data/config.py | 68 ++++++++++++--------------- timm/data/loader.py | 2 + timm/data/transforms.py | 10 +++- timm/models/gen_efficientnet.py | 4 +- timm/models/helpers.py | 11 +++-- timm/scheduler/scheduler.py | 2 +- timm/utils.py | 3 +- train.py | 83 +++++++++++++++++++-------------- validate.py | 2 +- 10 files changed, 104 insertions(+), 83 deletions(-) diff --git a/inference.py b/inference.py index 9077cc07..3255a8d9 100644 --- a/inference.py +++ b/inference.py @@ -70,7 +70,7 @@ def main(): logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) - config = resolve_data_config(model, args) + config = resolve_data_config(vars(args), model=model) model, test_time_pool = apply_test_time_pool(model, config, args) if args.num_gpu > 1: diff --git a/timm/data/config.py b/timm/data/config.py index 1675d2a9..8a83d19f 100644 --- a/timm/data/config.py +++ b/timm/data/config.py @@ -2,35 +2,43 @@ import logging from .constants import * -def resolve_data_config(model, args, default_cfg={}, verbose=True): +def resolve_data_config(args, default_cfg={}, model=None, verbose=True): new_config = {} default_cfg = default_cfg - if not default_cfg and hasattr(model, 'default_cfg'): + if not default_cfg and model is not None and hasattr(model, 'default_cfg'): default_cfg = model.default_cfg # Resolve input/image size - # FIXME grayscale/chans arg to use different # channels? in_chans = 3 + if 'chans' in args and args['chans'] is not None: + in_chans = args['chans'] + input_size = (in_chans, 224, 224) - if args.img_size is not None: - # FIXME support passing img_size as tuple, non-square - assert isinstance(args.img_size, int) - input_size = (in_chans, args.img_size, args.img_size) + if 'input_size' in args and args['input_size'] is not None: + assert isinstance(args['input_size'], (tuple, list)) + assert len(args['input_size']) == 3 + input_size = tuple(args['input_size']) + in_chans = input_size[0] # input_size overrides in_chans + elif 'img_size' in args and args['img_size'] is not None: + assert isinstance(args['img_size'], int) + input_size = (in_chans, args['img_size'], args['img_size']) elif 'input_size' in default_cfg: input_size = default_cfg['input_size'] new_config['input_size'] = input_size # resolve interpolation method - new_config['interpolation'] = 'bilinear' - if args.interpolation: - new_config['interpolation'] = args.interpolation + new_config['interpolation'] = 'bicubic' + if 'interpolation' in args and args['interpolation']: + new_config['interpolation'] = args['interpolation'] elif 'interpolation' in default_cfg: new_config['interpolation'] = default_cfg['interpolation'] # resolve dataset + model mean for normalization - new_config['mean'] = get_mean_by_model(args.model) - if args.mean is not None: - mean = tuple(args.mean) + new_config['mean'] = IMAGENET_DEFAULT_MEAN + if 'model' in args: + new_config['mean'] = get_mean_by_model(args['model']) + if 'mean' in args and args['mean'] is not None: + mean = tuple(args['mean']) if len(mean) == 1: mean = tuple(list(mean) * in_chans) else: @@ -40,9 +48,11 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True): new_config['mean'] = default_cfg['mean'] # resolve dataset + model std deviation for normalization - new_config['std'] = get_std_by_model(args.model) - if args.std is not None: - std = tuple(args.std) + new_config['std'] = IMAGENET_DEFAULT_STD + if 'model' in args: + new_config['std'] = get_std_by_model(args['model']) + if 'std' in args and args['std'] is not None: + std = tuple(args['std']) if len(std) == 1: std = tuple(list(std) * in_chans) else: @@ -53,7 +63,9 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True): # resolve default crop percentage new_config['crop_pct'] = DEFAULT_CROP_PCT - if 'crop_pct' in default_cfg: + if 'crop_pct' in args and args['crop_pct'] is not None: + new_config['crop_pct'] = args['crop_pct'] + elif 'crop_pct' in default_cfg: new_config['crop_pct'] = default_cfg['crop_pct'] if verbose: @@ -64,29 +76,11 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True): return new_config -def get_mean_by_name(name): - if name == 'dpn': - return IMAGENET_DPN_MEAN - elif name == 'inception' or name == 'le': - return IMAGENET_INCEPTION_MEAN - else: - return IMAGENET_DEFAULT_MEAN - - -def get_std_by_name(name): - if name == 'dpn': - return IMAGENET_DPN_STD - elif name == 'inception' or name == 'le': - return IMAGENET_INCEPTION_STD - else: - return IMAGENET_DEFAULT_STD - - def get_mean_by_model(model_name): model_name = model_name.lower() if 'dpn' in model_name: return IMAGENET_DPN_STD - elif 'ception' in model_name or 'nasnet' in model_name: + elif 'ception' in model_name or ('nasnet' in model_name and 'mnasnet' not in model_name): return IMAGENET_INCEPTION_MEAN else: return IMAGENET_DEFAULT_MEAN @@ -96,7 +90,7 @@ def get_std_by_model(model_name): model_name = model_name.lower() if 'dpn' in model_name: return IMAGENET_DEFAULT_STD - elif 'ception' in model_name or 'nasnet' in model_name: + elif 'ception' in model_name or ('nasnet' in model_name and 'mnasnet' not in model_name): return IMAGENET_INCEPTION_STD else: return IMAGENET_DEFAULT_STD diff --git a/timm/data/loader.py b/timm/data/loader.py index 777eb878..6a19b805 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -86,6 +86,7 @@ def create_loader( use_prefetcher=True, rand_erase_prob=0., rand_erase_mode='const', + color_jitter=0.4, interpolation='bilinear', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, @@ -107,6 +108,7 @@ def create_loader( if is_training: transform = transforms_imagenet_train( img_size, + color_jitter=color_jitter, interpolation=interpolation, use_prefetcher=use_prefetcher, mean=mean, diff --git a/timm/data/transforms.py b/timm/data/transforms.py index bee505a2..1e1b054a 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -156,7 +156,7 @@ class RandomResizedCropAndInterpolation(object): def transforms_imagenet_train( img_size=224, scale=(0.08, 1.0), - color_jitter=(0.4, 0.4, 0.4), + color_jitter=0.4, interpolation='random', random_erasing=0.4, random_erasing_mode='const', @@ -164,6 +164,14 @@ def transforms_imagenet_train( mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD ): + if isinstance(color_jitter, (list, tuple)): + # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation + # or 4 if also augmenting hue + assert len(color_jitter) in (3, 4) + else: + # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue + color_jitter = (float(color_jitter),) * 3 + print(*color_jitter) tfl = [ RandomResizedCropAndInterpolation( diff --git a/timm/models/gen_efficientnet.py b/timm/models/gen_efficientnet.py index 1f5890bc..0642a1cb 100644 --- a/timm/models/gen_efficientnet.py +++ b/timm/models/gen_efficientnet.py @@ -1430,7 +1430,7 @@ def efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ EfficientNet-B1 """ default_cfg = default_cfgs['efficientnet_b1'] # NOTE for train, drop_rate should be 0.2 - #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg + kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg model = _gen_efficientnet( channel_multiplier=1.0, depth_multiplier=1.1, num_classes=num_classes, in_chans=in_chans, **kwargs) @@ -1445,7 +1445,7 @@ def efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ EfficientNet-B2 """ default_cfg = default_cfgs['efficientnet_b2'] # NOTE for train, drop_rate should be 0.3 - #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg + kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg model = _gen_efficientnet( channel_multiplier=1.1, depth_multiplier=1.2, num_classes=num_classes, in_chans=in_chans, **kwargs) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index b7de304a..1deff273 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -28,8 +28,9 @@ def load_checkpoint(model, checkpoint_path, use_ema=False): raise FileNotFoundError() -def resume_checkpoint(model, checkpoint_path, start_epoch=None): +def resume_checkpoint(model, checkpoint_path): optimizer_state = None + resume_epoch = None if os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: @@ -40,13 +41,15 @@ def resume_checkpoint(model, checkpoint_path, start_epoch=None): model.load_state_dict(new_state_dict) if 'optimizer' in checkpoint: optimizer_state = checkpoint['optimizer'] - start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch + if 'epoch' in checkpoint: + resume_epoch = checkpoint['epoch'] + if 'version' in checkpoint and checkpoint['version'] > 1: + resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save logging.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) else: model.load_state_dict(checkpoint) - start_epoch = 0 if start_epoch is None else start_epoch logging.info("Loaded checkpoint '{}'".format(checkpoint_path)) - return optimizer_state, start_epoch + return optimizer_state, resume_epoch else: logging.error("No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() diff --git a/timm/scheduler/scheduler.py b/timm/scheduler/scheduler.py index 59fcfc16..78e8460d 100644 --- a/timm/scheduler/scheduler.py +++ b/timm/scheduler/scheduler.py @@ -56,7 +56,7 @@ class Scheduler: def step(self, epoch: int, metric: float = None) -> None: self.metric = metric - values = self.get_epoch_values(epoch + 1) # +1 to calculate for next epoch + values = self.get_epoch_values(epoch) if values is not None: self.update_groups(values) diff --git a/timm/utils.py b/timm/utils.py index 8d4418a6..36355c2b 100644 --- a/timm/utils.py +++ b/timm/utils.py @@ -83,7 +83,8 @@ class CheckpointSaver: 'arch': args.model, 'state_dict': get_state_dict(model), 'optimizer': optimizer.state_dict(), - 'args': args + 'args': args, + 'version': 2, # version < 2 increments epoch before save } if model_ema is not None: save_state['state_dict_ema'] = get_state_dict(model_ema) diff --git a/train.py b/train.py index 7196305f..f7ecdd5d 100644 --- a/train.py +++ b/train.py @@ -27,22 +27,21 @@ import torchvision.utils torch.backends.cudnn.benchmark = True parser = argparse.ArgumentParser(description='Training') +# Dataset / Model parameters parser.add_argument('data', metavar='DIR', help='path to dataset') parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL', help='Name of model to train (default: "countception"') +parser.add_argument('--pretrained', action='store_true', default=False, + help='Start with pretrained version of specified network (if avail)') +parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', + help='Initialize model from this checkpoint (default: none)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='Resume full model and optimizer state from checkpoint (default: none)') parser.add_argument('--num-classes', type=int, default=1000, metavar='N', help='number of label classes (default: 1000)') -parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', - help='Optimizer (default: "sgd"') -parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', - help='Optimizer Epsilon (default: 1e-8)') parser.add_argument('--gp', default='avg', type=str, metavar='POOL', help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")') -parser.add_argument('--tta', type=int, default=0, metavar='N', - help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') -parser.add_argument('--pretrained', action='store_true', default=False, - help='Start with pretrained version of specified network (if avail)') parser.add_argument('--img-size', type=int, default=None, metavar='N', help='Image patch size (default: None => model default)') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', @@ -53,8 +52,24 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N', help='input batch size for training (default: 32)') -parser.add_argument('-s', '--initial-batch-size', type=int, default=0, metavar='N', - help='initial input batch size for training (default: 0)') +parser.add_argument('--drop', type=float, default=0.0, metavar='DROP', + help='Dropout rate (default: 0.)') +# Optimizer parameters +parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', + help='Optimizer (default: "sgd"') +parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', + help='Optimizer Epsilon (default: 1e-8)') +parser.add_argument('--momentum', type=float, default=0.9, metavar='M', + help='SGD momentum (default: 0.9)') +parser.add_argument('--weight-decay', type=float, default=0.0001, + help='weight decay (default: 0.0001)') +# Learning rate schedule parameters +parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', + help='LR scheduler (default: "step"') +parser.add_argument('--lr', type=float, default=0.01, metavar='LR', + help='learning rate (default: 0.01)') +parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', + help='warmup learning rate (default: 0.0001)') parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train (default: 2)') parser.add_argument('--start-epoch', default=None, type=int, metavar='N', @@ -65,40 +80,34 @@ parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', help='epochs to warmup LR, if scheduler supports') parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') -parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', - help='LR scheduler (default: "step"') -parser.add_argument('--drop', type=float, default=0.0, metavar='DROP', - help='Dropout rate (default: 0.)') +# Augmentation parameters +parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', + help='Color jitter factor (default: 0.4)') parser.add_argument('--reprob', type=float, default=0., metavar='PCT', help='Random erase prob (default: 0.)') parser.add_argument('--remode', type=str, default='const', help='Random erase mode (default: "const")') -parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help='learning rate (default: 0.01)') -parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', - help='warmup learning rate (default: 0.0001)') -parser.add_argument('--momentum', type=float, default=0.9, metavar='M', - help='SGD momentum (default: 0.9)') -parser.add_argument('--weight-decay', type=float, default=0.0001, - help='weight decay (default: 0.0001)') parser.add_argument('--mixup', type=float, default=0.0, help='mixup alpha, mixup enabled if > 0. (default: 0.)') parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', help='turn off mixup after this epoch, disabled if 0 (default: 0)') parser.add_argument('--smoothing', type=float, default=0.1, help='label smoothing (default: 0.1)') +# Batch norm parameters (only works with gen_efficientnet based models currently) parser.add_argument('--bn-tf', action='store_true', default=False, help='Use Tensorflow BatchNorm defaults for models that support it (default: False)') parser.add_argument('--bn-momentum', type=float, default=None, help='BatchNorm momentum override (if not None)') parser.add_argument('--bn-eps', type=float, default=None, help='BatchNorm epsilon override (if not None)') +# Model Exponential Moving Average parser.add_argument('--model-ema', action='store_true', default=False, help='Enable tracking moving average of model weights') parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') parser.add_argument('--model-ema-decay', type=float, default=0.9998, help='decay factor for model weights moving average (default: 0.9998)') +# Misc parser.add_argument('--seed', type=int, default=42, metavar='S', help='random seed (default: 42)') parser.add_argument('--log-interval', type=int, default=50, metavar='N', @@ -109,10 +118,6 @@ parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', help='how many training processes to use (default: 1)') parser.add_argument('--num-gpu', type=int, default=1, help='Number of GPUS to use') -parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', - help='path to init checkpoint (default: none)') -parser.add_argument('--resume', default='', type=str, metavar='PATH', - help='path to latest checkpoint (default: none)') parser.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') parser.add_argument('--amp', action='store_true', default=False, @@ -125,6 +130,8 @@ parser.add_argument('--output', default='', type=str, metavar='PATH', help='path to output folder (default: none, current dir)') parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC', help='Best metric (default: "prec1"') +parser.add_argument('--tta', type=int, default=0, metavar='N', + help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') parser.add_argument("--local_rank", default=0, type=int) @@ -174,13 +181,13 @@ def main(): logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) - data_config = resolve_data_config(model, args, verbose=args.local_rank == 0) + data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # optionally resume from a checkpoint - start_epoch = 0 optimizer_state = None + resume_epoch = None if args.resume: - optimizer_state, start_epoch = resume_checkpoint(model, args.resume, args.start_epoch) + optimizer_state, resume_epoch = resume_checkpoint(model, args.resume) if args.num_gpu > 1: if args.amp: @@ -232,8 +239,15 @@ def main(): # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) + start_epoch = 0 + if args.start_epoch is not None: + # a specified start_epoch will always override the resume epoch + start_epoch = args.start_epoch + elif resume_epoch is not None: + start_epoch = resume_epoch if start_epoch > 0: lr_scheduler.step(start_epoch) + if args.local_rank == 0: logging.info('Scheduled epochs: {}'.format(num_epochs)) @@ -255,6 +269,7 @@ def main(): use_prefetcher=args.prefetcher, rand_erase_prob=args.reprob, rand_erase_mode=args.remode, + color_jitter=args.color_jitter, interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], @@ -327,7 +342,8 @@ def main(): eval_metrics = ema_eval_metrics if lr_scheduler is not None: - lr_scheduler.step(epoch, eval_metrics[eval_metric]) + # step LR for next epoch + lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), @@ -338,9 +354,7 @@ def main(): save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, - epoch=epoch + 1, - model_ema=model_ema, - metric=save_metric) + epoch=epoch, model_ema=model_ema, metric=save_metric) except KeyboardInterrupt: pass @@ -433,9 +447,8 @@ def train_epoch( if saver is not None and args.recovery_interval and ( last_batch or (batch_idx + 1) % args.recovery_interval == 0): - save_epoch = epoch + 1 if last_batch else epoch saver.save_recovery( - model, optimizer, args, save_epoch, model_ema=model_ema, batch_idx=batch_idx) + model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) diff --git a/validate.py b/validate.py index 199888ad..280bc260 100644 --- a/validate.py +++ b/validate.py @@ -71,7 +71,7 @@ def validate(args): param_count = sum([m.numel() for m in model.parameters()]) logging.info('Model %s created, param count: %d' % (args.model, param_count)) - data_config = resolve_data_config(model, args) + data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = apply_test_time_pool(model, data_config, args) if args.num_gpu > 1: From 65a634626fde0d38e25e8f801bd4926aebbbfb06 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 29 Jun 2019 10:03:13 -0700 Subject: [PATCH 2/6] Switch random erasing to doing normal_() on CPU to avoid instability, remove a debug print --- timm/data/random_erasing.py | 7 ++++--- timm/data/transforms.py | 1 - 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/timm/data/random_erasing.py b/timm/data/random_erasing.py index c16725ae..e66f7b95 100644 --- a/timm/data/random_erasing.py +++ b/timm/data/random_erasing.py @@ -6,12 +6,13 @@ import torch def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'): # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() # paths, flip the order so normal is run on CPU if this becomes a problem - # ie torch.empty(patch_size, dtype=dtype).normal_().to(device=device) + # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 + # will revert back to doing normal_() on GPU when it's in next release if per_pixel: return torch.empty( - patch_size, dtype=dtype, device=device).normal_() + patch_size, dtype=dtype).normal_().to(device=device) elif rand_color: - return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() + return torch.empty((patch_size[0], 1, 1), dtype=dtype).normal_().to(device=device) else: return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) diff --git a/timm/data/transforms.py b/timm/data/transforms.py index 1e1b054a..13a6ff01 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -171,7 +171,6 @@ def transforms_imagenet_train( else: # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue color_jitter = (float(color_jitter),) * 3 - print(*color_jitter) tfl = [ RandomResizedCropAndInterpolation( From b8762cc67d79464ae3105fecf6e6f35ebe8ae230 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 29 Jun 2019 15:37:42 -0700 Subject: [PATCH 3/6] Model updates. Add my best ResNet50 weights top-1=78.47. Add some other torchvision weights. * Remove some models that don't exist as pretrained an likely never will (se)resnext152 * Add some torchvision weights as tv_ for models that I have added better weights for * Add wide resnet recently added to torchvision along with resnext101-32x8d * Add functionality to model registry to allow filtering on pretrained weight presence --- README.md | 1 + timm/models/dpn.py | 4 +- timm/models/gluon_resnet.py | 30 ------------ timm/models/registry.py | 34 +++++++++---- timm/models/resnet.py | 95 +++++++++++++++++++++++++++++++++---- 5 files changed, 114 insertions(+), 50 deletions(-) diff --git a/README.md b/README.md index 64d8a6f1..ac2fbdf6 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,7 @@ I've leveraged the training scripts in this repository to train a few of the mod |Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | |---|---|---|---|---| | resnext50_32x4d | 78.512 (21.488) | 94.042 (5.958) | 25M | bicubic | +| resnet50 | 78.470 (21.530) | 94.266 (5.734) | 25.6M | bicubic | | seresnext26_32x4d | 77.104 (22.896) | 93.316 (6.684) | 16.8M | bicubic | | efficientnet_b0 | 76.912 (23.088) | 93.210 (6.790) | 5.29M | bicubic | | mobilenetv3_100 | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic | diff --git a/timm/models/dpn.py b/timm/models/dpn.py index 76b59ca2..92bc7855 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -35,9 +35,9 @@ def _cfg(url=''): default_cfgs = { 'dpn68': _cfg( url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'), - 'dpn68b_extra': _cfg( + 'dpn68b': _cfg( url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68b_extra-84854c156.pth'), - 'dpn92_extra': _cfg( + 'dpn92': _cfg( url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn92_extra-b040e4a9b.pth'), 'dpn98': _cfg( url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn98-5b90dec4d.pth'), diff --git a/timm/models/gluon_resnet.py b/timm/models/gluon_resnet.py index c5d0634f..715e0950 100644 --- a/timm/models/gluon_resnet.py +++ b/timm/models/gluon_resnet.py @@ -50,11 +50,9 @@ default_cfgs = { 'gluon_resnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext50_32x4d-e6a097c1.pth'), 'gluon_resnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_32x4d-b253c8c4.pth'), 'gluon_resnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_64x4d-f9a8e184.pth'), - 'gluon_resnext152_32x4d': _cfg(url=''), 'gluon_seresnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext50_32x4d-90cf2d6e.pth'), 'gluon_seresnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_32x4d-cf52900d.pth'), 'gluon_seresnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_64x4d-f9926f93.pth'), - 'gluon_seresnext152_32x4d': _cfg(url=''), 'gluon_senet154': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth'), } @@ -617,20 +615,6 @@ def gluon_resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwa return model -@register_model -def gluon_resnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a ResNeXt152-32x4d model. - """ - default_cfg = default_cfgs['gluon_resnext152_32x4d'] - model = GluonResNet( - BottleneckGl, [3, 8, 36, 3], cardinality=32, base_width=4, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - @register_model def gluon_seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a SEResNeXt50-32x4d model. @@ -673,20 +657,6 @@ def gluon_seresnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **k return model -@register_model -def gluon_seresnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a SEResNeXt152-32x4d model. - """ - default_cfg = default_cfgs['gluon_seresnext152_32x4d'] - model = GluonResNet( - BottleneckGl, [3, 8, 36, 3], cardinality=32, base_width=4, use_se=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - #if pretrained: - # load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - @register_model def gluon_senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs an SENet-154 model. diff --git a/timm/models/registry.py b/timm/models/registry.py index 45bc1809..c15f5414 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -5,22 +5,36 @@ from collections import defaultdict __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules'] -_module_to_models = defaultdict(set) -_model_to_module = {} -_model_entrypoints = {} +_module_to_models = defaultdict(set) # dict of sets to check membership of model in module +_model_to_module = {} # mapping of model names to module names +_model_entrypoints = {} # mapping of model names to entrypoint fns +_model_has_pretrained = set() # set of model names that have pretrained weight url present def register_model(fn): + # lookup containing module mod = sys.modules[fn.__module__] module_name_split = fn.__module__.split('.') module_name = module_name_split[-1] if len(module_name_split) else '' + + # add model to __all__ in module + model_name = fn.__name__ if hasattr(mod, '__all__'): - mod.__all__.append(fn.__name__) + mod.__all__.append(model_name) else: - mod.__all__ = [fn.__name__] - _model_entrypoints[fn.__name__] = fn - _model_to_module[fn.__name__] = module_name - _module_to_models[module_name].add(fn.__name__) + mod.__all__ = [model_name] + + # add entries to registry dict/sets + _model_entrypoints[model_name] = fn + _model_to_module[model_name] = module_name + _module_to_models[module_name].add(model_name) + has_pretrained = False # check if model has a pretrained url to allow filtering on this + if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: + # this will catch all models that have entrypoint matching cfg key, but miss any aliasing + # entrypoints or non-matching combos + has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] + if has_pretrained: + _model_has_pretrained.add(model_name) return fn @@ -28,7 +42,7 @@ def _natural_key(string_): return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] -def list_models(filter='', module=''): +def list_models(filter='', module='', pretrained=False): """ Return list of available model names, sorted alphabetically Args: @@ -45,6 +59,8 @@ def list_models(filter='', module=''): models = _model_entrypoints.keys() if filter: models = fnmatch.filter(models, filter) + if pretrained: + models = _model_has_pretrained.intersection(models) return list(sorted(models, key=_natural_key)) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 7ed3b2e1..9a4b22cd 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -33,14 +33,22 @@ default_cfgs = { 'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'), 'resnet34': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'), - 'resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'), + 'resnet50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/rw_resnet50-86acaeed.pth', + interpolation='bicubic'), 'resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'), 'resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'), - 'resnext50_32x4d': _cfg(url='https://www.dropbox.com/s/yxci33lfew51p6a/resnext50_32x4d-068914d1.pth?dl=1', - interpolation='bicubic'), + 'tv_resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'), + 'tv_resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'), + 'wide_resnet50_2': _cfg(url='https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth'), + 'wide_resnet101_2': _cfg(url='https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth'), + 'resnext50_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50_32x4d-068914d1.pth', + interpolation='bicubic'), 'resnext101_32x4d': _cfg(url=''), + 'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'), 'resnext101_64x4d': _cfg(url=''), - 'resnext152_32x4d': _cfg(url=''), + 'tv_resnext50_32x4d': _cfg(url='https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth'), 'ig_resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth'), 'ig_resnext101_32x16d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth'), 'ig_resnext101_32x32d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth'), @@ -285,6 +293,61 @@ def resnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model +def tv_resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-34 model with original Torchvision weights. + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfgs['tv_resnet34'] + if pretrained: + load_pretrained(model, model.default_cfg, num_classes, in_chans) + return model + + +@register_model +def tv_resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-50 model with original Torchvision weights. + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfgs['tv_resnet50'] + if pretrained: + load_pretrained(model, model.default_cfg, num_classes, in_chans) + return model + + +@register_model +def wide_resnet50_2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a Wide ResNet-50-2 model. + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + """ + model = ResNet( + Bottleneck, [3, 4, 6, 3], base_width=128, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfgs['wide_resnet50_2'] + if pretrained: + load_pretrained(model, model.default_cfg, num_classes, in_chans) + return model + + +@register_model +def wide_resnet101_2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a Wide ResNet-100-2 model. + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same. + """ + model = ResNet( + Bottleneck, [3, 4, 23, 3], base_width=128, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfgs['wide_resnet101_2'] + if pretrained: + load_pretrained(model, model.default_cfg, num_classes, in_chans) + return model + + @register_model def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNeXt50-32x4d model. @@ -301,7 +364,7 @@ def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): @register_model def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a ResNeXt-101 model. + """Constructs a ResNeXt-101 32x4d model. """ default_cfg = default_cfgs['resnext101_32x4d'] model = ResNet( @@ -313,6 +376,20 @@ def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model +def resnext101_32x8d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNeXt-101 32x8d model. + """ + default_cfg = default_cfgs['resnext101_32x8d'] + model = ResNet( + Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + @register_model def resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNeXt101-64x4d model. @@ -328,12 +405,12 @@ def resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): @register_model -def resnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a ResNeXt152-32x4d model. +def tv_resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNeXt50-32x4d model with original Torchvision weights. """ - default_cfg = default_cfgs['resnext152_32x4d'] + default_cfg = default_cfgs['tv_resnext50_32x4d'] model = ResNet( - Bottleneck, [3, 8, 36, 3], cardinality=32, base_width=4, + Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: From c3287aafb3b226f291938df2ed6081a21931c329 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 29 Jun 2019 16:17:06 -0700 Subject: [PATCH 4/6] Slight improvement in EfficientNet-B2 native PyTorch weights --- README.md | 2 +- timm/models/gen_efficientnet.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ac2fbdf6..194a6476 100644 --- a/README.md +++ b/README.md @@ -87,7 +87,7 @@ I've leveraged the training scripts in this repository to train a few of the mod #### @ 260x260 |Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | |---|---|---|---|---| -| efficientnet_b2 | 79.668 (20.332) | 94.634 (5.366) | 9.11M | bicubic | +| efficientnet_b2 | 79.760 (20.240) | 94.714 (5.286) | 9.11M | bicubic | ### Ported Weights diff --git a/timm/models/gen_efficientnet.py b/timm/models/gen_efficientnet.py index 0642a1cb..9a0a6dd4 100644 --- a/timm/models/gen_efficientnet.py +++ b/timm/models/gen_efficientnet.py @@ -84,7 +84,7 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', input_size=(3, 240, 240), pool_size=(8, 8), interpolation='bicubic', crop_pct=0.882), 'efficientnet_b2': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-d4105846.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth', input_size=(3, 260, 260), pool_size=(9, 9), interpolation='bicubic', crop_pct=0.890), 'efficientnet_b3': _cfg( url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), From 188aeae8f48dc8888f234b87efc390ed373d071b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 29 Jun 2019 16:17:54 -0700 Subject: [PATCH 5/6] Bump version 0.1.4 --- timm/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/version.py b/timm/version.py index 10939f01..7525d199 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.1.2' +__version__ = '0.1.4' From 9b0070edc91177f048d2d0ab0717f4c5579427d1 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 29 Jun 2019 16:44:25 -0700 Subject: [PATCH 6/6] Add two comments back, fix typo --- timm/models/gen_efficientnet.py | 4 ++-- timm/models/resnet.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/timm/models/gen_efficientnet.py b/timm/models/gen_efficientnet.py index 9a0a6dd4..2541bc6b 100644 --- a/timm/models/gen_efficientnet.py +++ b/timm/models/gen_efficientnet.py @@ -1430,7 +1430,7 @@ def efficientnet_b1(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ EfficientNet-B1 """ default_cfg = default_cfgs['efficientnet_b1'] # NOTE for train, drop_rate should be 0.2 - kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg + #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg model = _gen_efficientnet( channel_multiplier=1.0, depth_multiplier=1.1, num_classes=num_classes, in_chans=in_chans, **kwargs) @@ -1445,7 +1445,7 @@ def efficientnet_b2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """ EfficientNet-B2 """ default_cfg = default_cfgs['efficientnet_b2'] # NOTE for train, drop_rate should be 0.3 - kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg + #kwargs['drop_connect_rate'] = 0.2 # set when training, TODO add as cmd arg model = _gen_efficientnet( channel_multiplier=1.1, depth_multiplier=1.2, num_classes=num_classes, in_chans=in_chans, **kwargs) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 9a4b22cd..32ff3acf 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -334,7 +334,7 @@ def wide_resnet50_2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): @register_model def wide_resnet101_2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a Wide ResNet-100-2 model. + """Constructs a Wide ResNet-101-2 model. The model is the same as ResNet except for the bottleneck number of channels which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same.