Add DeiT distilled weights and distilled model def. Remove some redudant ViT model args.

pull/323/head
Ross Wightman 4 years ago
parent c16e965037
commit bb50ac4708

@ -121,6 +121,15 @@ default_cfgs = {
'vit_deit_base_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_deit_tiny_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth'),
'vit_deit_small_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth'),
'vit_deit_base_distilled_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', ),
'vit_deit_base_distilled_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
input_size=(3, 384, 384), crop_pct=1.0),
}
@ -367,6 +376,53 @@ class VisionTransformer(nn.Module):
return x
class DistilledVisionTransformer(VisionTransformer):
""" Vision Transformer with distillation token.
Paper: `Training data-efficient image transformers & distillation through attention` -
https://arxiv.org/abs/2012.12877
This impl of distilled ViT is taken from https://github.com/facebookresearch/deit
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
trunc_normal_(self.dist_token, std=.02)
trunc_normal_(self.pos_embed, std=.02)
self.head_dist.apply(self._init_weights)
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
x = x + self.pos_embed
x = self.pos_drop(x)
for blk in self.blocks:
x = blk(x)
x = self.norm(x)
return x[:, 0], x[:, 1]
def forward(self, x):
x, x_dist = self.forward_features(x)
x = self.head(x)
x_dist = self.head_dist(x_dist)
if self.training:
return x, x_dist
else:
# during inference, return the average of both classifier predictions
return (x + x_dist) / 2
def resize_pos_embed(posemb, posemb_new):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
@ -396,7 +452,8 @@ def checkpoint_filter_fn(state_dict, model):
for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
# For old models that I trained prior to conv based patchification
v = v.reshape(model.patch_embed.proj.weight.shape)
O, I, H, W = model.patch_embed.proj.weight.shape
v = v.reshape(O, -1, H, W)
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
# To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed(v, model.pos_embed)
@ -404,7 +461,7 @@ def checkpoint_filter_fn(state_dict, model):
return out_dict
def _create_vision_transformer(variant, pretrained=False, **kwargs):
def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs):
default_cfg = default_cfgs[variant]
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-1]
@ -418,7 +475,8 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
_logger.warning("Removing representation layer for fine-tuning.")
repr_size = None
model = VisionTransformer(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs)
model_cls = DistilledVisionTransformer if distilled else VisionTransformer
model = model_cls(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs)
model.default_cfg = default_cfg
if pretrained:
@ -446,7 +504,7 @@ def vit_base_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@ -455,7 +513,7 @@ def vit_base_patch16_224(pretrained=False, **kwargs):
def vit_base_patch32_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
"""
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
return model
@ -465,7 +523,7 @@ def vit_base_patch16_384(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@ -475,9 +533,7 @@ def vit_base_patch32_384(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
return model
@ -487,7 +543,7 @@ def vit_large_patch16_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@ -496,7 +552,7 @@ def vit_large_patch16_224(pretrained=False, **kwargs):
def vit_large_patch32_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
"""
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
return model
@ -506,7 +562,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@ -516,7 +572,7 @@ def vit_large_patch32_384(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
return model
@ -527,7 +583,7 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, representation_size=768, **kwargs)
patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@ -538,7 +594,7 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, representation_size=768, **kwargs)
patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@ -549,7 +605,7 @@ def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, representation_size=1024, **kwargs)
patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@ -560,7 +616,7 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(
patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, representation_size=1024, **kwargs)
patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@ -572,7 +628,7 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
NOTE: converted weights not currently available, too large for github release hosting.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, representation_size=1280, **kwargs)
patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@ -587,7 +643,7 @@ def vit_base_resnet50_224_in21k(pretrained=False, **kwargs):
layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
preact=False, stem_type='same', conv_layer=StdConv2dSame)
model_kwargs = dict(
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone,
embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone,
representation_size=768, **kwargs)
model = _create_vision_transformer('vit_base_resnet50_224_in21k', pretrained=pretrained, **model_kwargs)
return model
@ -602,7 +658,7 @@ def vit_base_resnet50_384(pretrained=False, **kwargs):
backbone = ResNetV2(
layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3),
preact=False, stem_type='same', conv_layer=StdConv2dSame)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_resnet50_384', pretrained=pretrained, **model_kwargs)
return model
@ -611,7 +667,7 @@ def vit_base_resnet50_384(pretrained=False, **kwargs):
def vit_small_resnet26d_224(pretrained=False, **kwargs):
""" Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights.
"""
backbone = resnet26d(pretrained=pretrained, features_only=True, out_indices=[4])
backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_resnet26d_224', pretrained=pretrained, **model_kwargs)
return model
@ -621,7 +677,7 @@ def vit_small_resnet26d_224(pretrained=False, **kwargs):
def vit_small_resnet50d_s3_224(pretrained=False, **kwargs):
""" Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights.
"""
backbone = resnet50d(pretrained=pretrained, features_only=True, out_indices=[3])
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3])
model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_small_resnet50d_s3_224', pretrained=pretrained, **model_kwargs)
return model
@ -631,8 +687,8 @@ def vit_small_resnet50d_s3_224(pretrained=False, **kwargs):
def vit_base_resnet26d_224(pretrained=False, **kwargs):
""" Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights.
"""
backbone = resnet26d(pretrained=pretrained, features_only=True, out_indices=[4])
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs)
backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_resnet26d_224', pretrained=pretrained, **model_kwargs)
return model
@ -641,8 +697,8 @@ def vit_base_resnet26d_224(pretrained=False, **kwargs):
def vit_base_resnet50d_224(pretrained=False, **kwargs):
""" Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
"""
backbone = resnet50d(pretrained=pretrained, features_only=True, out_indices=[4])
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs)
backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4])
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs)
model = _create_vision_transformer('vit_base_resnet50d_224', pretrained=pretrained, **model_kwargs)
return model
@ -652,7 +708,7 @@ def vit_deit_tiny_patch16_224(pretrained=False, **kwargs):
""" DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, **kwargs)
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@ -662,7 +718,7 @@ def vit_deit_small_patch16_224(pretrained=False, **kwargs):
""" DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs)
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@ -672,7 +728,7 @@ def vit_deit_base_patch16_224(pretrained=False, **kwargs):
""" DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@ -682,6 +738,50 @@ def vit_deit_base_patch16_384(pretrained=False, **kwargs):
""" DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs):
""" DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer(
'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
return model
@register_model
def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs):
""" DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer(
'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
return model
@register_model
def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs):
""" DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer(
'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs)
return model
@register_model
def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs):
""" DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer(
'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
return model
Loading…
Cancel
Save