Support npz custom load for vision transformer hybrid models. Add posembed rescale for npz load.

cleanup_xla_model_fixes
Ross Wightman 4 years ago
parent 8319e0c373
commit b9cfb64412

@ -27,7 +27,8 @@ class AvgPool2dSame(nn.AvgPool2d):
super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad) super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
def forward(self, x): def forward(self, x):
return avg_pool2d_same( x = pad_same(x, self.kernel_size, self.stride)
return F.avg_pool2d(
x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad) x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
@ -41,14 +42,15 @@ def max_pool2d_same(
class MaxPool2dSame(nn.MaxPool2d): class MaxPool2dSame(nn.MaxPool2d):
""" Tensorflow like 'SAME' wrapper for 2D max pooling """ Tensorflow like 'SAME' wrapper for 2D max pooling
""" """
def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False, count_include_pad=True): def __init__(self, kernel_size: int, stride=None, padding=0, dilation=1, ceil_mode=False):
kernel_size = to_2tuple(kernel_size) kernel_size = to_2tuple(kernel_size)
stride = to_2tuple(stride) stride = to_2tuple(stride)
dilation = to_2tuple(dilation) dilation = to_2tuple(dilation)
super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode, count_include_pad) super(MaxPool2dSame, self).__init__(kernel_size, stride, (0, 0), dilation, ceil_mode)
def forward(self, x): def forward(self, x):
return max_pool2d_same(x, self.kernel_size, self.stride, self.padding, self.dilation, self.ceil_mode) x = pad_same(x, self.kernel_size, self.stride, value=-float('inf'))
return F.max_pool2d(x, self.kernel_size, self.stride, (0, 0), self.dilation, self.ceil_mode)
def create_pool2d(pool_type, kernel_size, stride=None, **kwargs): def create_pool2d(pool_type, kernel_size, stride=None, **kwargs):

