From f591e90b0d4896561143ba47fe62b92f3804ad68 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 29 Oct 2020 15:33:47 -0700 Subject: [PATCH 1/6] Make sure num_features attr is present in vit models as with others --- timm/models/vision_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 042efc05..32acccf3 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -206,7 +206,7 @@ class VisionTransformer(nn.Module): drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm): super().__init__() self.num_classes = num_classes - self.embed_dim = embed_dim + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models if hybrid_backbone is not None: self.patch_embed = HybridEmbed( From da6cd2cc1fd8696986b1cf224a464f45819eec2d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 29 Oct 2020 15:43:39 -0700 Subject: [PATCH 2/6] Fix regression for pretrained classifier loading when using entrypt functions directly --- tests/test_models.py | 2 +- timm/models/helpers.py | 5 +++-- timm/models/hrnet.py | 1 + 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index c673dc96..db8efbf3 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -24,7 +24,7 @@ MAX_FWD_FEAT_SIZE = 448 @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-1])) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward(model_name, batch_size): """Run a single forward pass with each model""" diff --git a/timm/models/helpers.py b/timm/models/helpers.py index ac119295..b90ce1db 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -277,11 +277,12 @@ def build_model_with_cfg( if pruned: model = adapt_model_from_file(model, variant) + # for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats + num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) if pretrained: load_pretrained( model, - num_classes=kwargs.get('num_classes', 0), - in_chans=kwargs.get('in_chans', 3), + num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3), filter_fn=pretrained_filter_fn, strict=pretrained_strict) if features: diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 1e867686..2e8757b5 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -776,6 +776,7 @@ def _create_hrnet(variant, pretrained, **model_kwargs): strict = True if model_kwargs.pop('features_only', False): model_cls = HighResolutionNetFeatures + model_kwargs['num_classes'] = 0 strict = False return build_model_with_cfg( From e90edce438b3751cb5ed93c00e65516b00885040 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 29 Oct 2020 15:45:17 -0700 Subject: [PATCH 3/6] Support native silu activation (aka swish). An optimized ver is available in PyTorch 1.7. --- timm/models/layers/create_act.py | 22 ++++++++++++++++------ 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py index 5bc4db99..6f2ab83e 100644 --- a/timm/models/layers/create_act.py +++ b/timm/models/layers/create_act.py @@ -6,9 +6,14 @@ from .activations_jit import * from .activations_me import * from .config import is_exportable, is_scriptable, is_no_jit +# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. This code +# will use native version if present. Eventually, the custom Swish layers will be removed +# and only native 'silu' will be used. +_has_silu = 'silu' in dir(torch.nn.functional) _ACT_FN_DEFAULT = dict( - swish=swish, + silu=F.silu if _has_silu else swish, + swish=F.silu if _has_silu else swish, mish=mish, relu=F.relu, relu6=F.relu6, @@ -26,7 +31,8 @@ _ACT_FN_DEFAULT = dict( ) _ACT_FN_JIT = dict( - swish=swish_jit, + silu=F.silu if _has_silu else swish_jit, + swish=F.silu if _has_silu else swish_jit, mish=mish_jit, hard_sigmoid=hard_sigmoid_jit, hard_swish=hard_swish_jit, @@ -34,7 +40,8 @@ _ACT_FN_JIT = dict( ) _ACT_FN_ME = dict( - swish=swish_me, + silu=F.silu if _has_silu else swish_me, + swish=F.silu if _has_silu else swish_me, mish=mish_me, hard_sigmoid=hard_sigmoid_me, hard_swish=hard_swish_me, @@ -42,7 +49,8 @@ _ACT_FN_ME = dict( ) _ACT_LAYER_DEFAULT = dict( - swish=Swish, + silu=nn.SiLU if _has_silu else Swish, + swish=nn.SiLU if _has_silu else Swish, mish=Mish, relu=nn.ReLU, relu6=nn.ReLU6, @@ -60,7 +68,8 @@ _ACT_LAYER_DEFAULT = dict( ) _ACT_LAYER_JIT = dict( - swish=SwishJit, + silu=nn.SiLU if _has_silu else SwishJit, + swish=nn.SiLU if _has_silu else SwishJit, mish=MishJit, hard_sigmoid=HardSigmoidJit, hard_swish=HardSwishJit, @@ -68,7 +77,8 @@ _ACT_LAYER_JIT = dict( ) _ACT_LAYER_ME = dict( - swish=SwishMe, + silu=nn.SiLU if _has_silu else SwishMe, + swish=nn.SiLU if _has_silu else SwishMe, mish=MishMe, hard_sigmoid=HardSigmoidMe, hard_swish=HardSwishMe, From 61200db0abcdfee0b609a4e75b20258a5903827b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 29 Oct 2020 15:49:36 -0700 Subject: [PATCH 4/6] in_chans=1 working w/ pretrained weights for vision_transformer --- timm/models/vision_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 32acccf3..bd9ea231 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -37,7 +37,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': '', 'classifier': 'head', + 'first_conv': 'patch_embed.proj', 'classifier': 'head', **kwargs } From b401952caf92340d4f12fd62a947c147e435357e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 29 Oct 2020 17:31:01 -0700 Subject: [PATCH 5/6] Add newly added vision transformer large/base 224x224 weights ported from JAX official repo --- timm/models/vision_transformer.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index bd9ea231..72f3a61a 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -48,7 +48,8 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth', ), 'vit_base_patch16_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_base_p16_224-4e355ebd.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), ), 'vit_base_patch16_384': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth', @@ -56,7 +57,9 @@ default_cfgs = { 'vit_base_patch32_384': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth', input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), - 'vit_large_patch16_224': _cfg(), + 'vit_large_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_224-4ee7a4dc.pth', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), 'vit_large_patch16_384': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth', input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), @@ -305,10 +308,9 @@ def vit_small_patch16_224(pretrained=False, **kwargs): @register_model def vit_base_patch16_224(pretrained=False, **kwargs): - if pretrained: - # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model - kwargs.setdefault('qk_scale', 768 ** -0.5) - model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs) + model = VisionTransformer( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = default_cfgs['vit_base_patch16_224'] if pretrained: load_pretrained( @@ -340,8 +342,12 @@ def vit_base_patch32_384(pretrained=False, **kwargs): @register_model def vit_large_patch16_224(pretrained=False, **kwargs): - model = VisionTransformer(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs) + model = VisionTransformer( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model.default_cfg = default_cfgs['vit_large_patch16_224'] + if pretrained: + load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) return model From 741572dc9d7eed93b550300c38cdc880c312de66 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 29 Oct 2020 17:31:39 -0700 Subject: [PATCH 6/6] Bump version to 0.3.0 for pending PyPi push --- timm/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/version.py b/timm/version.py index 020ed73d..0404d810 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.2.2' +__version__ = '0.3.0'