From ea9c9550b24dfaf30fdcca960b9cc24a65c359fe Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 1 Apr 2021 14:17:38 -0700 Subject: [PATCH] Fully move ViT hybrids to their own file, including embedding module. Remove some extra DeiT models that were for benchmarking only. --- timm/models/vision_transformer.py | 134 +----------- timm/models/vision_transformer_hybrid.py | 256 +++++++++-------------- 2 files changed, 112 insertions(+), 278 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 5f244589..578a5f08 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -5,6 +5,9 @@ A PyTorch implement of Vision Transformers as described in The official jax code is released and available at https://github.com/google-research/vision_transformer +DeiT model defs and weights from https://github.com/facebookresearch/deit, +paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + Acknowledgments: * The paper authors for releasing code and weights, thanks! * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out @@ -12,9 +15,6 @@ for some einops/einsum fun * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT * Bert reference code checks against Huggingface Transformers and Tensorflow Bert -DeiT model defs and weights from https://github.com/facebookresearch/deit, -paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 - Hacked together by / Copyright 2020 Ross Wightman """ import math @@ -99,18 +99,8 @@ default_cfgs = { # deit models (FB weights) 'vit_deit_tiny_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'), - 'vit_deit_tiny_patch16_224_in21k': _cfg(num_classes=21843), - 'vit_deit_tiny_patch16_384': _cfg(input_size=(3, 384, 384)), - 'vit_deit_small_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'), - 'vit_deit_small_patch16_224_in21k': _cfg(num_classes=21843), - 'vit_deit_small_patch16_384': _cfg(input_size=(3, 384, 384)), - - 'vit_deit_small_patch32_224': _cfg(), - 'vit_deit_small_patch32_224_in21k': _cfg(num_classes=21843), - 'vit_deit_small_patch32_384': _cfg(input_size=(3, 384, 384)), - 'vit_deit_base_patch16_224': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',), 'vit_deit_base_patch16_384': _cfg( @@ -220,48 +210,6 @@ class PatchEmbed(nn.Module): return x -class HybridEmbed(nn.Module): - """ CNN Feature Map Embedding - Extract feature map from CNN, flatten, project to embedding dim. - """ - def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768): - super().__init__() - assert isinstance(backbone, nn.Module) - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - self.img_size = img_size - self.patch_size = patch_size - self.backbone = backbone - if feature_size is None: - with torch.no_grad(): - # NOTE Most reliable way of determining output dims is to run forward pass - training = backbone.training - if training: - backbone.eval() - o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) - if isinstance(o, (list, tuple)): - o = o[-1] # last feature if backbone outputs list/tuple of features - feature_size = o.shape[-2:] - feature_dim = o.shape[1] - backbone.train(training) - else: - feature_size = to_2tuple(feature_size) - if hasattr(self.backbone, 'feature_info'): - feature_dim = self.backbone.feature_info.channels()[-1] - else: - feature_dim = self.backbone.num_features - assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 - self.num_patches = feature_size[0] // patch_size[0] * feature_size[1] // patch_size[1] - self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) - - def forward(self, x): - x = self.backbone(x) - if isinstance(x, (list, tuple)): - x = x[-1] # last feature if backbone outputs list/tuple of features - x = self.proj(x).flatten(2).transpose(1, 2) - return x - - class VisionTransformer(nn.Module): """ Vision Transformer @@ -274,7 +222,7 @@ class VisionTransformer(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, distilled=False, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, act_layer=None, weight_init=''): """ Args: @@ -293,7 +241,7 @@ class VisionTransformer(nn.Module): drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate - hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module + embed_layer (nn.Module): patch embedding layer norm_layer: (nn.Module): normalization layer weight_init: (str): weight init scheme """ @@ -303,14 +251,9 @@ class VisionTransformer(nn.Module): self.num_tokens = 2 if distilled else 1 norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU - patch_size = patch_size or (1 if hybrid_backbone is not None else 16) - if hybrid_backbone is not None: - self.patch_embed = HybridEmbed( - hybrid_backbone, img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) - else: - self.patch_embed = PatchEmbed( - img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) @@ -489,8 +432,9 @@ def checkpoint_filter_fn(state_dict, model): return out_dict -def _create_vision_transformer(variant, pretrained=False, **kwargs): - default_cfg = deepcopy(default_cfgs[variant]) +def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): + if default_cfg is None: + default_cfg = deepcopy(default_cfgs[variant]) overlay_external_default_cfg(default_cfg, kwargs) default_num_classes = default_cfg['num_classes'] default_img_size = default_cfg['input_size'][-2:] @@ -680,22 +624,6 @@ def vit_deit_tiny_patch16_224(pretrained=False, **kwargs): return model -@register_model -def vit_deit_tiny_patch16_224_in21k(pretrained=False, **kwargs): - """ DeiT-tiny model""" - model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, representation_size=192, **kwargs) - model = _create_vision_transformer('vit_deit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_deit_tiny_patch16_384(pretrained=False, **kwargs): - """ DeiT-tiny model""" - model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) - model = _create_vision_transformer('vit_deit_tiny_patch16_384', pretrained=pretrained, **model_kwargs) - return model - - @register_model def vit_deit_small_patch16_224(pretrained=False, **kwargs): """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). @@ -706,48 +634,6 @@ def vit_deit_small_patch16_224(pretrained=False, **kwargs): return model -@register_model -def vit_deit_small_patch16_224_in21k(pretrained=False, **kwargs): - """ DeiT-small """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, representation_size=384, **kwargs) - model = _create_vision_transformer('vit_deit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_deit_small_patch16_384(pretrained=False, **kwargs): - """ DeiT-small """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_deit_small_patch16_384', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_deit_small_patch32_224(pretrained=False, **kwargs): - """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). - ImageNet-1k weights from https://github.com/facebookresearch/deit. - """ - model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_deit_small_patch32_224', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_deit_small_patch32_224_in21k(pretrained=False, **kwargs): - """ DeiT-small """ - model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, representation_size=384, **kwargs) - model = _create_vision_transformer('vit_deit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_deit_small_patch32_384(pretrained=False, **kwargs): - """ DeiT-small """ - model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_deit_small_patch32_384', pretrained=pretrained, **model_kwargs) - return model - - @register_model def vit_deit_base_patch16_224(pretrained=False, **kwargs): """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 293dd34d..816bbc8e 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -9,6 +9,12 @@ keep file sizes sane. Hacked together by / Copyright 2020 Ross Wightman """ +from copy import deepcopy +from functools import partial + +import torch +import torch.nn as nn + from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .layers import StdConv2dSame, StdConv2d, to_2tuple from .resnet import resnet26d, resnet50d @@ -41,39 +47,14 @@ default_cfgs = { # hybrid in-1k models (mostly untrained, experimental configs w/ resnetv2 stdconv backbones) 'vit_tiny_r_s16_p8_224': _cfg(), - 'vit_tiny_r_s16_p8_384': _cfg( - input_size=(3, 384, 384), crop_pct=1.0), - - 'vit_small_r_s16_p8_224': _cfg( - crop_pct=1.0), - 'vit_small_r_s16_p8_384': _cfg( - input_size=(3, 384, 384), crop_pct=1.0), - + 'vit_small_r_s16_p8_224': _cfg(), 'vit_small_r20_s16_p2_224': _cfg(), - 'vit_small_r20_s16_p2_384': _cfg( - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_small_r20_s16_224': _cfg(), - 'vit_small_r20_s16_384': _cfg( - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_small_r26_s32_224': _cfg(), - 'vit_small_r26_s32_384': _cfg( - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_base_r20_s16_224': _cfg(), - 'vit_base_r20_s16_384': _cfg( - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_base_r26_s32_224': _cfg(), - 'vit_base_r26_s32_384': _cfg( - input_size=(3, 384, 384), crop_pct=1.0), - 'vit_base_r50_s16_224': _cfg(), - 'vit_large_r50_s32_224': _cfg(), - 'vit_large_r50_s32_384': _cfg( - input_size=(3, 384, 384), crop_pct=1.0), # hybrid models (using timm resnet backbones) 'vit_small_resnet26d_224': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), @@ -83,6 +64,56 @@ default_cfgs = { } +class HybridEmbed(nn.Module): + """ CNN Feature Map Embedding + Extract feature map from CNN, flatten, project to embedding dim. + """ + def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768): + super().__init__() + assert isinstance(backbone, nn.Module) + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.backbone = backbone + if feature_size is None: + with torch.no_grad(): + # NOTE Most reliable way of determining output dims is to run forward pass + training = backbone.training + if training: + backbone.eval() + o = self.backbone(torch.zeros(1, in_chans, img_size[0], img_size[1])) + if isinstance(o, (list, tuple)): + o = o[-1] # last feature if backbone outputs list/tuple of features + feature_size = o.shape[-2:] + feature_dim = o.shape[1] + backbone.train(training) + else: + feature_size = to_2tuple(feature_size) + if hasattr(self.backbone, 'feature_info'): + feature_dim = self.backbone.feature_info.channels()[-1] + else: + feature_dim = self.backbone.num_features + assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 + self.num_patches = feature_size[0] // patch_size[0] * feature_size[1] // patch_size[1] + self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + x = self.backbone(x) + if isinstance(x, (list, tuple)): + x = x[-1] # last feature if backbone outputs list/tuple of features + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): + default_cfg = deepcopy(default_cfgs[variant]) + embed_layer = partial(HybridEmbed, backbone=backbone) + kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set + return _create_vision_transformer( + variant, pretrained=pretrained, default_cfg=default_cfg, embed_layer=embed_layer, **kwargs) + + def _resnetv2(layers=(3, 4, 9), **kwargs): """ ResNet-V2 backbone helper""" padding_same = kwargs.get('padding_same', True) @@ -108,9 +139,9 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. """ backbone = _resnetv2(layers=(3, 4, 9), **kwargs) - model_kwargs = dict( - embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, representation_size=768, **kwargs) - model = _create_vision_transformer('vit_base_r50_s16_224_in21k', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -120,8 +151,9 @@ def vit_base_r50_s16_384(pretrained=False, **kwargs): ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ backbone = _resnetv2((3, 4, 9), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_base_r50_s16_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -130,20 +162,9 @@ def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs): """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224. """ backbone = _resnetv2(layers=(), **kwargs) - model_kwargs = dict( - patch_size=8, embed_dim=192, depth=12, num_heads=3, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_tiny_r_s16_p8_224', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs): - """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224. - """ - backbone = _resnetv2(layers=(), **kwargs) - model_kwargs = dict( - patch_size=8, embed_dim=192, depth=12, num_heads=3, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_tiny_r_s16_p8_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -152,21 +173,10 @@ def vit_small_r_s16_p8_224(pretrained=False, **kwargs): """ R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224. """ backbone = _resnetv2(layers=(), **kwargs) - model_kwargs = dict( - patch_size=8, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_r_s16_p8_224', pretrained=pretrained, **model_kwargs) - - return model - + model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs) -@register_model -def vit_small_r_s16_p8_384(pretrained=False, **kwargs): - """ R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224. - """ - backbone = _resnetv2(layers=(), **kwargs) - model_kwargs = dict( - patch_size=8, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_r_s16_p8_384', pretrained=pretrained, **model_kwargs) return model @@ -175,20 +185,9 @@ def vit_small_r20_s16_p2_224(pretrained=False, **kwargs): """ R52+ViT-S/S16 w/ 2x2 patch hybrid @ 224 x 224. """ backbone = _resnetv2((2, 4), **kwargs) - model_kwargs = dict( - patch_size=2, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_r20_s16_p2_224', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_small_r20_s16_p2_384(pretrained=False, **kwargs): - """ R20+ViT-S/S16 w/ 2x2 Patch hybrid @ 384x384. - """ - backbone = _resnetv2((2, 4), **kwargs) - model_kwargs = dict( - embed_dim=384, patch_size=2, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_r20_s16_p2_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=2, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_r20_s16_p2_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -197,18 +196,9 @@ def vit_small_r20_s16_224(pretrained=False, **kwargs): """ R20+ViT-S/S16 hybrid. """ backbone = _resnetv2((2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_r20_s16_224', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_small_r20_s16_384(pretrained=False, **kwargs): - """ R20+ViT-S/S16 hybrid @ 384x384. - """ - backbone = _resnetv2((2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_r20_s16_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_r20_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -217,18 +207,9 @@ def vit_small_r26_s32_224(pretrained=False, **kwargs): """ R26+ViT-S/S32 hybrid. """ backbone = _resnetv2((2, 2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_r26_s32_224', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_small_r26_s32_384(pretrained=False, **kwargs): - """ R26+ViT-S/S32 hybrid @ 384x384. - """ - backbone = _resnetv2((2, 2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_r26_s32_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_r26_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -237,18 +218,9 @@ def vit_base_r20_s16_224(pretrained=False, **kwargs): """ R20+ViT-B/S16 hybrid. """ backbone = _resnetv2((2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_base_r20_s16_224', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_base_r20_s16_384(pretrained=False, **kwargs): - """ R20+ViT-B/S16 hybrid. - """ - backbone = _resnetv2((2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_base_r20_s16_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_r20_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -257,18 +229,9 @@ def vit_base_r26_s32_224(pretrained=False, **kwargs): """ R26+ViT-B/S32 hybrid. """ backbone = _resnetv2((2, 2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_base_r26_s32_224', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_base_r26_s32_384(pretrained=False, **kwargs): - """ R26+ViT-B/S32 hybrid. - """ - backbone = _resnetv2((2, 2, 2, 2), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_base_r26_s32_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_r26_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -277,8 +240,9 @@ def vit_base_r50_s16_224(pretrained=False, **kwargs): """ R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929). """ backbone = _resnetv2((3, 4, 9), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_base_r50_s16_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_r50_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -287,29 +251,9 @@ def vit_large_r50_s32_224(pretrained=False, **kwargs): """ R50+ViT-L/S32 hybrid. """ backbone = _resnetv2((3, 4, 6, 3), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_large_r50_s32_224', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs): - """ R50+ViT-L/S32 hybrid. - """ - backbone = _resnetv2((3, 4, 6, 3), **kwargs) - model_kwargs = dict( - embed_dim=768, depth=12, num_heads=12, representation_size=768, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_large_r50_s32_224_in21k', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def vit_large_r50_s32_384(pretrained=False, **kwargs): - """ R50+ViT-L/S32 hybrid. - """ - backbone = _resnetv2((3, 4, 6, 3), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_large_r50_s32_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_large_r50_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -318,8 +262,9 @@ def vit_small_resnet26d_224(pretrained=False, **kwargs): """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights. """ backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) - model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_resnet26d_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_resnet26d_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -328,8 +273,9 @@ def vit_small_resnet50d_s16_224(pretrained=False, **kwargs): """ Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights. """ backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3]) - model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_small_resnet50d_s16_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_small_resnet50d_s16_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -338,8 +284,9 @@ def vit_base_resnet26d_224(pretrained=False, **kwargs): """ Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights. """ backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_base_resnet26d_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_resnet26d_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model @@ -348,6 +295,7 @@ def vit_base_resnet50d_224(pretrained=False, **kwargs): """ Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights. """ backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) - model = _create_vision_transformer('vit_base_resnet50d_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer_hybrid( + 'vit_base_resnet50d_224', backbone=backbone, pretrained=pretrained, **model_kwargs) return model \ No newline at end of file