|
|
|
@ -35,26 +35,34 @@ def _cfg(url='', **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_cfgs = {
|
|
|
|
|
# hybrid in-21k models (weights ported from official Google JAX impl where they exist)
|
|
|
|
|
'vit_base_r50_s16_224_in21k': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
|
|
|
|
|
num_classes=21843, crop_pct=0.9),
|
|
|
|
|
|
|
|
|
|
# hybrid in-1k models (weights ported from official JAX impl)
|
|
|
|
|
'vit_base_r50_s16_384': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
|
|
|
|
|
# hybrid in-1k models (mostly untrained, experimental configs w/ resnetv2 stdconv backbones)
|
|
|
|
|
# hybrid in-1k models (weights ported from official JAX impl where they exist)
|
|
|
|
|
'vit_tiny_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'),
|
|
|
|
|
'vit_tiny_r_s16_p8_384': _cfg(
|
|
|
|
|
first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
'vit_small_r_s16_p8_224': _cfg(first_conv='patch_embed.backbone.conv'),
|
|
|
|
|
'vit_small_r20_s16_p2_224': _cfg(),
|
|
|
|
|
'vit_small_r20_s16_224': _cfg(),
|
|
|
|
|
'vit_small_r26_s32_224': _cfg(),
|
|
|
|
|
'vit_small_r26_s32_384': _cfg(
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
'vit_base_r20_s16_224': _cfg(),
|
|
|
|
|
'vit_base_r26_s32_224': _cfg(),
|
|
|
|
|
'vit_base_r50_s16_224': _cfg(),
|
|
|
|
|
'vit_base_r50_s16_384': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
'vit_large_r50_s32_224': _cfg(),
|
|
|
|
|
'vit_large_r50_s32_384': _cfg(),
|
|
|
|
|
|
|
|
|
|
# hybrid in-21k models (weights ported from official Google JAX impl where they exist)
|
|
|
|
|
'vit_small_r26_s32_224_in21k': _cfg(
|
|
|
|
|
num_classes=21843, crop_pct=0.9),
|
|
|
|
|
'vit_small_r26_s32_384_in21k': _cfg(
|
|
|
|
|
num_classes=21843, input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
'vit_base_r50_s16_224_in21k': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
|
|
|
|
|
num_classes=21843, crop_pct=0.9),
|
|
|
|
|
'vit_large_r50_s32_224_in21k': _cfg(num_classes=21843, crop_pct=0.9),
|
|
|
|
|
|
|
|
|
|
# hybrid models (using timm resnet backbones)
|
|
|
|
|
'vit_small_resnet26d_224': _cfg(
|
|
|
|
@ -99,7 +107,8 @@ class HybridEmbed(nn.Module):
|
|
|
|
|
else:
|
|
|
|
|
feature_dim = self.backbone.num_features
|
|
|
|
|
assert feature_size[0] % patch_size[0] == 0 and feature_size[1] % patch_size[1] == 0
|
|
|
|
|
self.num_patches = feature_size[0] // patch_size[0] * feature_size[1] // patch_size[1]
|
|
|
|
|
self.grid_size = (feature_size[0] // patch_size[0], feature_size[1] // patch_size[1])
|
|
|
|
|
self.num_patches = self.grid_size[0] * self.grid_size[1]
|
|
|
|
|
self.proj = nn.Conv2d(feature_dim, embed_dim, kernel_size=patch_size, stride=patch_size)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
@ -133,37 +142,35 @@ def _resnetv2(layers=(3, 4, 9), **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
""" R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
|
def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
|
|
|
|
|
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2(layers=(3, 4, 9), **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
|
|
|
|
|
backbone = _resnetv2(layers=(), **kwargs)
|
|
|
|
|
model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
|
'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
|
'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_base_r50_s16_384(pretrained=False, **kwargs):
|
|
|
|
|
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
|
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
|
|
|
|
def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs):
|
|
|
|
|
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((3, 4, 9), **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
|
backbone = _resnetv2(layers=(), **kwargs)
|
|
|
|
|
model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
|
'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
|
'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
|
|
|
|
|
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
|
|
|
|
def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs):
|
|
|
|
|
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 384 x 384.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2(layers=(), **kwargs)
|
|
|
|
|
model_kwargs = dict(patch_size=8, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
|
'vit_tiny_r_s16_p8_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
|
'vit_tiny_r_s16_p8_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -212,6 +219,17 @@ def vit_small_r26_s32_224(pretrained=False, **kwargs):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_small_r26_s32_384(pretrained=False, **kwargs):
|
|
|
|
|
""" R26+ViT-S/S32 hybrid.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
|
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
|
'vit_small_r26_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_base_r20_s16_224(pretrained=False, **kwargs):
|
|
|
|
|
""" R20+ViT-B/S16 hybrid.
|
|
|
|
@ -245,17 +263,74 @@ def vit_base_r50_s16_224(pretrained=False, **kwargs):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_base_r50_s16_384(pretrained=False, **kwargs):
|
|
|
|
|
""" R50+ViT-B/16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
|
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((3, 4, 9), **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
|
'vit_base_r50_s16_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_large_r50_s32_224(pretrained=False, **kwargs):
|
|
|
|
|
""" R50+ViT-L/S32 hybrid.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
|
'vit_large_r50_s32_224', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_large_r50_s32_384(pretrained=False, **kwargs):
|
|
|
|
|
""" R50+ViT-L/S32 hybrid.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
|
'vit_large_r50_s32_384', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
""" R26+ViT-S/S32 hybrid.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
|
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
|
'vit_small_r26_s32_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_small_r26_s32_384_in21k(pretrained=False, **kwargs):
|
|
|
|
|
""" R26+ViT-S/S32 hybrid.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, **kwargs)
|
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
|
'vit_small_r26_s32_384_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
""" R50+ViT-B/16 hybrid model from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2(layers=(3, 4, 9), **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
|
|
|
|
|
model = _create_vision_transformer_hybrid(
|
|
|
|
|
'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_small_resnet26d_224(pretrained=False, **kwargs):
|
|
|
|
|
""" Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.
|
|
|
|
|