|
|
|
@ -1,7 +1,10 @@
|
|
|
|
|
""" DeiT - Data-efficient Image Transformers
|
|
|
|
|
|
|
|
|
|
DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below
|
|
|
|
|
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
|
|
|
|
|
|
|
|
|
paper: `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
|
|
|
|
|
|
|
|
|
paper: `DeiT III: Revenge of the ViT` - https://arxiv.org/abs/2204.07118
|
|
|
|
|
|
|
|
|
|
Modifications copyright 2021, Ross Wightman
|
|
|
|
|
"""
|
|
|
|
@ -53,6 +56,46 @@ default_cfgs = {
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0,
|
|
|
|
|
classifier=('head', 'head_dist')),
|
|
|
|
|
|
|
|
|
|
'deit3_small_patch16_224': _cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_1k.pth'),
|
|
|
|
|
'deit3_small_patch16_384': _cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth',
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
'deit3_base_patch16_224': _cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth'),
|
|
|
|
|
'deit3_base_patch16_384': _cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_1k.pth',
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
'deit3_large_patch16_224': _cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_1k.pth'),
|
|
|
|
|
'deit3_large_patch16_384': _cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_1k.pth',
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
'deit3_huge_patch14_224': _cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_1k.pth'),
|
|
|
|
|
|
|
|
|
|
'deit3_small_patch16_224_in21ft1k': _cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_21k.pth',
|
|
|
|
|
crop_pct=1.0),
|
|
|
|
|
'deit3_small_patch16_384_in21ft1k': _cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth',
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
'deit3_base_patch16_224_in21ft1k': _cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth',
|
|
|
|
|
crop_pct=1.0),
|
|
|
|
|
'deit3_base_patch16_384_in21ft1k': _cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_21k.pth',
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
'deit3_large_patch16_224_in21ft1k': _cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_21k.pth',
|
|
|
|
|
crop_pct=1.0),
|
|
|
|
|
'deit3_large_patch16_384_in21ft1k': _cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_21k.pth',
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
'deit3_huge_patch14_224_in21ft1k': _cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_21k_v1.pth',
|
|
|
|
|
crop_pct=1.0),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -68,9 +111,10 @@ class VisionTransformerDistilled(VisionTransformer):
|
|
|
|
|
super().__init__(*args, **kwargs, weight_init='skip')
|
|
|
|
|
assert self.global_pool in ('token',)
|
|
|
|
|
|
|
|
|
|
self.num_tokens = 2
|
|
|
|
|
self.num_prefix_tokens = 2
|
|
|
|
|
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
|
|
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim))
|
|
|
|
|
self.pos_embed = nn.Parameter(
|
|
|
|
|
torch.zeros(1, self.patch_embed.num_patches + self.num_prefix_tokens, self.embed_dim))
|
|
|
|
|
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
|
|
|
|
self.distilled_training = False # must set this True to train w/ distillation token
|
|
|
|
|
|
|
|
|
@ -220,3 +264,157 @@ def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
|
|
|
|
|
model = _create_deit(
|
|
|
|
|
'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def deit3_small_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
|
|
|
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
|
|
|
|
|
model = _create_deit('deit3_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def deit3_small_patch16_384(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
|
|
|
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
|
|
|
|
|
model = _create_deit('deit3_small_patch16_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def deit3_base_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
|
|
|
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
|
|
|
|
|
model = _create_deit('deit3_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def deit3_base_patch16_384(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
|
|
|
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
|
|
|
|
|
model = _create_deit('deit3_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def deit3_large_patch16_224(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
|
|
|
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
|
|
|
|
model = _create_deit('deit3_large_patch16_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def deit3_large_patch16_384(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
|
|
|
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
|
|
|
|
model = _create_deit('deit3_large_patch16_384', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def deit3_huge_patch14_224(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
|
|
|
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
|
|
|
|
model = _create_deit('deit3_huge_patch14_224', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def deit3_small_patch16_224_in21ft1k(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
|
|
|
|
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
|
|
|
|
|
model = _create_deit('deit3_small_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def deit3_small_patch16_384_in21ft1k(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
|
|
|
|
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
|
|
|
|
|
model = _create_deit('deit3_small_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def deit3_base_patch16_224_in21ft1k(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
|
|
|
|
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
|
|
|
|
|
model = _create_deit('deit3_base_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def deit3_base_patch16_384_in21ft1k(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
|
|
|
|
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
|
|
|
|
|
model = _create_deit('deit3_base_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def deit3_large_patch16_224_in21ft1k(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
|
|
|
|
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
|
|
|
|
model = _create_deit('deit3_large_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def deit3_large_patch16_384_in21ft1k(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
|
|
|
|
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
|
|
|
|
model = _create_deit('deit3_large_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def deit3_huge_patch14_224_in21ft1k(pretrained=False, **kwargs):
|
|
|
|
|
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
|
|
|
|
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
|
|
|
|
"""
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
|
|
|
|
model = _create_deit('deit3_huge_patch14_224_in21ft1k', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return model
|
|
|
|
|