|
|
|
@ -28,9 +28,9 @@ import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
|
from .helpers import load_pretrained
|
|
|
|
|
from .layers import StdConv2dSame, DropPath, to_2tuple, trunc_normal_
|
|
|
|
|
from .layers import StdConv2dSame, StdConv2d, DropPath, to_2tuple, trunc_normal_
|
|
|
|
|
from .resnet import resnet26d, resnet50d
|
|
|
|
|
from .resnetv2 import ResNetV2
|
|
|
|
|
from .resnetv2 import ResNetV2, create_resnetv2_stem
|
|
|
|
|
from .registry import register_model
|
|
|
|
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
@ -97,17 +97,62 @@ default_cfgs = {
|
|
|
|
|
url='', # FIXME I have weights for this but > 2GB limit for github release binaries
|
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
|
|
|
|
|
|
|
|
|
|
# hybrid models (weights ported from official Google JAX impl)
|
|
|
|
|
'vit_base_resnet50_224_in21k': _cfg(
|
|
|
|
|
# hybrid in-21k models (weights ported from official Google JAX impl where they exist)
|
|
|
|
|
'vit_base_r50_s16_224_in21k': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
|
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_base_resnet50_384': _cfg(
|
|
|
|
|
|
|
|
|
|
# hybrid in-1k models (weights ported from official Google JAX impl where they exist)
|
|
|
|
|
'vit_small_r_s16_p8_224': _cfg(
|
|
|
|
|
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_small_r20_s16_p2_224': _cfg(
|
|
|
|
|
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_small_r20_s16_p2_384': _cfg(
|
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_small_r20_s16_224': _cfg(
|
|
|
|
|
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_small_r20_s16_384': _cfg(
|
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_small_r26_s32_224': _cfg(
|
|
|
|
|
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_small_r26_s32_384': _cfg(
|
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_base_r20_s16_224': _cfg(
|
|
|
|
|
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_base_r20_s16_384': _cfg(
|
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_base_r26_s32_224': _cfg(
|
|
|
|
|
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_base_r26_s32_384': _cfg(
|
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_base_r50_s16_224': _cfg(
|
|
|
|
|
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_base_r50_s16_384': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
|
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_large_r50_s32_224': _cfg(
|
|
|
|
|
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_large_r50_s32_384': _cfg(
|
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
|
|
|
|
|
# hybrid models (my experiments)
|
|
|
|
|
'vit_small_resnet26d_224': _cfg(),
|
|
|
|
|
'vit_small_resnet50d_s3_224': _cfg(),
|
|
|
|
|
'vit_small_resnet50d_s16_224': _cfg(),
|
|
|
|
|
'vit_base_resnet26d_224': _cfg(),
|
|
|
|
|
'vit_base_resnet50d_224': _cfg(),
|
|
|
|
|
|
|
|
|
@ -227,11 +272,13 @@ class HybridEmbed(nn.Module):
|
|
|
|
|
""" CNN Feature Map Embedding
|
|
|
|
|
Extract feature map from CNN, flatten, project to embedding dim.
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768):
|
|
|
|
|
def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768):
|
|
|
|
|
super().__init__()
|
|
|
|
|
assert isinstance(backbone, nn.Module)
|
|
|
|
|
img_size = to_2tuple(img_size)
|
|
|
|
|
patch_size = to_2tuple(patch_size)
|
|
|
|
|
self.img_size = img_size
|
|
|
|
|
self.patch_size = patch_size
|
|
|
|
|
self.backbone = backbone
|
|
|
|
|
if feature_size is None:
|
|
|
|
|
with torch.no_grad():
|
|
|
|
@ -253,8 +300,9 @@ class HybridEmbed(nn.Module):
|
|
|
|
|
feature_dim = self.backbone.feature_info.channels()[-1]
|
|
|
|
|
else:
|
|
|
|
|
feature_dim = self.backbone.num_features
|
|
|
|
|
self.num_patches = feature_size[0] * feature_size[1]
|
|
|
|
|
self.proj = nn.Conv2d(feature_dim, embed_dim, 1)
|
|
|
|
|
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
|
|
|
|
|
self.num_patches = feature_size[0] // patch_size[0] * feature_size[1] // patch_size[1]
|
|
|
|
|
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.backbone(x)
|
|
|
|
@ -270,9 +318,10 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
|
|
|
|
https://arxiv.org/abs/2010.11929
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
|
|
|
|
def __init__(self, img_size=224, patch_size=None, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
|
|
|
|
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
|
|
|
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None):
|
|
|
|
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None,
|
|
|
|
|
act_layer=None):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
img_size (int, tuple): input image size
|
|
|
|
@ -296,10 +345,12 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
|
|
|
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
|
|
|
|
act_layer = act_layer or nn.GELU
|
|
|
|
|
patch_size = patch_size or 1 if hybrid_backbone is not None else 16
|
|
|
|
|
|
|
|
|
|
if hybrid_backbone is not None:
|
|
|
|
|
self.patch_embed = HybridEmbed(
|
|
|
|
|
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
|
|
|
|
|
hybrid_backbone, img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
|
|
|
|
else:
|
|
|
|
|
self.patch_embed = PatchEmbed(
|
|
|
|
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
|
|
|
@ -313,7 +364,7 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
self.blocks = nn.ModuleList([
|
|
|
|
|
Block(
|
|
|
|
|
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
|
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
|
|
|
|
|
for i in range(depth)])
|
|
|
|
|
self.norm = norm_layer(embed_dim)
|
|
|
|
|
|
|
|
|
@ -423,13 +474,15 @@ class DistilledVisionTransformer(VisionTransformer):
|
|
|
|
|
return (x + x_dist) / 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resize_pos_embed(posemb, posemb_new):
|
|
|
|
|
def resize_pos_embed(posemb, posemb_new, token='class'):
|
|
|
|
|
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
|
|
|
|
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
|
|
|
|
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
|
|
|
|
ntok_new = posemb_new.shape[1]
|
|
|
|
|
if True:
|
|
|
|
|
posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
|
|
|
|
|
if token:
|
|
|
|
|
assert token in ('class', 'distill')
|
|
|
|
|
token_idx = 2 if token == 'distill' else 1
|
|
|
|
|
posemb_tok, posemb_grid = posemb[:, :token_idx], posemb[0, token_idx:]
|
|
|
|
|
ntok_new -= 1
|
|
|
|
|
else:
|
|
|
|
|
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
|
|
|
@ -633,33 +686,190 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _resnetv2(layers=(3, 4, 9), **kwargs):
|
|
|
|
|
""" ResNet-V2 backbone helper"""
|
|
|
|
|
padding_same = kwargs.get('padding_same', True)
|
|
|
|
|
if padding_same:
|
|
|
|
|
stem_type = 'same'
|
|
|
|
|
conv_layer = StdConv2dSame
|
|
|
|
|
else:
|
|
|
|
|
stem_type = ''
|
|
|
|
|
conv_layer = StdConv2d
|
|
|
|
|
if len(layers):
|
|
|
|
|
backbone = ResNetV2(
|
|
|
|
|
layers=layers, num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
|
|
|
|
|
preact=False, stem_type=stem_type, conv_layer=conv_layer)
|
|
|
|
|
else:
|
|
|
|
|
backbone = create_resnetv2_stem(
|
|
|
|
|
kwargs.get('in_chans', 3), stem_type=stem_type, preact=False, conv_layer=conv_layer)
|
|
|
|
|
return backbone
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
""" R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
|
"""
|
|
|
|
|
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
|
|
|
|
|
backbone = ResNetV2(
|
|
|
|
|
layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
|
|
|
|
|
preact=False, stem_type='same', conv_layer=StdConv2dSame)
|
|
|
|
|
backbone = _resnetv2(layers=(3, 4, 9), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, representation_size=768, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_base_r50_s16_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
|
|
|
|
|
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2(layers=(), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=8, embed_dim=192, depth=12, num_heads=3, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_small_r20_s16_p2_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_small_r_s16_p8_224(pretrained=False, **kwargs):
|
|
|
|
|
""" R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2(layers=(), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=8, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_small_r_s16_p8_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_small_r20_s16_p2_224(pretrained=False, **kwargs):
|
|
|
|
|
""" R52+ViT-S/S16 w/ 2x2 patch hybrid @ 224 x 224.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((2, 4), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone,
|
|
|
|
|
representation_size=768, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_base_resnet50_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
patch_size=2, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_small_r20_s16_p2_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_small_r20_s16_p2_384(pretrained=False, **kwargs):
|
|
|
|
|
""" R20+ViT-S/S16 w/ 2x2 Patch hybrid @ 384x384.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((2, 4), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
embed_dim=384, patch_size=2, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_small_r20_s16_p2_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_small_r20_s16_224(pretrained=False, **kwargs):
|
|
|
|
|
""" R20+ViT-S/S16 hybrid.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((2, 2, 2), **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_small_r20_s16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_base_resnet50_384(pretrained=False, **kwargs):
|
|
|
|
|
def vit_small_r20_s16_384(pretrained=False, **kwargs):
|
|
|
|
|
""" R20+ViT-S/S16 hybrid @ 384x384.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((2, 2, 2), **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_small_r20_s16_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_small_r26_s32_224(pretrained=False, **kwargs):
|
|
|
|
|
""" R26+ViT-S/S32 hybrid.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_small_r26_s32_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_small_r26_s32_384(pretrained=False, **kwargs):
|
|
|
|
|
""" R26+ViT-S/S32 hybrid @ 384x384.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_small_r26_s32_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_base_r20_s16_224(pretrained=False, **kwargs):
|
|
|
|
|
""" R20+ViT-B/S16 hybrid.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((2, 2, 2), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, act_layer=nn.SiLU, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_base_r20_s16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_base_r20_s16_384(pretrained=False, **kwargs):
|
|
|
|
|
""" R20+ViT-B/S16 hybrid.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((2, 2, 2), **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_base_r20_s16_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_base_r26_s32_224(pretrained=False, **kwargs):
|
|
|
|
|
""" R26+ViT-B/S32 hybrid.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_base_r26_s32_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_base_r50_s16_224(pretrained=False, **kwargs):
|
|
|
|
|
""" R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((3, 4, 9), **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_base_r50_s16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_base_r50_s16_384(pretrained=False, **kwargs):
|
|
|
|
|
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
|
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
|
|
|
|
"""
|
|
|
|
|
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head
|
|
|
|
|
backbone = ResNetV2(
|
|
|
|
|
layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
|
|
|
|
|
preact=False, stem_type='same', conv_layer=StdConv2dSame)
|
|
|
|
|
backbone = _resnetv2((3, 4, 9), **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_base_r50_s16_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_large_r50_s32_224(pretrained=False, **kwargs):
|
|
|
|
|
""" R50+ViT-L/S32 hybrid.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_large_r50_s32_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_large_r50_s32_384(pretrained=False, **kwargs):
|
|
|
|
|
""" R50+ViT-L/S32 hybrid.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_base_resnet50_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_large_r50_s32_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -674,12 +884,12 @@ def vit_small_resnet26d_224(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_small_resnet50d_s3_224(pretrained=False, **kwargs):
|
|
|
|
|
def vit_small_resnet50d_s16_224(pretrained=False, **kwargs):
|
|
|
|
|
""" Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights.
|
|
|
|
|
"""
|
|
|
|
|
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3])
|
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_small_resnet50d_s3_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_small_resnet50d_s16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|