@ -52,6 +52,10 @@ default_cfgs = {
url='', url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
), ),
'vit_tiny_patch16_384': _cfg(
url='',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0
),
'vit_small_patch16_224': _cfg( 'vit_small_patch16_224': _cfg(
url='', url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
@ -60,6 +64,14 @@ default_cfgs = {
url='', url='',
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5),
), ),
'vit_small_patch16_384': _cfg(
url='',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0
),
'vit_small_patch32_384': _cfg(
url='',
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0
),
# patch models (weights ported from official Google JAX impl) # patch models (weights ported from official Google JAX impl)
'vit_base_patch16_224': _cfg( 'vit_base_patch16_224': _cfg(
@ -102,6 +114,7 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
'vit_huge_patch14_224_in21k': _cfg( 'vit_huge_patch14_224_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/imagenet21k/ViT-H_14.npz',
hf_hub='timm/vit_huge_patch14_224_in21k', hf_hub='timm/vit_huge_patch14_224_in21k',
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
@ -371,24 +384,53 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
import numpy as np import numpy as np
def _n2p(w, t=True): def _n2p(w, t=True):
if t and w.ndim == 4: if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
w = w.transpose([3, 2, 0, 1]) w = w.flatten()
elif t and w.ndim == 3: if t:
w = w.transpose([2, 0, 1]) if w.ndim == 4:
elif t and w.ndim == 2: w = w.transpose([3, 2, 0, 1])
w = w.transpose([1, 0]) elif w.ndim == 3:
w = w.transpose([2, 0, 1])
elif w.ndim == 2:
w = w.transpose([1, 0])
return torch.from_numpy(w) return torch.from_numpy(w)
w = np.load(checkpoint_path) w = np.load(checkpoint_path)
if not prefix: if not prefix and 'opt/target/embedding/kernel' in w:
prefix = 'opt/target/' if 'opt/target/embedding/kernel' in w else prefix prefix = 'opt/target/'
input_conv_w = adapt_input_conv( if hasattr(model.patch_embed, 'backbone'):
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel'])) # hybrid
model.patch_embed.proj.weight.copy_(input_conv_w) backbone = model.patch_embed.backbone
stem_only = not hasattr(backbone, 'stem')
stem = backbone if stem_only else backbone.stem
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
if not stem_only:
for i, stage in enumerate(backbone.stages):
for j, block in enumerate(stage.blocks):
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
for r in range(3):
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
if block.downsample is not None:
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
else:
embed_conv_w = adapt_input_conv(
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
model.patch_embed.proj.weight.copy_(embed_conv_w)
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias'])) model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False)) model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
model.pos_embed.copy_(_n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)) pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
if pos_embed_w.shape != model.pos_embed.shape:
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
model.pos_embed.copy_(pos_embed_w)
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
if model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: if model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
@ -396,23 +438,18 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
for i, block in enumerate(model.blocks.children()): for i, block in enumerate(model.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/' block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale'])) block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias'])) block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
block.attn.qkv.weight.copy_(torch.cat([ block.attn.qkv.weight.copy_(torch.cat([
_n2p(w[f'{mha_prefix}query/kernel'], t=False).flatten(1).T, _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
_n2p(w[f'{mha_prefix}key/kernel'], t=False).flatten(1).T,
_n2p(w[f'{mha_prefix}value/kernel'], t=False).flatten(1).T]))
block.attn.qkv.bias.copy_(torch.cat([ block.attn.qkv.bias.copy_(torch.cat([
_n2p(w[f'{mha_prefix}query/bias'], t=False).reshape(-1), _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
_n2p(w[f'{mha_prefix}key/bias'], t=False).reshape(-1),
_n2p(w[f'{mha_prefix}value/bias'], t=False).reshape(-1)]))
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1)) block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias'])) block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
block.mlp.fc1.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_0/kernel'])) for r in range(2):
block.mlp.fc1.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_0/bias'])) getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
block.mlp.fc2.weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_1/kernel'])) getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
block.mlp.fc2.bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_1/bias']))
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale'])) block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
@ -478,6 +515,7 @@ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kw
default_cfg=default_cfg, default_cfg=default_cfg,
representation_size=repr_size, representation_size=repr_size,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
pretrained_custom_load='npz' in default_cfg['url'],
**kwargs) **kwargs)
return model return model
@ -510,6 +548,16 @@ def vit_small_patch32_224(pretrained=False, **kwargs):
return model return model
@register_model
def vit_small_patch16_384(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
NOTE I've replaced my previous 'small' model definition and weights with the small variant from the DeiT paper
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_base_patch16_224(pretrained=False, **kwargs): def vit_base_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).

@ -35,26 +35,34 @@ def _cfg(url='', **kwargs):
default_cfgs = { default_cfgs = {
# hybrid in-21k models (weights ported from official Google JAX impl where they exist) # hybrid in-1k models (weights ported from official JAX impl where they exist)
'vit_base_r50_s16_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
num_classes=21843, crop_pct=0.9),
# hybrid in-1k models (weights ported from official JAX impl)
'vit_base_r50_s16_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
input_size=(3, 384, 384), crop_pct=1.0),
# hybrid in-1k models (mostly untrained, experimental configs w/ resnetv2 stdconv backbones)
'vit_tiny_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'), 'vit_tiny_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'),
'vit_tiny_r_s16_p8_384': _cfg(
first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0),
'vit_small_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'), 'vit_small_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'),
'vit_small_r20_s16_p2_224': _cfg(), 'vit_small_r20_s16_p2_224': _cfg(),
'vit_small_r20_s16_224': _cfg(), 'vit_small_r20_s16_224': _cfg(),
'vit_small_r26_s32_224': _cfg(), 'vit_small_r26_s32_224': _cfg(),
'vit_small_r26_s32_384': _cfg(
input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_r20_s16_224': _cfg(), 'vit_base_r20_s16_224': _cfg(),
'vit_base_r26_s32_224': _cfg(), 'vit_base_r26_s32_224': _cfg(),
'vit_base_r50_s16_224': _cfg(), 'vit_base_r50_s16_224': _cfg(),
'vit_base_r50_s16_384': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_r50_s32_224': _cfg(), 'vit_large_r50_s32_224': _cfg(),
'vit_large_r50_s32_384': _cfg(),
# hybrid in-21k models (weights ported from official Google JAX impl where they exist)
'vit_small_r26_s32_224_in21k': _cfg(
num_classes=21843, crop_pct=0.9),
'vit_small_r26_s32_384_in21k': _cfg(
num_classes=21843, input_size=(3, 384, 384), crop_pct=1.0),
'vit_base_r50_s16_224_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
num_classes=21843, crop_pct=0.9),
'vit_large_r50_s32_224_in21k': _cfg(num_classes=21843, crop_pct=0.9),
# hybrid models (using timm resnet backbones) # hybrid models (using timm resnet backbones)
'vit_small_resnet26d_224': _cfg( 'vit_small_resnet26d_224': _cfg(
@ -99,7 +107,8 @@ class HybridEmbed(nn.Module):
else: else:
feature_dim = self.backbone.num_features feature_dim = self.backbone.num_features
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0 assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
self.num_patches = feature_size[0] // patch_size[0] * feature_size[1] // patch_size[1] self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size) self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x): def forward(self, x):
@ -133,37 +142,35 @@ def _resnetv2(layers=(3, 4, 9), **kwargs):
@register_model @register_model
def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
""" R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929). """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
""" """
backbone = _resnetv2(layers=(3, 4, 9), **kwargs) backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer_hybrid( model = _create_vision_transformer_hybrid(
'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) 'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_base_r50_s16_384(pretrained=False, **kwargs): def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs):
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929). """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384.
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
""" """
backbone = _resnetv2((3, 4, 9), **kwargs) backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer_hybrid( model = _create_vision_transformer_hybrid(
'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs) 'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model return model
@register_model @register_model
def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs): def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs):
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224. """ R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384.
""" """
backbone = _resnetv2(layers=(), **kwargs) backbone = _resnetv2(layers=(), **kwargs)
model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs) model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer_hybrid( model = _create_vision_transformer_hybrid(
'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs) 'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model return model
@ -212,6 +219,17 @@ def vit_small_r26_s32_224(pretrained=False, **kwargs):
return model return model
@register_model
def vit_small_r26_s32_384(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid.
"""
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_small_r26_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_base_r20_s16_224(pretrained=False, **kwargs): def vit_base_r20_s16_224(pretrained=False, **kwargs):
""" R20+ViT-B/S16 hybrid. """ R20+ViT-B/S16 hybrid.
@ -245,17 +263,74 @@ def vit_base_r50_s16_224(pretrained=False, **kwargs):
return model return model
@register_model
def vit_base_r50_s16_384(pretrained=False, **kwargs):
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
backbone = _resnetv2((3, 4, 9), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_large_r50_s32_224(pretrained=False, **kwargs): def vit_large_r50_s32_224(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid. """ R50+ViT-L/S32 hybrid.
""" """
backbone = _resnetv2((3, 4, 6, 3), **kwargs) backbone = _resnetv2((3, 4, 6, 3), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer_hybrid( model = _create_vision_transformer_hybrid(
'vit_large_r50_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs) 'vit_large_r50_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model return model
@register_model
def vit_large_r50_s32_384(pretrained=False, **kwargs):
""" R50+ViT-L/S32 hybrid.
"""
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_large_r50_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid.
"""
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_small_r26_s32_384_in21k(pretrained=False, **kwargs):
""" R26+ViT-S/S32 hybrid.
"""
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_small_r26_s32_384_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
""" R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
backbone = _resnetv2(layers=(3, 4, 9), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
model = _create_vision_transformer_hybrid(
'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_small_resnet26d_224(pretrained=False, **kwargs): def vit_small_resnet26d_224(pretrained=False, **kwargs):
""" Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights. """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.

Loading…
Cancel
Save