From 7d4b3807d5c40b0f8d7e66d27a7672684e482996 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 4 Jul 2022 22:25:22 -0700 Subject: [PATCH] Support DeiT-3 (Revenge of the ViT) checkpoints. Add non-overlapping (w/ class token) pos-embed support to vit. --- timm/models/deit.py | 204 +++++++++++++++++++++++++++++- timm/models/vision_transformer.py | 64 +++++++--- 2 files changed, 247 insertions(+), 21 deletions(-) diff --git a/timm/models/deit.py b/timm/models/deit.py index e6b4b025..a2f43b91 100644 --- a/timm/models/deit.py +++ b/timm/models/deit.py @@ -1,7 +1,10 @@ """ DeiT - Data-efficient Image Transformers DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below -paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + +paper: `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + +paper: `DeiT III: Revenge of the ViT` - https://arxiv.org/abs/2204.07118 Modifications copyright 2021, Ross Wightman """ @@ -53,6 +56,46 @@ default_cfgs = { url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')), + + 'deit3_small_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_1k.pth'), + 'deit3_small_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_base_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth'), + 'deit3_base_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_large_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_1k.pth'), + 'deit3_large_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_1k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_huge_patch14_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_1k.pth'), + + 'deit3_small_patch16_224_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_21k.pth', + crop_pct=1.0), + 'deit3_small_patch16_384_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_base_patch16_224_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth', + crop_pct=1.0), + 'deit3_base_patch16_384_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_21k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_large_patch16_224_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_21k.pth', + crop_pct=1.0), + 'deit3_large_patch16_384_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_21k.pth', + input_size=(3, 384, 384), crop_pct=1.0), + 'deit3_huge_patch14_224_in21ft1k': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_21k_v1.pth', + crop_pct=1.0), } @@ -68,9 +111,10 @@ class VisionTransformerDistilled(VisionTransformer): super().__init__(*args, **kwargs, weight_init='skip') assert self.global_pool in ('token',) - self.num_tokens = 2 + self.num_prefix_tokens = 2 self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, self.patch_embed.num_patches + self.num_prefix_tokens, self.embed_dim)) self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() self.distilled_training = False # must set this True to train w/ distillation token @@ -220,3 +264,157 @@ def deit_base_distilled_patch16_384(pretrained=False, **kwargs): model = _create_deit( 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) return model + + +@register_model +def deit3_small_patch16_224(pretrained=False, **kwargs): + """ DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_small_patch16_384(pretrained=False, **kwargs): + """ DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_small_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_base_patch16_224(pretrained=False, **kwargs): + """ DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_base_patch16_384(pretrained=False, **kwargs): + """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_base_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_large_patch16_224(pretrained=False, **kwargs): + """ DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_large_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_large_patch16_384(pretrained=False, **kwargs): + """ DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_large_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_huge_patch14_224(pretrained=False, **kwargs): + """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_huge_patch14_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_small_patch16_224_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_small_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_small_patch16_384_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_small_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_base_patch16_224_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_base_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_base_patch16_384_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_base_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_large_patch16_224_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_large_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_large_patch16_384_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_large_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit3_huge_patch14_224_in21ft1k(pretrained=False, **kwargs): + """ DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118). + ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs) + model = _create_deit('deit3_huge_patch14_224_in21ft1k', pretrained=pretrained, **model_kwargs) + return model diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 8551feae..022052d0 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -325,8 +325,8 @@ class VisionTransformer(nn.Module): def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, - class_token=True, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', - embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): + class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): """ Args: img_size (int, tuple): input image size @@ -360,15 +360,17 @@ class VisionTransformer(nn.Module): self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.num_tokens = 1 if class_token else 0 + self.num_prefix_tokens = 1 if class_token else 0 + self.no_embed_class = no_embed_class self.grad_checkpointing = False self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if self.num_tokens > 0 else None - self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, embed_dim) * .02) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None + embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens + self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule @@ -428,11 +430,24 @@ class VisionTransformer(nn.Module): self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + def _pos_embed(self, x): + if self.no_embed_class: + # deit-3, updated JAX (big vision) + # position embedding does not overlap with class token, add then concat + x = x + self.pos_embed + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + else: + # original timm, JAX, and deit vit impl + # pos_embed has entry for class token, concat then add + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.pos_embed + return self.pos_drop(x) + def forward_features(self, x): x = self.patch_embed(x) - if self.cls_token is not None: - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) - x = self.pos_drop(x + self.pos_embed) + x = self._pos_embed(x) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: @@ -442,7 +457,7 @@ class VisionTransformer(nn.Module): def forward_head(self, x, pre_logits: bool = False): if self.global_pool: - x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.fc_norm(x) return x if pre_logits else self.head(x) @@ -556,7 +571,11 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False) if pos_embed_w.shape != model.pos_embed.shape: pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights - pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + pos_embed_w, + model.pos_embed, + getattr(model, 'num_prefix_tokens', 1), + model.patch_embed.grid_size + ) model.pos_embed.copy_(pos_embed_w) model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale'])) model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias'])) @@ -585,16 +604,16 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias'])) -def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): +def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()): # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) ntok_new = posemb_new.shape[1] - if num_tokens: - posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] - ntok_new -= num_tokens + if num_prefix_tokens: + posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:] + ntok_new -= num_prefix_tokens else: - posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + posemb_prefix, posemb_grid = posemb[:, :0], posemb[0] gs_old = int(math.sqrt(len(posemb_grid))) if not len(gs_new): # backwards compatibility gs_new = [int(math.sqrt(ntok_new))] * 2 @@ -603,25 +622,34 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False) posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) - posemb = torch.cat([posemb_tok, posemb_grid], dim=1) + posemb = torch.cat([posemb_prefix, posemb_grid], dim=1) return posemb def checkpoint_filter_fn(state_dict, model): """ convert patch embedding weight from manual patchify + linear proj to conv""" + import re out_dict = {} if 'model' in state_dict: # For deit models state_dict = state_dict['model'] + for k, v in state_dict.items(): if 'patch_embed.proj.weight' in k and len(v.shape) < 4: # For old models that I trained prior to conv based patchification O, I, H, W = model.patch_embed.proj.weight.shape v = v.reshape(O, -1, H, W) - elif k == 'pos_embed' and v.shape != model.pos_embed.shape: + elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]: # To resize pos embedding when using model at different size from pretrained weights v = resize_pos_embed( - v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + v, + model.pos_embed, + getattr(model, 'num_prefix_tokens', 1), + model.patch_embed.grid_size + ) + elif 'gamma_' in k: + # remap layer-scale gamma into sub-module (deit3 models) + k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k) elif 'pre_logits' in k: # NOTE representation layer removed as not used in latest 21k/1k pretrained weights continue