|
|
@ -121,6 +121,15 @@ default_cfgs = {
|
|
|
|
'vit_deit_base_patch16_384': _cfg(
|
|
|
|
'vit_deit_base_patch16_384': _cfg(
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
|
|
|
'vit_deit_tiny_distilled_patch16_224': _cfg(
|
|
|
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth'),
|
|
|
|
|
|
|
|
'vit_deit_small_distilled_patch16_224': _cfg(
|
|
|
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth'),
|
|
|
|
|
|
|
|
'vit_deit_base_distilled_patch16_224': _cfg(
|
|
|
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', ),
|
|
|
|
|
|
|
|
'vit_deit_base_distilled_patch16_384': _cfg(
|
|
|
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
|
|
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -367,6 +376,53 @@ class VisionTransformer(nn.Module):
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistilledVisionTransformer(VisionTransformer):
|
|
|
|
|
|
|
|
""" Vision Transformer with distillation token.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Paper: `Training data-efficient image transformers & distillation through attention` -
|
|
|
|
|
|
|
|
https://arxiv.org/abs/2012.12877
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
This impl of distilled ViT is taken from https://github.com/facebookresearch/deit
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
|
|
|
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
|
|
|
|
|
|
|
num_patches = self.patch_embed.num_patches
|
|
|
|
|
|
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
|
|
|
|
|
|
|
|
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
trunc_normal_(self.dist_token, std=.02)
|
|
|
|
|
|
|
|
trunc_normal_(self.pos_embed, std=.02)
|
|
|
|
|
|
|
|
self.head_dist.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
|
|
|
B = x.shape[0]
|
|
|
|
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
|
|
|
|
|
|
|
dist_token = self.dist_token.expand(B, -1, -1)
|
|
|
|
|
|
|
|
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = x + self.pos_embed
|
|
|
|
|
|
|
|
x = self.pos_drop(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for blk in self.blocks:
|
|
|
|
|
|
|
|
x = blk(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
|
|
|
return x[:, 0], x[:, 1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
x, x_dist = self.forward_features(x)
|
|
|
|
|
|
|
|
x = self.head(x)
|
|
|
|
|
|
|
|
x_dist = self.head_dist(x_dist)
|
|
|
|
|
|
|
|
if self.training:
|
|
|
|
|
|
|
|
return x, x_dist
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
# during inference, return the average of both classifier predictions
|
|
|
|
|
|
|
|
return (x + x_dist) / 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resize_pos_embed(posemb, posemb_new):
|
|
|
|
def resize_pos_embed(posemb, posemb_new):
|
|
|
|
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
|
|
|
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
|
|
|
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
|
|
|
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
|
|
@ -396,7 +452,8 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
|
|
|
|
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
|
|
|
|
# For old models that I trained prior to conv based patchification
|
|
|
|
# For old models that I trained prior to conv based patchification
|
|
|
|
v = v.reshape(model.patch_embed.proj.weight.shape)
|
|
|
|
O, I, H, W = model.patch_embed.proj.weight.shape
|
|
|
|
|
|
|
|
v = v.reshape(O, -1, H, W)
|
|
|
|
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
|
|
|
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
|
|
|
# To resize pos embedding when using model at different size from pretrained weights
|
|
|
|
# To resize pos embedding when using model at different size from pretrained weights
|
|
|
|
v = resize_pos_embed(v, model.pos_embed)
|
|
|
|
v = resize_pos_embed(v, model.pos_embed)
|
|
|
@ -404,7 +461,7 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
return out_dict
|
|
|
|
return out_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_vision_transformer(variant, pretrained=False, **kwargs):
|
|
|
|
def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs):
|
|
|
|
default_cfg = default_cfgs[variant]
|
|
|
|
default_cfg = default_cfgs[variant]
|
|
|
|
default_num_classes = default_cfg['num_classes']
|
|
|
|
default_num_classes = default_cfg['num_classes']
|
|
|
|
default_img_size = default_cfg['input_size'][-1]
|
|
|
|
default_img_size = default_cfg['input_size'][-1]
|
|
|
@ -418,7 +475,8 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
|
|
|
|
_logger.warning("Removing representation layer for fine-tuning.")
|
|
|
|
_logger.warning("Removing representation layer for fine-tuning.")
|
|
|
|
repr_size = None
|
|
|
|
repr_size = None
|
|
|
|
|
|
|
|
|
|
|
|
model = VisionTransformer(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs)
|
|
|
|
model_cls = DistilledVisionTransformer if distilled else VisionTransformer
|
|
|
|
|
|
|
|
model = model_cls(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs)
|
|
|
|
model.default_cfg = default_cfg
|
|
|
|
model.default_cfg = default_cfg
|
|
|
|
|
|
|
|
|
|
|
|
if pretrained:
|
|
|
|
if pretrained:
|
|
|
@ -446,7 +504,7 @@ def vit_base_patch16_224(pretrained=False, **kwargs):
|
|
|
|
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -455,7 +513,7 @@ def vit_base_patch16_224(pretrained=False, **kwargs):
|
|
|
|
def vit_base_patch32_224(pretrained=False, **kwargs):
|
|
|
|
def vit_base_patch32_224(pretrained=False, **kwargs):
|
|
|
|
""" ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
|
|
|
|
""" ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
|
|
|
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -465,7 +523,7 @@ def vit_base_patch16_384(pretrained=False, **kwargs):
|
|
|
|
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
|
|
|
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -475,9 +533,7 @@ def vit_base_patch32_384(pretrained=False, **kwargs):
|
|
|
|
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
|
|
|
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(
|
|
|
|
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
|
|
|
|
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
|
|
|
|
model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -487,7 +543,7 @@ def vit_large_patch16_224(pretrained=False, **kwargs):
|
|
|
|
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -496,7 +552,7 @@ def vit_large_patch16_224(pretrained=False, **kwargs):
|
|
|
|
def vit_large_patch32_224(pretrained=False, **kwargs):
|
|
|
|
def vit_large_patch32_224(pretrained=False, **kwargs):
|
|
|
|
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
|
|
|
|
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
|
|
|
|
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -506,7 +562,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
|
|
|
|
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
|
|
|
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -516,7 +572,7 @@ def vit_large_patch32_384(pretrained=False, **kwargs):
|
|
|
|
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
|
|
|
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
|
|
|
|
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -527,7 +583,7 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(
|
|
|
|
model_kwargs = dict(
|
|
|
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, representation_size=768, **kwargs)
|
|
|
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -538,7 +594,7 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(
|
|
|
|
model_kwargs = dict(
|
|
|
|
patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, representation_size=768, **kwargs)
|
|
|
|
patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -549,7 +605,7 @@ def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(
|
|
|
|
model_kwargs = dict(
|
|
|
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, representation_size=1024, **kwargs)
|
|
|
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -560,7 +616,7 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(
|
|
|
|
model_kwargs = dict(
|
|
|
|
patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, representation_size=1024, **kwargs)
|
|
|
|
patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -572,7 +628,7 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
|
|
|
|
NOTE: converted weights not currently available, too large for github release hosting.
|
|
|
|
NOTE: converted weights not currently available, too large for github release hosting.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(
|
|
|
|
model_kwargs = dict(
|
|
|
|
patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, representation_size=1280, **kwargs)
|
|
|
|
patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -587,7 +643,7 @@ def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
|
|
|
|
layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
|
|
|
|
layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
|
|
|
|
preact=False, stem_type='same', conv_layer=StdConv2dSame)
|
|
|
|
preact=False, stem_type='same', conv_layer=StdConv2dSame)
|
|
|
|
model_kwargs = dict(
|
|
|
|
model_kwargs = dict(
|
|
|
|
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone,
|
|
|
|
embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone,
|
|
|
|
representation_size=768, **kwargs)
|
|
|
|
representation_size=768, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_resnet50_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_resnet50_224_in21k', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
@ -602,7 +658,7 @@ def vit_base_resnet50_384(pretrained=False, **kwargs):
|
|
|
|
backbone = ResNetV2(
|
|
|
|
backbone = ResNetV2(
|
|
|
|
layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
|
|
|
|
layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
|
|
|
|
preact=False, stem_type='same', conv_layer=StdConv2dSame)
|
|
|
|
preact=False, stem_type='same', conv_layer=StdConv2dSame)
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_resnet50_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_resnet50_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -611,7 +667,7 @@ def vit_base_resnet50_384(pretrained=False, **kwargs):
|
|
|
|
def vit_small_resnet26d_224(pretrained=False, **kwargs):
|
|
|
|
def vit_small_resnet26d_224(pretrained=False, **kwargs):
|
|
|
|
""" Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.
|
|
|
|
""" Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
backbone = resnet26d(pretrained=pretrained, features_only=True, out_indices=[4])
|
|
|
|
backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_small_resnet26d_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_small_resnet26d_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
@ -621,7 +677,7 @@ def vit_small_resnet26d_224(pretrained=False, **kwargs):
|
|
|
|
def vit_small_resnet50d_s3_224(pretrained=False, **kwargs):
|
|
|
|
def vit_small_resnet50d_s3_224(pretrained=False, **kwargs):
|
|
|
|
""" Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights.
|
|
|
|
""" Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
backbone = resnet50d(pretrained=pretrained, features_only=True, out_indices=[3])
|
|
|
|
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3])
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_small_resnet50d_s3_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_small_resnet50d_s3_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
@ -631,8 +687,8 @@ def vit_small_resnet50d_s3_224(pretrained=False, **kwargs):
|
|
|
|
def vit_base_resnet26d_224(pretrained=False, **kwargs):
|
|
|
|
def vit_base_resnet26d_224(pretrained=False, **kwargs):
|
|
|
|
""" Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights.
|
|
|
|
""" Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
backbone = resnet26d(pretrained=pretrained, features_only=True, out_indices=[4])
|
|
|
|
backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_resnet26d_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_resnet26d_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -641,8 +697,8 @@ def vit_base_resnet26d_224(pretrained=False, **kwargs):
|
|
|
|
def vit_base_resnet50d_224(pretrained=False, **kwargs):
|
|
|
|
def vit_base_resnet50d_224(pretrained=False, **kwargs):
|
|
|
|
""" Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
|
|
|
|
""" Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
backbone = resnet50d(pretrained=pretrained, features_only=True, out_indices=[4])
|
|
|
|
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_resnet50d_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_base_resnet50d_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -652,7 +708,7 @@ def vit_deit_tiny_patch16_224(pretrained=False, **kwargs):
|
|
|
|
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
|
|
|
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
|
|
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
|
|
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, **kwargs)
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -662,7 +718,7 @@ def vit_deit_small_patch16_224(pretrained=False, **kwargs):
|
|
|
|
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
|
|
|
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
|
|
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
|
|
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -672,7 +728,7 @@ def vit_deit_base_patch16_224(pretrained=False, **kwargs):
|
|
|
|
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
|
|
|
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
|
|
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
|
|
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
@ -682,6 +738,50 @@ def vit_deit_base_patch16_384(pretrained=False, **kwargs):
|
|
|
|
""" DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
|
|
|
|
""" DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
|
|
|
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
|
|
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
|
|
|
|
|
|
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
|
|
|
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
|
|
|
|
'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
|
|
|
|
|
|
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
|
|
|
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
|
|
|
|
'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
|
|
|
|
|
|
|
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
|
|
|
|
'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
|
|
|
|
|
|
|
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
|
|
|
|
'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
|
|
|
|
|
|
|
|
return model
|