|
|
@ -118,6 +118,17 @@ default_cfgs = {
|
|
|
|
'vit_deit_base_distilled_patch16_384': _cfg(
|
|
|
|
'vit_deit_base_distilled_patch16_384': _cfg(
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')),
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ViT ImageNet-21K-P pretraining
|
|
|
|
|
|
|
|
'vit_base_patch16_224_21k_miil': _cfg(
|
|
|
|
|
|
|
|
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/vit_base_patch16_224_21k_miil.pth',
|
|
|
|
|
|
|
|
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
'vit_base_patch16_224_1k_miil': _cfg(
|
|
|
|
|
|
|
|
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm'
|
|
|
|
|
|
|
|
'/vit_base_patch16_224_1k_miil_84_4.pth',
|
|
|
|
|
|
|
|
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
|
|
|
|
|
|
|
|
),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -155,7 +166,7 @@ class Attention(nn.Module):
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
B, N, C = x.shape
|
|
|
|
B, N, C = x.shape
|
|
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
|
|
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
|
|
|
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
|
|
|
|
|
|
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
@ -652,7 +663,7 @@ def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
|
|
|
'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -663,7 +674,7 @@ def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
|
|
|
'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -674,7 +685,7 @@ def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
|
|
|
'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -687,3 +698,21 @@ def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs):
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
model = _create_vision_transformer(
|
|
|
|
'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
|
|
|
|
'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_base_patch16_224_21k_miil(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
|
|
|
|
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
|
|
|
|
|
|
|
|
model = _create_vision_transformer('vit_base_patch16_224_21k_miil', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def vit_base_patch16_224_1k_miil(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
|
|
|
|
|
|
|
|
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
|
|
|
|
|
|
|
|
model = _create_vision_transformer('vit_base_patch16_224_1k_miil_84_4', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
return model
|