diff --git a/README.md b/README.md index 64d8a6f1..194a6476 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,7 @@ I've leveraged the training scripts in this repository to train a few of the mod |Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | |---|---|---|---|---| | resnext50_32x4d | 78.512 (21.488) | 94.042 (5.958) | 25M | bicubic | +| resnet50 | 78.470 (21.530) | 94.266 (5.734) | 25.6M | bicubic | | seresnext26_32x4d | 77.104 (22.896) | 93.316 (6.684) | 16.8M | bicubic | | efficientnet_b0 | 76.912 (23.088) | 93.210 (6.790) | 5.29M | bicubic | | mobilenetv3_100 | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic | @@ -86,7 +87,7 @@ I've leveraged the training scripts in this repository to train a few of the mod #### @ 260x260 |Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | |---|---|---|---|---| -| efficientnet_b2 | 79.668 (20.332) | 94.634 (5.366) | 9.11M | bicubic | +| efficientnet_b2 | 79.760 (20.240) | 94.714 (5.286) | 9.11M | bicubic | ### Ported Weights diff --git a/inference.py b/inference.py index 9077cc07..3255a8d9 100644 --- a/inference.py +++ b/inference.py @@ -70,7 +70,7 @@ def main(): logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) - config = resolve_data_config(model, args) + config = resolve_data_config(vars(args), model=model) model, test_time_pool = apply_test_time_pool(model, config, args) if args.num_gpu > 1: diff --git a/timm/data/config.py b/timm/data/config.py index 1675d2a9..8a83d19f 100644 --- a/timm/data/config.py +++ b/timm/data/config.py @@ -2,35 +2,43 @@ import logging from .constants import * -def resolve_data_config(model, args, default_cfg={}, verbose=True): +def resolve_data_config(args, default_cfg={}, model=None, verbose=True): new_config = {} default_cfg = default_cfg - if not default_cfg and hasattr(model, 'default_cfg'): + if not default_cfg and model is not None and hasattr(model, 'default_cfg'): default_cfg = model.default_cfg # Resolve input/image size - # FIXME grayscale/chans arg to use different # channels? in_chans = 3 + if 'chans' in args and args['chans'] is not None: + in_chans = args['chans'] + input_size = (in_chans, 224, 224) - if args.img_size is not None: - # FIXME support passing img_size as tuple, non-square - assert isinstance(args.img_size, int) - input_size = (in_chans, args.img_size, args.img_size) + if 'input_size' in args and args['input_size'] is not None: + assert isinstance(args['input_size'], (tuple, list)) + assert len(args['input_size']) == 3 + input_size = tuple(args['input_size']) + in_chans = input_size[0] # input_size overrides in_chans + elif 'img_size' in args and args['img_size'] is not None: + assert isinstance(args['img_size'], int) + input_size = (in_chans, args['img_size'], args['img_size']) elif 'input_size' in default_cfg: input_size = default_cfg['input_size'] new_config['input_size'] = input_size # resolve interpolation method - new_config['interpolation'] = 'bilinear' - if args.interpolation: - new_config['interpolation'] = args.interpolation + new_config['interpolation'] = 'bicubic' + if 'interpolation' in args and args['interpolation']: + new_config['interpolation'] = args['interpolation'] elif 'interpolation' in default_cfg: new_config['interpolation'] = default_cfg['interpolation'] # resolve dataset + model mean for normalization - new_config['mean'] = get_mean_by_model(args.model) - if args.mean is not None: - mean = tuple(args.mean) + new_config['mean'] = IMAGENET_DEFAULT_MEAN + if 'model' in args: + new_config['mean'] = get_mean_by_model(args['model']) + if 'mean' in args and args['mean'] is not None: + mean = tuple(args['mean']) if len(mean) == 1: mean = tuple(list(mean) * in_chans) else: @@ -40,9 +48,11 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True): new_config['mean'] = default_cfg['mean'] # resolve dataset + model std deviation for normalization - new_config['std'] = get_std_by_model(args.model) - if args.std is not None: - std = tuple(args.std) + new_config['std'] = IMAGENET_DEFAULT_STD + if 'model' in args: + new_config['std'] = get_std_by_model(args['model']) + if 'std' in args and args['std'] is not None: + std = tuple(args['std']) if len(std) == 1: std = tuple(list(std) * in_chans) else: @@ -53,7 +63,9 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True): # resolve default crop percentage new_config['crop_pct'] = DEFAULT_CROP_PCT - if 'crop_pct' in default_cfg: + if 'crop_pct' in args and args['crop_pct'] is not None: + new_config['crop_pct'] = args['crop_pct'] + elif 'crop_pct' in default_cfg: new_config['crop_pct'] = default_cfg['crop_pct'] if verbose: @@ -64,29 +76,11 @@ def resolve_data_config(model, args, default_cfg={}, verbose=True): return new_config -def get_mean_by_name(name): - if name == 'dpn': - return IMAGENET_DPN_MEAN - elif name == 'inception' or name == 'le': - return IMAGENET_INCEPTION_MEAN - else: - return IMAGENET_DEFAULT_MEAN - - -def get_std_by_name(name): - if name == 'dpn': - return IMAGENET_DPN_STD - elif name == 'inception' or name == 'le': - return IMAGENET_INCEPTION_STD - else: - return IMAGENET_DEFAULT_STD - - def get_mean_by_model(model_name): model_name = model_name.lower() if 'dpn' in model_name: return IMAGENET_DPN_STD - elif 'ception' in model_name or 'nasnet' in model_name: + elif 'ception' in model_name or ('nasnet' in model_name and 'mnasnet' not in model_name): return IMAGENET_INCEPTION_MEAN else: return IMAGENET_DEFAULT_MEAN @@ -96,7 +90,7 @@ def get_std_by_model(model_name): model_name = model_name.lower() if 'dpn' in model_name: return IMAGENET_DEFAULT_STD - elif 'ception' in model_name or 'nasnet' in model_name: + elif 'ception' in model_name or ('nasnet' in model_name and 'mnasnet' not in model_name): return IMAGENET_INCEPTION_STD else: return IMAGENET_DEFAULT_STD diff --git a/timm/data/loader.py b/timm/data/loader.py index 777eb878..6a19b805 100644 --- a/timm/data/loader.py +++ b/timm/data/loader.py @@ -86,6 +86,7 @@ def create_loader( use_prefetcher=True, rand_erase_prob=0., rand_erase_mode='const', + color_jitter=0.4, interpolation='bilinear', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, @@ -107,6 +108,7 @@ def create_loader( if is_training: transform = transforms_imagenet_train( img_size, + color_jitter=color_jitter, interpolation=interpolation, use_prefetcher=use_prefetcher, mean=mean, diff --git a/timm/data/random_erasing.py b/timm/data/random_erasing.py index c16725ae..e66f7b95 100644 --- a/timm/data/random_erasing.py +++ b/timm/data/random_erasing.py @@ -6,12 +6,13 @@ import torch def _get_pixels(per_pixel, rand_color, patch_size, dtype=torch.float32, device='cuda'): # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() # paths, flip the order so normal is run on CPU if this becomes a problem - # ie torch.empty(patch_size, dtype=dtype).normal_().to(device=device) + # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 + # will revert back to doing normal_() on GPU when it's in next release if per_pixel: return torch.empty( - patch_size, dtype=dtype, device=device).normal_() + patch_size, dtype=dtype).normal_().to(device=device) elif rand_color: - return torch.empty((patch_size[0], 1, 1), dtype=dtype, device=device).normal_() + return torch.empty((patch_size[0], 1, 1), dtype=dtype).normal_().to(device=device) else: return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) diff --git a/timm/data/transforms.py b/timm/data/transforms.py index bee505a2..13a6ff01 100644 --- a/timm/data/transforms.py +++ b/timm/data/transforms.py @@ -156,7 +156,7 @@ class RandomResizedCropAndInterpolation(object): def transforms_imagenet_train( img_size=224, scale=(0.08, 1.0), - color_jitter=(0.4, 0.4, 0.4), + color_jitter=0.4, interpolation='random', random_erasing=0.4, random_erasing_mode='const', @@ -164,6 +164,13 @@ def transforms_imagenet_train( mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD ): + if isinstance(color_jitter, (list, tuple)): + # color jitter should be a 3-tuple/list if spec brightness/contrast/saturation + # or 4 if also augmenting hue + assert len(color_jitter) in (3, 4) + else: + # if it's a scalar, duplicate for brightness, contrast, and saturation, no hue + color_jitter = (float(color_jitter),) * 3 tfl = [ RandomResizedCropAndInterpolation( diff --git a/timm/models/dpn.py b/timm/models/dpn.py index 76b59ca2..92bc7855 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -35,9 +35,9 @@ def _cfg(url=''): default_cfgs = { 'dpn68': _cfg( url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'), - 'dpn68b_extra': _cfg( + 'dpn68b': _cfg( url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68b_extra-84854c156.pth'), - 'dpn92_extra': _cfg( + 'dpn92': _cfg( url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn92_extra-b040e4a9b.pth'), 'dpn98': _cfg( url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn98-5b90dec4d.pth'), diff --git a/timm/models/gen_efficientnet.py b/timm/models/gen_efficientnet.py index 1f5890bc..2541bc6b 100644 --- a/timm/models/gen_efficientnet.py +++ b/timm/models/gen_efficientnet.py @@ -84,7 +84,7 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', input_size=(3, 240, 240), pool_size=(8, 8), interpolation='bicubic', crop_pct=0.882), 'efficientnet_b2': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-d4105846.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2-cf78dc4d.pth', input_size=(3, 260, 260), pool_size=(9, 9), interpolation='bicubic', crop_pct=0.890), 'efficientnet_b3': _cfg( url='', input_size=(3, 300, 300), pool_size=(10, 10), crop_pct=0.904), diff --git a/timm/models/gluon_resnet.py b/timm/models/gluon_resnet.py index c5d0634f..715e0950 100644 --- a/timm/models/gluon_resnet.py +++ b/timm/models/gluon_resnet.py @@ -50,11 +50,9 @@ default_cfgs = { 'gluon_resnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext50_32x4d-e6a097c1.pth'), 'gluon_resnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_32x4d-b253c8c4.pth'), 'gluon_resnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_resnext101_64x4d-f9a8e184.pth'), - 'gluon_resnext152_32x4d': _cfg(url=''), 'gluon_seresnext50_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext50_32x4d-90cf2d6e.pth'), 'gluon_seresnext101_32x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_32x4d-cf52900d.pth'), 'gluon_seresnext101_64x4d': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_seresnext101_64x4d-f9926f93.pth'), - 'gluon_seresnext152_32x4d': _cfg(url=''), 'gluon_senet154': _cfg(url='https://github.com/rwightman/pytorch-pretrained-gluonresnet/releases/download/v0.1/gluon_senet154-70a1a3c0.pth'), } @@ -617,20 +615,6 @@ def gluon_resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwa return model -@register_model -def gluon_resnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a ResNeXt152-32x4d model. - """ - default_cfg = default_cfgs['gluon_resnext152_32x4d'] - model = GluonResNet( - BottleneckGl, [3, 8, 36, 3], cardinality=32, base_width=4, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - if pretrained: - load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - @register_model def gluon_seresnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a SEResNeXt50-32x4d model. @@ -673,20 +657,6 @@ def gluon_seresnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **k return model -@register_model -def gluon_seresnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a SEResNeXt152-32x4d model. - """ - default_cfg = default_cfgs['gluon_seresnext152_32x4d'] - model = GluonResNet( - BottleneckGl, [3, 8, 36, 3], cardinality=32, base_width=4, use_se=True, - num_classes=num_classes, in_chans=in_chans, **kwargs) - model.default_cfg = default_cfg - #if pretrained: - # load_pretrained(model, default_cfg, num_classes, in_chans) - return model - - @register_model def gluon_senet154(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs an SENet-154 model. diff --git a/timm/models/helpers.py b/timm/models/helpers.py index b7de304a..1deff273 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -28,8 +28,9 @@ def load_checkpoint(model, checkpoint_path, use_ema=False): raise FileNotFoundError() -def resume_checkpoint(model, checkpoint_path, start_epoch=None): +def resume_checkpoint(model, checkpoint_path): optimizer_state = None + resume_epoch = None if os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: @@ -40,13 +41,15 @@ def resume_checkpoint(model, checkpoint_path, start_epoch=None): model.load_state_dict(new_state_dict) if 'optimizer' in checkpoint: optimizer_state = checkpoint['optimizer'] - start_epoch = checkpoint['epoch'] if start_epoch is None else start_epoch + if 'epoch' in checkpoint: + resume_epoch = checkpoint['epoch'] + if 'version' in checkpoint and checkpoint['version'] > 1: + resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save logging.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch'])) else: model.load_state_dict(checkpoint) - start_epoch = 0 if start_epoch is None else start_epoch logging.info("Loaded checkpoint '{}'".format(checkpoint_path)) - return optimizer_state, start_epoch + return optimizer_state, resume_epoch else: logging.error("No checkpoint found at '{}'".format(checkpoint_path)) raise FileNotFoundError() diff --git a/timm/models/registry.py b/timm/models/registry.py index 45bc1809..c15f5414 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -5,22 +5,36 @@ from collections import defaultdict __all__ = ['list_models', 'is_model', 'model_entrypoint', 'list_modules', 'is_model_in_modules'] -_module_to_models = defaultdict(set) -_model_to_module = {} -_model_entrypoints = {} +_module_to_models = defaultdict(set) # dict of sets to check membership of model in module +_model_to_module = {} # mapping of model names to module names +_model_entrypoints = {} # mapping of model names to entrypoint fns +_model_has_pretrained = set() # set of model names that have pretrained weight url present def register_model(fn): + # lookup containing module mod = sys.modules[fn.__module__] module_name_split = fn.__module__.split('.') module_name = module_name_split[-1] if len(module_name_split) else '' + + # add model to __all__ in module + model_name = fn.__name__ if hasattr(mod, '__all__'): - mod.__all__.append(fn.__name__) + mod.__all__.append(model_name) else: - mod.__all__ = [fn.__name__] - _model_entrypoints[fn.__name__] = fn - _model_to_module[fn.__name__] = module_name - _module_to_models[module_name].add(fn.__name__) + mod.__all__ = [model_name] + + # add entries to registry dict/sets + _model_entrypoints[model_name] = fn + _model_to_module[model_name] = module_name + _module_to_models[module_name].add(model_name) + has_pretrained = False # check if model has a pretrained url to allow filtering on this + if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs: + # this will catch all models that have entrypoint matching cfg key, but miss any aliasing + # entrypoints or non-matching combos + has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] + if has_pretrained: + _model_has_pretrained.add(model_name) return fn @@ -28,7 +42,7 @@ def _natural_key(string_): return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] -def list_models(filter='', module=''): +def list_models(filter='', module='', pretrained=False): """ Return list of available model names, sorted alphabetically Args: @@ -45,6 +59,8 @@ def list_models(filter='', module=''): models = _model_entrypoints.keys() if filter: models = fnmatch.filter(models, filter) + if pretrained: + models = _model_has_pretrained.intersection(models) return list(sorted(models, key=_natural_key)) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 7ed3b2e1..32ff3acf 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -33,14 +33,22 @@ default_cfgs = { 'resnet18': _cfg(url='https://download.pytorch.org/models/resnet18-5c106cde.pth'), 'resnet34': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth'), - 'resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'), + 'resnet50': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/rw_resnet50-86acaeed.pth', + interpolation='bicubic'), 'resnet101': _cfg(url='https://download.pytorch.org/models/resnet101-5d3b4d8f.pth'), 'resnet152': _cfg(url='https://download.pytorch.org/models/resnet152-b121ed2d.pth'), - 'resnext50_32x4d': _cfg(url='https://www.dropbox.com/s/yxci33lfew51p6a/resnext50_32x4d-068914d1.pth?dl=1', - interpolation='bicubic'), + 'tv_resnet34': _cfg(url='https://download.pytorch.org/models/resnet34-333f7ec4.pth'), + 'tv_resnet50': _cfg(url='https://download.pytorch.org/models/resnet50-19c8e357.pth'), + 'wide_resnet50_2': _cfg(url='https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth'), + 'wide_resnet101_2': _cfg(url='https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth'), + 'resnext50_32x4d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnext50_32x4d-068914d1.pth', + interpolation='bicubic'), 'resnext101_32x4d': _cfg(url=''), + 'resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth'), 'resnext101_64x4d': _cfg(url=''), - 'resnext152_32x4d': _cfg(url=''), + 'tv_resnext50_32x4d': _cfg(url='https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth'), 'ig_resnext101_32x8d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth'), 'ig_resnext101_32x16d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth'), 'ig_resnext101_32x32d': _cfg(url='https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth'), @@ -285,6 +293,61 @@ def resnet152(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model +def tv_resnet34(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-34 model with original Torchvision weights. + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfgs['tv_resnet34'] + if pretrained: + load_pretrained(model, model.default_cfg, num_classes, in_chans) + return model + + +@register_model +def tv_resnet50(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNet-50 model with original Torchvision weights. + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfgs['tv_resnet50'] + if pretrained: + load_pretrained(model, model.default_cfg, num_classes, in_chans) + return model + + +@register_model +def wide_resnet50_2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a Wide ResNet-50-2 model. + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + """ + model = ResNet( + Bottleneck, [3, 4, 6, 3], base_width=128, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfgs['wide_resnet50_2'] + if pretrained: + load_pretrained(model, model.default_cfg, num_classes, in_chans) + return model + + +@register_model +def wide_resnet101_2(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a Wide ResNet-101-2 model. + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same. + """ + model = ResNet( + Bottleneck, [3, 4, 23, 3], base_width=128, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfgs['wide_resnet101_2'] + if pretrained: + load_pretrained(model, model.default_cfg, num_classes, in_chans) + return model + + @register_model def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNeXt50-32x4d model. @@ -301,7 +364,7 @@ def resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): @register_model def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a ResNeXt-101 model. + """Constructs a ResNeXt-101 32x4d model. """ default_cfg = default_cfgs['resnext101_32x4d'] model = ResNet( @@ -313,6 +376,20 @@ def resnext101_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): return model +@register_model +def resnext101_32x8d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNeXt-101 32x8d model. + """ + default_cfg = default_cfgs['resnext101_32x8d'] + model = ResNet( + Bottleneck, [3, 4, 23, 3], cardinality=32, base_width=8, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model + + @register_model def resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): """Constructs a ResNeXt101-64x4d model. @@ -328,12 +405,12 @@ def resnext101_64x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): @register_model -def resnext152_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): - """Constructs a ResNeXt152-32x4d model. +def tv_resnext50_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a ResNeXt50-32x4d model with original Torchvision weights. """ - default_cfg = default_cfgs['resnext152_32x4d'] + default_cfg = default_cfgs['tv_resnext50_32x4d'] model = ResNet( - Bottleneck, [3, 8, 36, 3], cardinality=32, base_width=4, + Bottleneck, [3, 4, 6, 3], cardinality=32, base_width=4, num_classes=num_classes, in_chans=in_chans, **kwargs) model.default_cfg = default_cfg if pretrained: diff --git a/timm/scheduler/scheduler.py b/timm/scheduler/scheduler.py index 59fcfc16..78e8460d 100644 --- a/timm/scheduler/scheduler.py +++ b/timm/scheduler/scheduler.py @@ -56,7 +56,7 @@ class Scheduler: def step(self, epoch: int, metric: float = None) -> None: self.metric = metric - values = self.get_epoch_values(epoch + 1) # +1 to calculate for next epoch + values = self.get_epoch_values(epoch) if values is not None: self.update_groups(values) diff --git a/timm/utils.py b/timm/utils.py index 8d4418a6..36355c2b 100644 --- a/timm/utils.py +++ b/timm/utils.py @@ -83,7 +83,8 @@ class CheckpointSaver: 'arch': args.model, 'state_dict': get_state_dict(model), 'optimizer': optimizer.state_dict(), - 'args': args + 'args': args, + 'version': 2, # version < 2 increments epoch before save } if model_ema is not None: save_state['state_dict_ema'] = get_state_dict(model_ema) diff --git a/timm/version.py b/timm/version.py index 10939f01..7525d199 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.1.2' +__version__ = '0.1.4' diff --git a/train.py b/train.py index 7196305f..f7ecdd5d 100644 --- a/train.py +++ b/train.py @@ -27,22 +27,21 @@ import torchvision.utils torch.backends.cudnn.benchmark = True parser = argparse.ArgumentParser(description='Training') +# Dataset / Model parameters parser.add_argument('data', metavar='DIR', help='path to dataset') parser.add_argument('--model', default='resnet101', type=str, metavar='MODEL', help='Name of model to train (default: "countception"') +parser.add_argument('--pretrained', action='store_true', default=False, + help='Start with pretrained version of specified network (if avail)') +parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', + help='Initialize model from this checkpoint (default: none)') +parser.add_argument('--resume', default='', type=str, metavar='PATH', + help='Resume full model and optimizer state from checkpoint (default: none)') parser.add_argument('--num-classes', type=int, default=1000, metavar='N', help='number of label classes (default: 1000)') -parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', - help='Optimizer (default: "sgd"') -parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', - help='Optimizer Epsilon (default: 1e-8)') parser.add_argument('--gp', default='avg', type=str, metavar='POOL', help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")') -parser.add_argument('--tta', type=int, default=0, metavar='N', - help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') -parser.add_argument('--pretrained', action='store_true', default=False, - help='Start with pretrained version of specified network (if avail)') parser.add_argument('--img-size', type=int, default=None, metavar='N', help='Image patch size (default: None => model default)') parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', @@ -53,8 +52,24 @@ parser.add_argument('--interpolation', default='', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') parser.add_argument('-b', '--batch-size', type=int, default=32, metavar='N', help='input batch size for training (default: 32)') -parser.add_argument('-s', '--initial-batch-size', type=int, default=0, metavar='N', - help='initial input batch size for training (default: 0)') +parser.add_argument('--drop', type=float, default=0.0, metavar='DROP', + help='Dropout rate (default: 0.)') +# Optimizer parameters +parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', + help='Optimizer (default: "sgd"') +parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON', + help='Optimizer Epsilon (default: 1e-8)') +parser.add_argument('--momentum', type=float, default=0.9, metavar='M', + help='SGD momentum (default: 0.9)') +parser.add_argument('--weight-decay', type=float, default=0.0001, + help='weight decay (default: 0.0001)') +# Learning rate schedule parameters +parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', + help='LR scheduler (default: "step"') +parser.add_argument('--lr', type=float, default=0.01, metavar='LR', + help='learning rate (default: 0.01)') +parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', + help='warmup learning rate (default: 0.0001)') parser.add_argument('--epochs', type=int, default=200, metavar='N', help='number of epochs to train (default: 2)') parser.add_argument('--start-epoch', default=None, type=int, metavar='N', @@ -65,40 +80,34 @@ parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', help='epochs to warmup LR, if scheduler supports') parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') -parser.add_argument('--sched', default='step', type=str, metavar='SCHEDULER', - help='LR scheduler (default: "step"') -parser.add_argument('--drop', type=float, default=0.0, metavar='DROP', - help='Dropout rate (default: 0.)') +# Augmentation parameters +parser.add_argument('--color_jitter', type=float, default=0.4, metavar='PCT', + help='Color jitter factor (default: 0.4)') parser.add_argument('--reprob', type=float, default=0., metavar='PCT', help='Random erase prob (default: 0.)') parser.add_argument('--remode', type=str, default='const', help='Random erase mode (default: "const")') -parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help='learning rate (default: 0.01)') -parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', - help='warmup learning rate (default: 0.0001)') -parser.add_argument('--momentum', type=float, default=0.9, metavar='M', - help='SGD momentum (default: 0.9)') -parser.add_argument('--weight-decay', type=float, default=0.0001, - help='weight decay (default: 0.0001)') parser.add_argument('--mixup', type=float, default=0.0, help='mixup alpha, mixup enabled if > 0. (default: 0.)') parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', help='turn off mixup after this epoch, disabled if 0 (default: 0)') parser.add_argument('--smoothing', type=float, default=0.1, help='label smoothing (default: 0.1)') +# Batch norm parameters (only works with gen_efficientnet based models currently) parser.add_argument('--bn-tf', action='store_true', default=False, help='Use Tensorflow BatchNorm defaults for models that support it (default: False)') parser.add_argument('--bn-momentum', type=float, default=None, help='BatchNorm momentum override (if not None)') parser.add_argument('--bn-eps', type=float, default=None, help='BatchNorm epsilon override (if not None)') +# Model Exponential Moving Average parser.add_argument('--model-ema', action='store_true', default=False, help='Enable tracking moving average of model weights') parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') parser.add_argument('--model-ema-decay', type=float, default=0.9998, help='decay factor for model weights moving average (default: 0.9998)') +# Misc parser.add_argument('--seed', type=int, default=42, metavar='S', help='random seed (default: 42)') parser.add_argument('--log-interval', type=int, default=50, metavar='N', @@ -109,10 +118,6 @@ parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', help='how many training processes to use (default: 1)') parser.add_argument('--num-gpu', type=int, default=1, help='Number of GPUS to use') -parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', - help='path to init checkpoint (default: none)') -parser.add_argument('--resume', default='', type=str, metavar='PATH', - help='path to latest checkpoint (default: none)') parser.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') parser.add_argument('--amp', action='store_true', default=False, @@ -125,6 +130,8 @@ parser.add_argument('--output', default='', type=str, metavar='PATH', help='path to output folder (default: none, current dir)') parser.add_argument('--eval-metric', default='prec1', type=str, metavar='EVAL_METRIC', help='Best metric (default: "prec1"') +parser.add_argument('--tta', type=int, default=0, metavar='N', + help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') parser.add_argument("--local_rank", default=0, type=int) @@ -174,13 +181,13 @@ def main(): logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) - data_config = resolve_data_config(model, args, verbose=args.local_rank == 0) + data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) # optionally resume from a checkpoint - start_epoch = 0 optimizer_state = None + resume_epoch = None if args.resume: - optimizer_state, start_epoch = resume_checkpoint(model, args.resume, args.start_epoch) + optimizer_state, resume_epoch = resume_checkpoint(model, args.resume) if args.num_gpu > 1: if args.amp: @@ -232,8 +239,15 @@ def main(): # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) + start_epoch = 0 + if args.start_epoch is not None: + # a specified start_epoch will always override the resume epoch + start_epoch = args.start_epoch + elif resume_epoch is not None: + start_epoch = resume_epoch if start_epoch > 0: lr_scheduler.step(start_epoch) + if args.local_rank == 0: logging.info('Scheduled epochs: {}'.format(num_epochs)) @@ -255,6 +269,7 @@ def main(): use_prefetcher=args.prefetcher, rand_erase_prob=args.reprob, rand_erase_mode=args.remode, + color_jitter=args.color_jitter, interpolation='random', # FIXME cleanly resolve this? data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], @@ -327,7 +342,8 @@ def main(): eval_metrics = ema_eval_metrics if lr_scheduler is not None: - lr_scheduler.step(epoch, eval_metrics[eval_metric]) + # step LR for next epoch + lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary( epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), @@ -338,9 +354,7 @@ def main(): save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, - epoch=epoch + 1, - model_ema=model_ema, - metric=save_metric) + epoch=epoch, model_ema=model_ema, metric=save_metric) except KeyboardInterrupt: pass @@ -433,9 +447,8 @@ def train_epoch( if saver is not None and args.recovery_interval and ( last_batch or (batch_idx + 1) % args.recovery_interval == 0): - save_epoch = epoch + 1 if last_batch else epoch saver.save_recovery( - model, optimizer, args, save_epoch, model_ema=model_ema, batch_idx=batch_idx) + model, optimizer, args, epoch, model_ema=model_ema, batch_idx=batch_idx) if lr_scheduler is not None: lr_scheduler.step_update(num_updates=num_updates, metric=losses_m.avg) diff --git a/validate.py b/validate.py index 199888ad..280bc260 100644 --- a/validate.py +++ b/validate.py @@ -71,7 +71,7 @@ def validate(args): param_count = sum([m.numel() for m in model.parameters()]) logging.info('Model %s created, param count: %d' % (args.model, param_count)) - data_config = resolve_data_config(model, args) + data_config = resolve_data_config(vars(args), model=model) model, test_time_pool = apply_test_time_pool(model, data_config, args) if args.num_gpu > 1: