Support DeiT-3 (Revenge of the ViT) checkpoints. Add non-overlapping (w/ class token) pos-embed support to vit.

pull/1327/head
Ross Wightman 2 years ago
parent d0c5bd5722
commit 7d4b3807d5

@ -1,7 +1,10 @@
""" DeiT - Data-efficient Image Transformers
DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
paper: `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
paper: `DeiT III: Revenge of the ViT` - https://arxiv.org/abs/2204.07118
Modifications copyright 2021, Ross Wightman
"""
@ -53,6 +56,46 @@ default_cfgs = {
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
input_size=(3, 384, 384), crop_pct=1.0,
classifier=('head', 'head_dist')),
'deit3_small_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_1k.pth'),
'deit3_small_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'deit3_base_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth'),
'deit3_base_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_1k.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'deit3_large_patch16_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_1k.pth'),
'deit3_large_patch16_384': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_1k.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'deit3_huge_patch14_224': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_1k.pth'),
'deit3_small_patch16_224_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_21k.pth',
crop_pct=1.0),
'deit3_small_patch16_384_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'deit3_base_patch16_224_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth',
crop_pct=1.0),
'deit3_base_patch16_384_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_21k.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'deit3_large_patch16_224_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_21k.pth',
crop_pct=1.0),
'deit3_large_patch16_384_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_21k.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'deit3_huge_patch14_224_in21ft1k': _cfg(
url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_21k_v1.pth',
crop_pct=1.0),
}
@ -68,9 +111,10 @@ class VisionTransformerDistilled(VisionTransformer):
super().__init__(*args, **kwargs, weight_init='skip')
assert self.global_pool in ('token',)
self.num_tokens = 2
self.num_prefix_tokens = 2
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, self.patch_embed.num_patches + self.num_prefix_tokens, self.embed_dim))
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
self.distilled_training = False # must set this True to train w/ distillation token
@ -220,3 +264,157 @@ def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
model = _create_deit(
'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
return model
@register_model
def deit3_small_patch16_224(pretrained=False, **kwargs):
""" DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_small_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_small_patch16_384(pretrained=False, **kwargs):
""" DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_small_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_base_patch16_224(pretrained=False, **kwargs):
""" DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_base_patch16_384(pretrained=False, **kwargs):
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_base_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_large_patch16_224(pretrained=False, **kwargs):
""" DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_large_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_large_patch16_384(pretrained=False, **kwargs):
""" DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_large_patch16_384', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_huge_patch14_224(pretrained=False, **kwargs):
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-1k weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_huge_patch14_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_small_patch16_224_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_small_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_small_patch16_384_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_small_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_base_patch16_224_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_base_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_base_patch16_384_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_base_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_large_patch16_224_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_large_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_large_patch16_384_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_large_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def deit3_huge_patch14_224_in21ft1k(pretrained=False, **kwargs):
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
model = _create_deit('deit3_huge_patch14_224_in21ft1k', pretrained=pretrained, **model_kwargs)
return model

@ -325,8 +325,8 @@ class VisionTransformer(nn.Module):
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
class_token=True, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='',
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
"""
Args:
img_size (int, tuple): input image size
@ -360,15 +360,17 @@ class VisionTransformer(nn.Module):
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1 if class_token else 0
self.num_prefix_tokens = 1 if class_token else 0
self.no_embed_class = no_embed_class
self.grad_checkpointing = False
self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if self.num_tokens > 0 else None
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, embed_dim) * .02)
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
@ -428,11 +430,24 @@ class VisionTransformer(nn.Module):
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def _pos_embed(self, x):
if self.no_embed_class:
# deit-3, updated JAX (big vision)
# position embedding does not overlap with class token, add then concat
x = x + self.pos_embed
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
else:
# original timm, JAX, and deit vit impl
# pos_embed has entry for class token, concat then add
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.pos_embed
return self.pos_drop(x)
def forward_features(self, x):
x = self.patch_embed(x)
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.pos_drop(x + self.pos_embed)
x = self._pos_embed(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
@ -442,7 +457,7 @@ class VisionTransformer(nn.Module):
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x)
return x if pre_logits else self.head(x)
@ -556,7 +571,11 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
if pos_embed_w.shape != model.pos_embed.shape:
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
pos_embed_w,
model.pos_embed,
getattr(model, 'num_prefix_tokens', 1),
model.patch_embed.grid_size
)
model.pos_embed.copy_(pos_embed_w)
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
@ -585,16 +604,16 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
ntok_new = posemb_new.shape[1]
if num_tokens:
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
ntok_new -= num_tokens
if num_prefix_tokens:
posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:]
ntok_new -= num_prefix_tokens
else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
posemb_prefix, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(math.sqrt(len(posemb_grid)))
if not len(gs_new): # backwards compatibility
gs_new = [int(math.sqrt(ntok_new))] * 2
@ -603,25 +622,34 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
posemb = torch.cat([posemb_prefix, posemb_grid], dim=1)
return posemb
def checkpoint_filter_fn(state_dict, model):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
import re
out_dict = {}
if 'model' in state_dict:
# For deit models
state_dict = state_dict['model']
for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
# For old models that I trained prior to conv based patchification
O, I, H, W = model.patch_embed.proj.weight.shape
v = v.reshape(O, -1, H, W)
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
# To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed(
v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
v,
model.pos_embed,
getattr(model, 'num_prefix_tokens', 1),
model.patch_embed.grid_size
)
elif 'gamma_' in k:
# remap layer-scale gamma into sub-module (deit3 models)
k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
elif 'pre_logits' in k:
# NOTE representation layer removed as not used in latest 21k/1k pretrained weights
continue

Loading…
Cancel
Save