|
|
|
@ -103,48 +103,90 @@ default_cfgs = {
|
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=0.9, first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
|
|
|
|
|
# hybrid in-1k models (weights ported from official Google JAX impl where they exist)
|
|
|
|
|
'vit_tiny_r_s16_p8_224': _cfg(
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_tiny_r_s16_p8_224_in21k': _cfg(
|
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_tiny_r_s16_p8_384': _cfg(
|
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
|
|
|
|
|
'vit_small_r_s16_p8_224': _cfg(
|
|
|
|
|
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_small_r_s16_p8_224_in21k': _cfg(
|
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_small_r_s16_p8_384': _cfg(
|
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
|
|
|
|
|
'vit_small_r20_s16_p2_224': _cfg(
|
|
|
|
|
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_small_r20_s16_p2_224_in21k': _cfg(
|
|
|
|
|
inum_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_small_r20_s16_p2_384': _cfg(
|
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'vit_small_r20_s16_224': _cfg(
|
|
|
|
|
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_small_r20_s16_224_in21k': _cfg(
|
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_small_r20_s16_384': _cfg(
|
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
|
|
|
|
|
'vit_small_r26_s32_224': _cfg(
|
|
|
|
|
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_small_r26_s32_224_in21k': _cfg(
|
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_small_r26_s32_384': _cfg(
|
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
|
|
|
|
|
'vit_base_r20_s16_224': _cfg(
|
|
|
|
|
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_base_r20_s16_224_in21k': _cfg(
|
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_base_r20_s16_384': _cfg(
|
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
|
|
|
|
|
'vit_base_r26_s32_224': _cfg(
|
|
|
|
|
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_base_r26_s32_224_in21k': _cfg(
|
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_base_r26_s32_384': _cfg(
|
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
|
|
|
|
|
'vit_base_r50_s16_224': _cfg(
|
|
|
|
|
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_base_r50_s16_384': _cfg(
|
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
|
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
|
|
|
|
|
'vit_large_r50_s32_224': _cfg(
|
|
|
|
|
input_size=(3, 224, 224), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_large_r50_s32_224_in21k': _cfg(
|
|
|
|
|
num_classes=21843, mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
|
first_conv='patch_embed.backbone.stem.conv'),
|
|
|
|
|
'vit_large_r50_s32_384': _cfg(
|
|
|
|
|
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0,
|
|
|
|
@ -159,8 +201,19 @@ default_cfgs = {
|
|
|
|
|
# deit models (FB weights)
|
|
|
|
|
'vit_deit_tiny_patch16_224': _cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'),
|
|
|
|
|
'vit_deit_tiny_patch16_224_in21k': _cfg(num_classes=21843),
|
|
|
|
|
'vit_deit_tiny_patch16_224_in21k_norep': _cfg(num_classes=21843),
|
|
|
|
|
'vit_deit_tiny_patch16_384': _cfg(input_size=(3, 384, 384)),
|
|
|
|
|
|
|
|
|
|
'vit_deit_small_patch16_224': _cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'),
|
|
|
|
|
'vit_deit_small_patch16_224_in21k': _cfg(num_classes=21843),
|
|
|
|
|
'vit_deit_small_patch16_384': _cfg(input_size=(3, 384, 384)),
|
|
|
|
|
|
|
|
|
|
'vit_deit_small_patch32_224': _cfg(),
|
|
|
|
|
'vit_deit_small_patch32_224_in21k': _cfg(num_classes=21843),
|
|
|
|
|
'vit_deit_small_patch32_384': _cfg(input_size=(3, 384, 384)),
|
|
|
|
|
|
|
|
|
|
'vit_deit_base_patch16_224': _cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth',),
|
|
|
|
|
'vit_deit_base_patch16_384': _cfg(
|
|
|
|
@ -728,7 +781,29 @@ def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
|
|
|
|
|
backbone = _resnetv2(layers=(), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=8, embed_dim=192, depth=12, num_heads=3, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_small_r20_s16_p2_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_tiny_r_s16_p8_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_tiny_r_s16_p8_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2(layers=(), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=8, embed_dim=192, depth=12, num_heads=3, representation_size=192, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_tiny_r_s16_p8_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_tiny_r_s16_p8_384(pretrained=False, **kwargs):
|
|
|
|
|
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2(layers=(), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=8, embed_dim=192, depth=12, num_heads=3, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_tiny_r_s16_p8_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -740,6 +815,29 @@ def vit_small_r_s16_p8_224(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=8, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_small_r_s16_p8_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_small_r_s16_p8_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
""" R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2(layers=(), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=8, embed_dim=384, depth=12, num_heads=6, representation_size=384, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_small_r_s16_p8_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_small_r_s16_p8_384(pretrained=False, **kwargs):
|
|
|
|
|
""" R+ViT-S/S16 w/ 8x8 patch hybrid @ 224 x 224.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2(layers=(), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=8, embed_dim=384, depth=12, num_heads=6, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_small_r_s16_p8_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -754,6 +852,17 @@ def vit_small_r20_s16_p2_224(pretrained=False, **kwargs):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_small_r20_s16_p2_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
""" R52+ViT-S/S16 w/ 2x2 patch hybrid @ 224 x 224.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((2, 4), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=2, embed_dim=384, depth=12, num_heads=6, representation_size=384, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_small_r20_s16_p2_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_small_r20_s16_p2_384(pretrained=False, **kwargs):
|
|
|
|
|
""" R20+ViT-S/S16 w/ 2x2 Patch hybrid @ 384x384.
|
|
|
|
@ -775,6 +884,16 @@ def vit_small_r20_s16_224(pretrained=False, **kwargs):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_small_r20_s16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
""" R20+ViT-S/S16 hybrid.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((2, 2, 2), **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=384, depth=12, num_heads=6, representation_size=384, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_small_r20_s16_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_small_r20_s16_384(pretrained=False, **kwargs):
|
|
|
|
|
""" R20+ViT-S/S16 hybrid @ 384x384.
|
|
|
|
@ -795,6 +914,17 @@ def vit_small_r26_s32_224(pretrained=False, **kwargs):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_small_r26_s32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
""" R26+ViT-S/S32 hybrid.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
embed_dim=384, depth=12, num_heads=6, representation_size=384, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_small_r26_s32_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_small_r26_s32_384(pretrained=False, **kwargs):
|
|
|
|
|
""" R26+ViT-S/S32 hybrid @ 384x384.
|
|
|
|
@ -810,12 +940,22 @@ def vit_base_r20_s16_224(pretrained=False, **kwargs):
|
|
|
|
|
""" R20+ViT-B/S16 hybrid.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((2, 2, 2), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, act_layer=nn.SiLU, **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_base_r20_s16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_base_r20_s16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
""" R20+ViT-B/S16 hybrid.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((2, 2, 2), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
embed_dim=768, depth=12, num_heads=12, representation_size=768, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_base_r20_s16_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_base_r20_s16_384(pretrained=False, **kwargs):
|
|
|
|
|
""" R20+ViT-B/S16 hybrid.
|
|
|
|
@ -836,6 +976,27 @@ def vit_base_r26_s32_224(pretrained=False, **kwargs):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_base_r26_s32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
""" R26+ViT-B/S32 hybrid.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
embed_dim=768, depth=12, num_heads=12, representation_size=768, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_base_r26_s32_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_base_r26_s32_384(pretrained=False, **kwargs):
|
|
|
|
|
""" R26+ViT-B/S32 hybrid.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((2, 2, 2, 2), **kwargs)
|
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_base_r26_s32_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_base_r50_s16_224(pretrained=False, **kwargs):
|
|
|
|
|
""" R50+ViT-B/S16 hybrid from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
@ -867,6 +1028,17 @@ def vit_large_r50_s32_224(pretrained=False, **kwargs):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_large_r50_s32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
""" R50+ViT-L/S32 hybrid.
|
|
|
|
|
"""
|
|
|
|
|
backbone = _resnetv2((3, 4, 6, 3), **kwargs)
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
embed_dim=768, depth=12, num_heads=12, representation_size=768, hybrid_backbone=backbone, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_large_r50_s32_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_large_r50_s32_384(pretrained=False, **kwargs):
|
|
|
|
|
""" R50+ViT-L/S32 hybrid.
|
|
|
|
@ -927,6 +1099,31 @@ def vit_deit_tiny_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_deit_tiny_patch16_224_in21k_norep(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-tiny model"""
|
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_deit_tiny_patch16_224_in21k_norep', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_deit_tiny_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-tiny model"""
|
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, representation_size=192, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_deit_tiny_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_deit_tiny_patch16_384(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-tiny model"""
|
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_deit_tiny_patch16_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_deit_small_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
|
|
|
@ -937,6 +1134,48 @@ def vit_deit_small_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_deit_small_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-small """
|
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, representation_size=384, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_deit_small_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_deit_small_patch16_384(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-small """
|
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_deit_small_patch16_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_deit_small_patch32_224(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
|
|
|
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_deit_small_patch32_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_deit_small_patch32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-small """
|
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, representation_size=384, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_deit_small_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_deit_small_patch32_384(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-small """
|
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
|
|
|
|
model = _create_vision_transformer('vit_deit_small_patch32_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def vit_deit_base_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
|
|
|
|