Add numerous experimental ViT Hybrid models w/ ResNetV2 base. Update the ViT naming for hybrids. Fix #426 for pretrained vit resizing.

pull/450/head
Ross Wightman 3 years ago
parent 0e16d4e9fb
commit f0ffdf89b3

@ -274,7 +274,9 @@ class ResNetStage(nn.Module):
return x return x
def create_stem(in_chs, out_chs, stem_type='', preact=True, conv_layer=None, norm_layer=None): def create_resnetv2_stem(
in_chs, out_chs=64, stem_type='', preact=True,
conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
stem = OrderedDict() stem = OrderedDict()
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same') assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same')
@ -322,7 +324,7 @@ class ResNetV2(nn.Module):
self.feature_info = [] self.feature_info = []
stem_chs = make_div(stem_chs * wf) stem_chs = make_div(stem_chs * wf)
self.stem = create_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer) self.stem = create_resnetv2_stem(in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
# NOTE no, reduction 2 feature if preact # NOTE no, reduction 2 feature if preact
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module='' if preact else 'stem.norm')) self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module='' if preact else 'stem.norm'))

@ -28,9 +28,9 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained from .helpers import load_pretrained
from .layers import StdConv2dSame, DropPath, to_2tuple, trunc_normal_ from .layers import StdConv2dSame, StdConv2d, DropPath, to_2tuple, trunc_normal_
from .resnet import resnet26d, resnet50d from .resnet import resnet26d, resnet50d
from .resnetv2 import ResNetV2 from .resnetv2 import ResNetV2, create_resnetv2_stem
from .registry import register_model from .registry import register_model
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -97,17 +97,62 @@ default_cfgs = {
url='', # FIXME I have weights for this but > 2GB limit for github release binaries url='', # FIXME I have weights for this but > 2GB limit for github release binaries
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
# hybrid models (weights ported from official Google JAX impl) # hybrid in-21k models (weights ported from official Google JAX impl where they exist)
'vit_base_resnet50_224_in21k': _cfg( 'vit_base_r50_s16_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, first_conv='patch_embed.backbone.stem.conv'), num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, first_conv='patch_embed.backbone.stem.conv'),
'vit_base_resnet50_384': _cfg(
# hybrid in-1k models (weights ported from official Google JAX impl where they exist)
'vit_small_r_s16_p8_224': _cfg(
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_small_r20_s16_p2_224': _cfg(
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_small_r20_s16_p2_384': _cfg(
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_small_r20_s16_224': _cfg(
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_small_r20_s16_384': _cfg(
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_small_r26_s32_224': _cfg(
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_small_r26_s32_384': _cfg(
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_base_r20_s16_224': _cfg(
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_base_r20_s16_384': _cfg(
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_base_r26_s32_224': _cfg(
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_base_r26_s32_384': _cfg(
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_base_r50_s16_224': _cfg(
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_base_r50_s16_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0, first_conv='patch_embed.backbone.stem.conv'), input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_large_r50_s32_224': _cfg(
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
'vit_large_r50_s32_384': _cfg(
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
first_conv='patch_embed.backbone.stem.conv'),
# hybrid models (my experiments) # hybrid models (my experiments)
'vit_small_resnet26d_224': _cfg(), 'vit_small_resnet26d_224': _cfg(),
'vit_small_resnet50d_s3_224': _cfg(), 'vit_small_resnet50d_s16_224': _cfg(),
'vit_base_resnet26d_224': _cfg(), 'vit_base_resnet26d_224': _cfg(),
'vit_base_resnet50d_224': _cfg(), 'vit_base_resnet50d_224': _cfg(),
@ -227,11 +272,13 @@ class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding """ CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim. Extract feature map from CNN, flatten, project to embedding dim.
""" """
def __init__(self, backbone, img_size=224, feature_size=None, in_chans=3, embed_dim=768): def __init__(self, backbone, img_size=224, patch_size=1, feature_size=None, in_chans=3, embed_dim=768):
super().__init__() super().__init__()
assert isinstance(backbone, nn.Module) assert isinstance(backbone, nn.Module)
img_size = to_2tuple(img_size) img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size self.img_size = img_size
self.patch_size = patch_size
self.backbone = backbone self.backbone = backbone
if feature_size is None: if feature_size is None:
with torch.no_grad(): with torch.no_grad():
@ -253,8 +300,9 @@ class HybridEmbed(nn.Module):
feature_dim = self.backbone.feature_info.channels()[-1] feature_dim = self.backbone.feature_info.channels()[-1]
else: else:
feature_dim = self.backbone.num_features feature_dim = self.backbone.num_features
self.num_patches = feature_size[0] * feature_size[1] assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
self.proj = nn.Conv2d(feature_dim, embed_dim, 1) self.num_patches = feature_size[0] // patch_size[0] * feature_size[1] // patch_size[1]
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x): def forward(self, x):
x = self.backbone(x) x = self.backbone(x)
@ -270,9 +318,10 @@ class VisionTransformer(nn.Module):
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
https://arxiv.org/abs/2010.11929 https://arxiv.org/abs/2010.11929
""" """
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, def __init__(self, img_size=224, patch_size=None, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None): drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None,
act_layer=None):
""" """
Args: Args:
img_size (int, tuple): input image size img_size (int, tuple): input image size
@ -296,10 +345,12 @@ class VisionTransformer(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
patch_size = patch_size or 1 if hybrid_backbone is not None else 16
if hybrid_backbone is not None: if hybrid_backbone is not None:
self.patch_embed = HybridEmbed( self.patch_embed = HybridEmbed(
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim) hybrid_backbone, img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
else: else:
self.patch_embed = PatchEmbed( self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
@ -313,7 +364,7 @@ class VisionTransformer(nn.Module):
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
Block( Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)]) for i in range(depth)])
self.norm = norm_layer(embed_dim) self.norm = norm_layer(embed_dim)
@ -423,13 +474,15 @@ class DistilledVisionTransformer(VisionTransformer):
return (x + x_dist) / 2 return (x + x_dist) / 2
def resize_pos_embed(posemb, posemb_new): def resize_pos_embed(posemb, posemb_new, token='class'):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from # Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
ntok_new = posemb_new.shape[1] ntok_new = posemb_new.shape[1]
if True: if token:
posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] assert token in ('class', 'distill')
token_idx = 2 if token == 'distill' else 1
posemb_tok, posemb_grid = posemb[:, :token_idx], posemb[0, token_idx:]
ntok_new -= 1 ntok_new -= 1
else: else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0] posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
@ -633,33 +686,190 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
return model return model
def _resnetv2(layers=(3, 4, 9), **kwargs):
""" ResNet-V2 backbone helper"""
padding_same = kwargs.get('padding_same', True)
if padding_same:
stem_type = 'same'
conv_layer = StdConv2dSame
else:
stem_type = ''
conv_layer = StdConv2d
if len(layers):
backbone = ResNetV2(
layers=layers, num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
preact=False, stem_type=stem_type, conv_layer=conv_layer)
else:
backbone = create_resnetv2_stem(
kwargs.get('in_chans', 3), stem_type=stem_type, preact=False, conv_layer=conv_layer)
return backbone
@register_model @register_model
def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
""" R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929). """ R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
""" """
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head backbone = _resnetv2(layers=(3, 4, 9), **kwargs)
backbone = ResNetV2( model_kwargs = dict(
layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3), embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, representation_size=768, **kwargs)
preact=False, stem_type='same', conv_layer=StdConv2dSame) model = _create_vision_transformer('vit_base_r50_s16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
"""
backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(
patch_size=8, embed_dim=192, depth=12, num_heads=3, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_r20_s16_p2_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r_s16_p8_224(pretrained=False, **kwargs):
""" R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224.
"""
backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(
patch_size=8, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_r_s16_p8_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r20_s16_p2_224(pretrained=False, **kwargs):
""" R52+ViT-S/S16 w/ 2x2 patch hybrid @ 224 x 224.
"""
backbone = _resnetv2((2, 4), **kwargs)
model_kwargs = dict( model_kwargs = dict(
embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, patch_size=2, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
representation_size=768, **kwargs) model = _create_vision_transformer('vit_small_r20_s16_p2_224', pretrained=pretrained, **model_kwargs)
model = _create_vision_transformer('vit_base_resnet50_224_in21k', pretrained=pretrained, **model_kwargs) return model
@register_model
def vit_small_r20_s16_p2_384(pretrained=False, **kwargs):
""" R20+ViT-S/S16 w/ 2x2 Patch hybrid @ 384x384.
"""
backbone = _resnetv2((2, 4), **kwargs)
model_kwargs = dict(
embed_dim=384, patch_size=2, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_r20_s16_p2_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r20_s16_224(pretrained=False, **kwargs):
""" R20+ViT-S/S16 hybrid.
"""
backbone = _resnetv2((2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_r20_s16_224', pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_base_resnet50_384(pretrained=False, **kwargs): def vit_small_r20_s16_384(pretrained=False, **kwargs):
""" R20+ViT-S/S16 hybrid @ 384x384.
"""
backbone = _resnetv2((2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_r20_s16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r26_s32_224(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid.
"""
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_r26_s32_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r26_s32_384(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid @ 384x384.
"""
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_r26_s32_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_r20_s16_224(pretrained=False, **kwargs):
""" R20+ViT-B/S16 hybrid.
"""
backbone = _resnetv2((2, 2, 2), **kwargs)
model_kwargs = dict(
embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, act_layer=nn.SiLU, **kwargs)
model = _create_vision_transformer('vit_base_r20_s16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_r20_s16_384(pretrained=False, **kwargs):
""" R20+ViT-B/S16 hybrid.
"""
backbone = _resnetv2((2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_r20_s16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_r26_s32_224(pretrained=False, **kwargs):
""" R26+ViT-B/S32 hybrid.
"""
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_r26_s32_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_r50_s16_224(pretrained=False, **kwargs):
""" R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
"""
backbone = _resnetv2((3, 4, 9), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_r50_s16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_r50_s16_384(pretrained=False, **kwargs):
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929). """ R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
""" """
# create a ResNetV2 w/o pre-activation, that uses StdConv and GroupNorm and has 3 stages, no head backbone = _resnetv2((3, 4, 9), **kwargs)
backbone = ResNetV2( model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3), model = _create_vision_transformer('vit_base_r50_s16_384', pretrained=pretrained, **model_kwargs)
preact=False, stem_type='same', conv_layer=StdConv2dSame) return model
@register_model
def vit_large_r50_s32_224(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid.
"""
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_large_r50_s32_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_large_r50_s32_384(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid.
"""
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_resnet50_384', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_large_r50_s32_384', pretrained=pretrained, **model_kwargs)
return model return model
@ -674,12 +884,12 @@ def vit_small_resnet26d_224(pretrained=False, **kwargs):
@register_model @register_model
def vit_small_resnet50d_s3_224(pretrained=False, **kwargs): def vit_small_resnet50d_s16_224(pretrained=False, **kwargs):
""" Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights. """ Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights.
""" """
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3]) backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3])
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_resnet50d_s3_224', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_small_resnet50d_s16_224', pretrained=pretrained, **model_kwargs)
return model return model

Loading…
Cancel
Save