From bb50ac470867eb681f688350e046d9abf5ad3bb8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 25 Jan 2021 11:05:23 -0800 Subject: [PATCH] Add DeiT distilled weights and distilled model def. Remove some redudant ViT model args. --- timm/models/vision_transformer.py | 160 ++++++++++++++++++++++++------ 1 file changed, 130 insertions(+), 30 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 90122090..ff2510f1 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -121,6 +121,15 @@ default_cfgs = { 'vit_deit_base_patch16_384': _cfg( url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', input_size=(3, 384, 384), crop_pct=1.0), + 'vit_deit_tiny_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth'), + 'vit_deit_small_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth'), + 'vit_deit_base_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', ), + 'vit_deit_base_distilled_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', + input_size=(3, 384, 384), crop_pct=1.0), } @@ -367,6 +376,53 @@ class VisionTransformer(nn.Module): return x +class DistilledVisionTransformer(VisionTransformer): + """ Vision Transformer with distillation token. + + Paper: `Training data-efficient image transformers & distillation through attention` - + https://arxiv.org/abs/2012.12877 + + This impl of distilled ViT is taken from https://github.com/facebookresearch/deit + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + num_patches = self.patch_embed.num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() + + trunc_normal_(self.dist_token, std=.02) + trunc_normal_(self.pos_embed, std=.02) + self.head_dist.apply(self._init_weights) + + def forward_features(self, x): + B = x.shape[0] + x = self.patch_embed(x) + + cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks + dist_token = self.dist_token.expand(B, -1, -1) + x = torch.cat((cls_tokens, dist_token, x), dim=1) + + x = x + self.pos_embed + x = self.pos_drop(x) + + for blk in self.blocks: + x = blk(x) + + x = self.norm(x) + return x[:, 0], x[:, 1] + + def forward(self, x): + x, x_dist = self.forward_features(x) + x = self.head(x) + x_dist = self.head_dist(x_dist) + if self.training: + return x, x_dist + else: + # during inference, return the average of both classifier predictions + return (x + x_dist) / 2 + + def resize_pos_embed(posemb, posemb_new): # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 @@ -396,7 +452,8 @@ def checkpoint_filter_fn(state_dict, model): for k, v in state_dict.items(): if 'patch_embed.proj.weight' in k and len(v.shape) < 4: # For old models that I trained prior to conv based patchification - v = v.reshape(model.patch_embed.proj.weight.shape) + O, I, H, W = model.patch_embed.proj.weight.shape + v = v.reshape(O, -1, H, W) elif k == 'pos_embed' and v.shape != model.pos_embed.shape: # To resize pos embedding when using model at different size from pretrained weights v = resize_pos_embed(v, model.pos_embed) @@ -404,7 +461,7 @@ def checkpoint_filter_fn(state_dict, model): return out_dict -def _create_vision_transformer(variant, pretrained=False, **kwargs): +def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs): default_cfg = default_cfgs[variant] default_num_classes = default_cfg['num_classes'] default_img_size = default_cfg['input_size'][-1] @@ -418,7 +475,8 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs): _logger.warning("Removing representation layer for fine-tuning.") repr_size = None - model = VisionTransformer(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs) + model_cls = DistilledVisionTransformer if distilled else VisionTransformer + model = model_cls(img_size=img_size, num_classes=num_classes, representation_size=repr_size, **kwargs) model.default_cfg = default_cfg if pretrained: @@ -446,7 +504,7 @@ def vit_base_patch16_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) return model @@ -455,7 +513,7 @@ def vit_base_patch16_224(pretrained=False, **kwargs): def vit_base_patch32_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. """ - model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs) + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs) return model @@ -465,7 +523,7 @@ def vit_base_patch16_384(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs) return model @@ -475,9 +533,7 @@ def vit_base_patch32_384(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict( - patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, - norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) return model @@ -487,7 +543,7 @@ def vit_large_patch16_224(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs) + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) return model @@ -496,7 +552,7 @@ def vit_large_patch16_224(pretrained=False, **kwargs): def vit_large_patch32_224(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. """ - model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs) + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) return model @@ -506,7 +562,7 @@ def vit_large_patch16_384(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs) + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs) return model @@ -516,7 +572,7 @@ def vit_large_patch32_384(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs) + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) return model @@ -527,7 +583,7 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, representation_size=768, **kwargs) + patch_size=16, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -538,7 +594,7 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict( - patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, representation_size=768, **kwargs) + patch_size=32, embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -549,7 +605,7 @@ def vit_large_patch16_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, representation_size=1024, **kwargs) + patch_size=16, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -560,7 +616,7 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict( - patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, representation_size=1024, **kwargs) + patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -572,7 +628,7 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): NOTE: converted weights not currently available, too large for github release hosting. """ model_kwargs = dict( - patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, representation_size=1280, **kwargs) + patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs) model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -587,7 +643,7 @@ def vit_base_resnet50_224_in21k(pretrained=False, **kwargs): layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3), preact=False, stem_type='same', conv_layer=StdConv2dSame) model_kwargs = dict( - embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, + embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, representation_size=768, **kwargs) model = _create_vision_transformer('vit_base_resnet50_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -602,7 +658,7 @@ def vit_base_resnet50_384(pretrained=False, **kwargs): backbone = ResNetV2( layers=(3, 4, 9), num_classes=0, global_pool='', in_chans=kwargs.get('in_chans', 3), preact=False, stem_type='same', conv_layer=StdConv2dSame) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) model = _create_vision_transformer('vit_base_resnet50_384', pretrained=pretrained, **model_kwargs) return model @@ -611,7 +667,7 @@ def vit_base_resnet50_384(pretrained=False, **kwargs): def vit_small_resnet26d_224(pretrained=False, **kwargs): """ Custom ViT small hybrid w/ ResNet26D stride 32. No pretrained weights. """ - backbone = resnet26d(pretrained=pretrained, features_only=True, out_indices=[4]) + backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) model = _create_vision_transformer('vit_small_resnet26d_224', pretrained=pretrained, **model_kwargs) return model @@ -621,7 +677,7 @@ def vit_small_resnet26d_224(pretrained=False, **kwargs): def vit_small_resnet50d_s3_224(pretrained=False, **kwargs): """ Custom ViT small hybrid w/ ResNet50D 3-stages, stride 16. No pretrained weights. """ - backbone = resnet50d(pretrained=pretrained, features_only=True, out_indices=[3]) + backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[3]) model_kwargs = dict(embed_dim=768, depth=8, num_heads=8, mlp_ratio=3, hybrid_backbone=backbone, **kwargs) model = _create_vision_transformer('vit_small_resnet50d_s3_224', pretrained=pretrained, **model_kwargs) return model @@ -631,8 +687,8 @@ def vit_small_resnet50d_s3_224(pretrained=False, **kwargs): def vit_base_resnet26d_224(pretrained=False, **kwargs): """ Custom ViT base hybrid w/ ResNet26D stride 32. No pretrained weights. """ - backbone = resnet26d(pretrained=pretrained, features_only=True, out_indices=[4]) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs) + backbone = resnet26d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) model = _create_vision_transformer('vit_base_resnet26d_224', pretrained=pretrained, **model_kwargs) return model @@ -641,8 +697,8 @@ def vit_base_resnet26d_224(pretrained=False, **kwargs): def vit_base_resnet50d_224(pretrained=False, **kwargs): """ Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights. """ - backbone = resnet50d(pretrained=pretrained, features_only=True, out_indices=[4]) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs) + backbone = resnet50d(pretrained=pretrained, in_chans=kwargs.get('in_chans', 3), features_only=True, out_indices=[4]) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, hybrid_backbone=backbone, **kwargs) model = _create_vision_transformer('vit_base_resnet50d_224', pretrained=pretrained, **model_kwargs) return model @@ -652,7 +708,7 @@ def vit_deit_tiny_patch16_224(pretrained=False, **kwargs): """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ - model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, **kwargs) + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) model = _create_vision_transformer('vit_deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) return model @@ -662,7 +718,7 @@ def vit_deit_small_patch16_224(pretrained=False, **kwargs): """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, **kwargs) + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) model = _create_vision_transformer('vit_deit_small_patch16_224', pretrained=pretrained, **model_kwargs) return model @@ -672,7 +728,7 @@ def vit_deit_base_patch16_224(pretrained=False, **kwargs): """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_deit_base_patch16_224', pretrained=pretrained, **model_kwargs) return model @@ -682,6 +738,50 @@ def vit_deit_base_patch16_384(pretrained=False, **kwargs): """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). ImageNet-1k weights from https://github.com/facebookresearch/deit. """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_deit_base_patch16_384', pretrained=pretrained, **model_kwargs) return model + + +@register_model +def vit_deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): + """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_vision_transformer( + 'vit_deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + return model + + +@register_model +def vit_deit_small_distilled_patch16_224(pretrained=False, **kwargs): + """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_vision_transformer( + 'vit_deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + return model + + +@register_model +def vit_deit_base_distilled_patch16_224(pretrained=False, **kwargs): + """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer( + 'vit_deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + return model + + +@register_model +def vit_deit_base_distilled_patch16_384(pretrained=False, **kwargs): + """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_vision_transformer( + 'vit_deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) + return model \ No newline at end of file