|
|
|
@ -123,14 +123,17 @@ default_cfgs = {
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth',
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
'vit_deit_tiny_distilled_patch16_224': _cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth'),
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth',
|
|
|
|
|
classifier=('head', 'head_dist')),
|
|
|
|
|
'vit_deit_small_distilled_patch16_224': _cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth'),
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth',
|
|
|
|
|
classifier=('head', 'head_dist')),
|
|
|
|
|
'vit_deit_base_distilled_patch16_224': _cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', ),
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth',
|
|
|
|
|
classifier=('head', 'head_dist')),
|
|
|
|
|
'vit_deit_base_distilled_patch16_384': _cfg(
|
|
|
|
|
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0),
|
|
|
|
|
input_size=(3, 384, 384), crop_pct=1.0, classifier=('head', 'head_dist')),
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -302,6 +305,7 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
|
|
|
|
self.num_tokens = 2 if distilled else 1
|
|
|
|
|
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
|
|
|
|
|
|
|
|
|
|
if hybrid_backbone is not None:
|
|
|
|
@ -313,9 +317,8 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
num_patches = self.patch_embed.num_patches
|
|
|
|
|
|
|
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
|
|
|
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) if distilled else None
|
|
|
|
|
num_tokens = 2 if distilled else 1
|
|
|
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + num_tokens, embed_dim))
|
|
|
|
|
self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
|
|
|
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
|
|
|
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
|
|
|
|
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
|
|
|
@ -382,10 +385,10 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
x = self.pos_drop(x + self.pos_embed)
|
|
|
|
|
x = self.blocks(x)
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
if self.dist_token is not None:
|
|
|
|
|
return x[:, 0], x[:, 1]
|
|
|
|
|
else:
|
|
|
|
|
if self.dist_token is None:
|
|
|
|
|
return self.pre_logits(x[:, 0])
|
|
|
|
|
else:
|
|
|
|
|
return x[:, 0], x[:, 1]
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.forward_features(x)
|
|
|
|
@ -401,15 +404,13 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resize_pos_embed(posemb, posemb_new, token='class'):
|
|
|
|
|
def resize_pos_embed(posemb, posemb_new, num_tokens=1):
|
|
|
|
|
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
|
|
|
|
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
|
|
|
|
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
|
|
|
|
ntok_new = posemb_new.shape[1]
|
|
|
|
|
if token:
|
|
|
|
|
assert token in ('class', 'distill')
|
|
|
|
|
token_idx = 2 if token == 'distill' else 1
|
|
|
|
|
posemb_tok, posemb_grid = posemb[:, :token_idx], posemb[0, token_idx:]
|
|
|
|
|
if num_tokens:
|
|
|
|
|
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
|
|
|
|
|
ntok_new -= 1
|
|
|
|
|
else:
|
|
|
|
|
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
|
|
|
@ -436,7 +437,7 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
v = v.reshape(O, -1, H, W)
|
|
|
|
|
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
|
|
|
|
# To resize pos embedding when using model at different size from pretrained weights
|
|
|
|
|
v = resize_pos_embed(v, model.pos_embed, token='distill' if model.dist_token is not None else 'class')
|
|
|
|
|
v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1))
|
|
|
|
|
out_dict[k] = v
|
|
|
|
|
return out_dict
|
|
|
|
|
|
|
|
|
|