From 55f7dfa9ea8bab0296c774dd7234d602b8396ce5 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 18 Jan 2021 16:11:02 -0800 Subject: [PATCH] Refactor vision_transformer entrpy fns, add pos embedding resize support for fine tuning, add some deit models for testing --- tests/test_models.py | 4 +- timm/models/vision_transformer.py | 307 +++++++++++++++--------------- 2 files changed, 153 insertions(+), 158 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 17d592d4..dee4fbe7 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -14,7 +14,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): torch._C._jit_set_profiling_mode(False) # transformer models don't support many of the spatial / feature based model functionalities -NON_STD_FILTERS = ['vit_*', 'deit_*'] +NON_STD_FILTERS = ['vit_*'] # exclude models that cause specific test failures if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): @@ -29,7 +29,7 @@ MAX_FWD_FEAT_SIZE = 448 @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-2])) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-1])) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward(model_name, batch_size): """Run a single forward pass with each model""" diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 076010ab..a832cce3 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -17,11 +17,15 @@ paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.128 Hacked together by / Copyright 2020 Ross Wightman """ -import torch -import torch.nn as nn +import math +import logging from functools import partial from collections import OrderedDict +import torch +import torch.nn as nn +import torch.nn.functional as F + from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained from .layers import DropPath, to_2tuple, trunc_normal_ @@ -29,6 +33,8 @@ from .resnet import resnet26d, resnet50d from .resnetv2 import ResNetV2, StdConv2dSame from .registry import register_model +_logger = logging.getLogger(__name__) + def _cfg(url='', **kwargs): return { @@ -94,7 +100,7 @@ default_cfgs = { # hybrid models (weights ported from official Google JAX impl) 'vit_base_resnet50_224_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', - mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9), + num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9), 'vit_base_resnet50_384': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0), @@ -106,15 +112,15 @@ default_cfgs = { 'vit_base_resnet50d_224': _cfg(), # deit models (FB weights) - 'deit_tiny_patch16_224': _cfg( + 'vit_deit_tiny_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'), - 'deit_small_patch16_224': _cfg( + 'vit_deit_small_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'), - 'deit_base_patch16_224': _cfg( + 'vit_deit_base_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',), - 'deit_base_patch16_384': _cfg( - url='', # no weights yet - input_size=(3, 384, 384)), + 'vit_deit_base_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', + input_size=(3, 384, 384), crop_pct=1.0), } @@ -253,11 +259,12 @@ class VisionTransformer(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, representation_size=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm): + num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None): super().__init__() self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) if hybrid_backbone is not None: self.patch_embed = HybridEmbed( @@ -290,7 +297,7 @@ class VisionTransformer(nn.Module): self.pre_logits = nn.Identity() # Classifier head - self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) @@ -338,180 +345,196 @@ class VisionTransformer(nn.Module): return x -def _conv_filter(state_dict, patch_size=16): +def resize_pos_embed(posemb, posemb_new): + # Rescale the grid of position embeddings when loading from state_dict + # Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) + ntok_new = posemb_new.shape[1] + if True: + posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] + ntok_new -= 1 + else: + posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + gs_old = int(math.sqrt(len(posemb_grid))) + gs_new = int(math.sqrt(ntok_new)) + _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new) + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear') + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1) + posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + state_dict['pos_embed'] = posemb + return state_dict + + +def checkpoint_filter_fn(state_dict, model): """ convert patch embedding weight from manual patchify + linear proj to conv""" out_dict = {} + if 'model' in state_dict: + # for deit models + state_dict = state_dict['model'] for k, v in state_dict.items(): - if 'patch_embed.proj.weight' in k: - v = v.reshape((v.shape[0], 3, patch_size, patch_size)) + if 'patch_embed.proj.weight' in k and len(v.shape) < 4: + # for old models that I trained prior to conv based patchification + v = v.reshape(model.patch_embed.proj.weight.shape) + elif k == 'pos_embed' and v.shape != model.pos_embed.shape: + # to resize pos embedding when using model at different size from pretrained weights + v = resize_pos_embed(v, model.pos_embed) out_dict[k] = v return out_dict +def _create_vision_transformer(variant, pretrained=False, **kwargs): + default_cfg = default_cfgs[variant] + default_num_classes = default_cfg['num_classes'] + default_img_size = default_cfg['input_size'][-1] + + num_classes = kwargs.pop('num_classes', default_num_classes) + img_size = kwargs.pop('img_size', default_img_size) + repr_size = kwargs.pop('representation_size', None) + if repr_size is not None and num_classes != default_num_classes: + # remove representation layer if fine-tuning + _logger.info("Removing representation layer for fine-tuning.") + repr_size = None + + model = VisionTransformer(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs) + model.default_cfg = default_cfg + + if pretrained: + load_pretrained( + model, num_classes=num_classes, in_chans=kwargs.get('in_chans', 3), + filter_fn=partial(checkpoint_filter_fn, model=model)) + return model + + @register_model def vit_small_patch16_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., + qkv_bias=False, norm_layer=nn.LayerNorm, **kwargs) if pretrained: # NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model - kwargs.setdefault('qk_scale', 768 ** -0.5) - model = VisionTransformer(patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., **kwargs) - model.default_cfg = default_cfgs['vit_small_patch16_224'] - if pretrained: - load_pretrained( - model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter) + model_kwargs.setdefault('qk_scale', 768 ** -0.5) + model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch16_224(pretrained=False, **kwargs): - model = VisionTransformer( - patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model.default_cfg = default_cfgs['vit_base_patch16_224'] - if pretrained: - load_pretrained( - model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch32_224(pretrained=False, **kwargs): - model = VisionTransformer( - img_size=224, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model.default_cfg = default_cfgs['vit_base_patch32_224'] - if pretrained: - load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs) + model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch16_384(pretrained=False, **kwargs): - model = VisionTransformer( - img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model.default_cfg = default_cfgs['vit_base_patch16_384'] - if pretrained: - load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs) + model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch32_384(pretrained=False, **kwargs): - model = VisionTransformer( - img_size=384, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, + model_kwargs = dict( + patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model.default_cfg = default_cfgs['vit_base_patch32_384'] - if pretrained: - load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) return model @register_model def vit_large_patch16_224(pretrained=False, **kwargs): - model = VisionTransformer( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model.default_cfg = default_cfgs['vit_large_patch16_224'] - if pretrained: - load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs) + model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_large_patch32_224(pretrained=False, **kwargs): - model = VisionTransformer( - img_size=224, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model.default_cfg = default_cfgs['vit_large_patch32_224'] - if pretrained: - load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs) + model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) return model @register_model def vit_large_patch16_384(pretrained=False, **kwargs): - model = VisionTransformer( - img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model.default_cfg = default_cfgs['vit_large_patch16_384'] - if pretrained: - load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs) + model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch16_224_in21k(pretrained=False, **kwargs): - num_classes = kwargs.pop('num_classes', 21843) - model = VisionTransformer( - patch_size=16, num_classes=num_classes, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, - representation_size=768, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model.default_cfg = default_cfgs['vit_base_patch16_224_in21k'] - if pretrained: - load_pretrained( - model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter) + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, representation_size=768, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_384_in21k(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, representation_size=768, **kwargs) + model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_patch32_224_in21k(pretrained=False, **kwargs): - num_classes = kwargs.pop('num_classes', 21843) - model = VisionTransformer( - img_size=224, num_classes=num_classes, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, - qkv_bias=True, representation_size=768, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model.default_cfg = default_cfgs['vit_base_patch32_224_in21k'] - if pretrained: - load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model_kwargs = dict( + patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, representation_size=768, **kwargs) + model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) return model @register_model def vit_large_patch16_224_in21k(pretrained=False, **kwargs): - num_classes = kwargs.pop('num_classes', 21843) - model = VisionTransformer( - patch_size=16, num_classes=num_classes, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, - representation_size=1024, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model.default_cfg = default_cfgs['vit_large_patch16_224_in21k'] - if pretrained: - load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, representation_size=1024, **kwargs) + model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model +# @register_model +# def vit_large_patch16_384_in21k(pretrained=False, **kwargs): +# model_kwargs = dict( +# patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, representation_size=1024, **kwargs) +# model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) +# return model + + @register_model def vit_large_patch32_224_in21k(pretrained=False, **kwargs): - num_classes = kwargs.get('num_classes', 21843) - model = VisionTransformer( - img_size=224, num_classes=num_classes, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, - qkv_bias=True, representation_size=1024, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model.default_cfg = default_cfgs['vit_large_patch32_224_in21k'] - if pretrained: - load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model_kwargs = dict( + patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, representation_size=1024, **kwargs) + model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs) return model @register_model def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): - num_classes = kwargs.pop('num_classes', 21843) - model = VisionTransformer( - img_size=224, patch_size=14, num_classes=num_classes, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, - qkv_bias=True, representation_size=1280, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model.default_cfg = default_cfgs['vit_huge_patch14_224_in21k'] - if pretrained: - load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model_kwargs = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, representation_size=1280, **kwargs) + model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs) return model @register_model def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head - num_classes = kwargs.pop('num_classes', 21843) backbone = ResNetV2( layers=(3, 4, 9), preact=False, stem_type='same', conv_layer=StdConv2dSame, num_classes=0, global_pool='') - model = VisionTransformer( - img_size=224, num_classes=num_classes, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, - hybrid_backbone=backbone, representation_size=768, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model.default_cfg = default_cfgs['vit_base_resnet50_224_in21k'] - if pretrained: - load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model_kwargs = dict( + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, + representation_size=768, **kwargs) + model = _create_vision_transformer('vit_base_resnet50_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -520,12 +543,8 @@ def vit_base_resnet50_384(pretrained=False, **kwargs): # create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head backbone = ResNetV2( layers=(3, 4, 9), preact=False, stem_type='same', conv_layer=StdConv2dSame, num_classes=0, global_pool='') - model = VisionTransformer( - img_size=384, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, - qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model.default_cfg = default_cfgs['vit_base_resnet50_384'] - if pretrained: - load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_base_resnet50_384', pretrained=pretrained, **model_kwargs) return model @@ -533,9 +552,8 @@ def vit_base_resnet50_384(pretrained=False, **kwargs): def vit_small_resnet26d_224(pretrained=False, **kwargs): pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing backbone = resnet26d(pretrained=pretrained_backbone, features_only=True, out_indices=[4]) - model = VisionTransformer( - img_size=224, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) - model.default_cfg = default_cfgs['vit_small_resnet26d_224'] + model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_resnet26d_224', pretrained=pretrained, **model_kwargs) return model @@ -543,9 +561,8 @@ def vit_small_resnet26d_224(pretrained=False, **kwargs): def vit_small_resnet50d_s3_224(pretrained=False, **kwargs): pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing backbone = resnet50d(pretrained=pretrained_backbone, features_only=True, out_indices=[3]) - model = VisionTransformer( - img_size=224, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) - model.default_cfg = default_cfgs['vit_small_resnet50d_s3_224'] + model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_small_resnet50d_s3_224', pretrained=pretrained, **model_kwargs) return model @@ -553,9 +570,8 @@ def vit_small_resnet50d_s3_224(pretrained=False, **kwargs): def vit_base_resnet26d_224(pretrained=False, **kwargs): pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing backbone = resnet26d(pretrained=pretrained_backbone, features_only=True, out_indices=[4]) - model = VisionTransformer( - img_size=224, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs) - model.default_cfg = default_cfgs['vit_base_resnet26d_224'] + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_base_resnet26d_224', pretrained=pretrained, **model_kwargs) return model @@ -563,55 +579,34 @@ def vit_base_resnet26d_224(pretrained=False, **kwargs): def vit_base_resnet50d_224(pretrained=False, **kwargs): pretrained_backbone = kwargs.get('pretrained_backbone', True) # default to True for now, for testing backbone = resnet50d(pretrained=pretrained_backbone, features_only=True, out_indices=[4]) - model = VisionTransformer( - img_size=224, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs) - model.default_cfg = default_cfgs['vit_base_resnet50d_224'] + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs) + model = _create_vision_transformer('vit_base_resnet50d_224', pretrained=pretrained, **model_kwargs) return model @register_model -def deit_tiny_patch16_224(pretrained=False, **kwargs): - model = VisionTransformer( - patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model.default_cfg = default_cfgs['deit_tiny_patch16_224'] - if pretrained: - load_pretrained( - model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model']) +def vit_deit_tiny_patch16_224(pretrained=False, **kwargs): + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, **kwargs) + model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model -def deit_small_patch16_224(pretrained=False, **kwargs): - model = VisionTransformer( - patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model.default_cfg = default_cfgs['deit_small_patch16_224'] - if pretrained: - load_pretrained( - model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model']) +def vit_deit_small_patch16_224(pretrained=False, **kwargs): + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs) + model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model -def deit_base_patch16_224(pretrained=False, **kwargs): - model = VisionTransformer( - patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model.default_cfg = default_cfgs['deit_base_patch16_224'] - if pretrained: - load_pretrained( - model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model']) +def vit_deit_base_patch16_224(pretrained=False, **kwargs): + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs) + model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs) return model @register_model -def deit_base_patch16_384(pretrained=False, **kwargs): - model = VisionTransformer( - img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True, - norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) - model.default_cfg = default_cfgs['deit_base_patch16_384'] - if pretrained: - load_pretrained( - model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), filter_fn=lambda x: x['model']) +def vit_deit_base_patch16_384(pretrained=False, **kwargs): + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs) + model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs) return model