|
|
@ -325,8 +325,8 @@ class VisionTransformer(nn.Module):
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
|
|
|
|
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
|
|
|
|
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
|
|
|
|
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
|
|
|
|
class_token=True, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='',
|
|
|
|
class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
|
|
|
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
|
|
|
|
weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
img_size (int, tuple): input image size
|
|
|
|
img_size (int, tuple): input image size
|
|
|
@ -360,15 +360,17 @@ class VisionTransformer(nn.Module):
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.global_pool = global_pool
|
|
|
|
self.global_pool = global_pool
|
|
|
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
|
|
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
|
|
|
self.num_tokens = 1 if class_token else 0
|
|
|
|
self.num_prefix_tokens = 1 if class_token else 0
|
|
|
|
|
|
|
|
self.no_embed_class = no_embed_class
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
|
|
|
|
|
|
|
|
self.patch_embed = embed_layer(
|
|
|
|
self.patch_embed = embed_layer(
|
|
|
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
|
|
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
|
|
|
num_patches = self.patch_embed.num_patches
|
|
|
|
num_patches = self.patch_embed.num_patches
|
|
|
|
|
|
|
|
|
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if self.num_tokens > 0 else None
|
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
|
|
|
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, embed_dim) * .02)
|
|
|
|
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
|
|
|
|
|
|
|
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
|
|
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
|
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
|
|
|
|
|
|
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
|
|
@ -428,11 +430,24 @@ class VisionTransformer(nn.Module):
|
|
|
|
self.global_pool = global_pool
|
|
|
|
self.global_pool = global_pool
|
|
|
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
def _pos_embed(self, x):
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
if self.no_embed_class:
|
|
|
|
|
|
|
|
# deit-3, updated JAX (big vision)
|
|
|
|
|
|
|
|
# position embedding does not overlap with class token, add then concat
|
|
|
|
|
|
|
|
x = x + self.pos_embed
|
|
|
|
|
|
|
|
if self.cls_token is not None:
|
|
|
|
|
|
|
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
# original timm, JAX, and deit vit impl
|
|
|
|
|
|
|
|
# pos_embed has entry for class token, concat then add
|
|
|
|
if self.cls_token is not None:
|
|
|
|
if self.cls_token is not None:
|
|
|
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
|
|
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
|
|
|
x = self.pos_drop(x + self.pos_embed)
|
|
|
|
x = x + self.pos_embed
|
|
|
|
|
|
|
|
return self.pos_drop(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
|
|
|
|
x = self._pos_embed(x)
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
x = checkpoint_seq(self.blocks, x)
|
|
|
|
x = checkpoint_seq(self.blocks, x)
|
|
|
|
else:
|
|
|
|
else:
|
|
|
@ -442,7 +457,7 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
|
|
if self.global_pool:
|
|
|
|
if self.global_pool:
|
|
|
|
x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
|
|
|
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
|
|
|
x = self.fc_norm(x)
|
|
|
|
x = self.fc_norm(x)
|
|
|
|
return x if pre_logits else self.head(x)
|
|
|
|
return x if pre_logits else self.head(x)
|
|
|
|
|
|
|
|
|
|
|
@ -556,7 +571,11 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
|
|
|
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
|
|
|
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
|
|
|
if pos_embed_w.shape != model.pos_embed.shape:
|
|
|
|
if pos_embed_w.shape != model.pos_embed.shape:
|
|
|
|
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
|
|
|
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
|
|
|
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
|
|
|
pos_embed_w,
|
|
|
|
|
|
|
|
model.pos_embed,
|
|
|
|
|
|
|
|
getattr(model, 'num_prefix_tokens', 1),
|
|
|
|
|
|
|
|
model.patch_embed.grid_size
|
|
|
|
|
|
|
|
)
|
|
|
|
model.pos_embed.copy_(pos_embed_w)
|
|
|
|
model.pos_embed.copy_(pos_embed_w)
|
|
|
|
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
|
|
|
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
|
|
|
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
|
|
|
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
|
|
@ -585,16 +604,16 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
|
|
|
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
|
|
|
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
|
|
|
def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
|
|
|
|
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
|
|
|
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
|
|
|
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
|
|
|
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
|
|
|
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
|
|
|
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
|
|
|
ntok_new = posemb_new.shape[1]
|
|
|
|
ntok_new = posemb_new.shape[1]
|
|
|
|
if num_tokens:
|
|
|
|
if num_prefix_tokens:
|
|
|
|
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
|
|
|
|
posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:]
|
|
|
|
ntok_new -= num_tokens
|
|
|
|
ntok_new -= num_prefix_tokens
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
|
|
|
posemb_prefix, posemb_grid = posemb[:, :0], posemb[0]
|
|
|
|
gs_old = int(math.sqrt(len(posemb_grid)))
|
|
|
|
gs_old = int(math.sqrt(len(posemb_grid)))
|
|
|
|
if not len(gs_new): # backwards compatibility
|
|
|
|
if not len(gs_new): # backwards compatibility
|
|
|
|
gs_new = [int(math.sqrt(ntok_new))] * 2
|
|
|
|
gs_new = [int(math.sqrt(ntok_new))] * 2
|
|
|
@ -603,25 +622,34 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
|
|
|
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
|
|
|
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
|
|
|
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
|
|
|
|
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
|
|
|
|
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
|
|
|
|
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
|
|
|
|
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
|
|
|
posemb = torch.cat([posemb_prefix, posemb_grid], dim=1)
|
|
|
|
return posemb
|
|
|
|
return posemb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
|
|
|
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
|
|
|
|
|
|
|
import re
|
|
|
|
out_dict = {}
|
|
|
|
out_dict = {}
|
|
|
|
if 'model' in state_dict:
|
|
|
|
if 'model' in state_dict:
|
|
|
|
# For deit models
|
|
|
|
# For deit models
|
|
|
|
state_dict = state_dict['model']
|
|
|
|
state_dict = state_dict['model']
|
|
|
|
|
|
|
|
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
|
|
|
|
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
|
|
|
|
# For old models that I trained prior to conv based patchification
|
|
|
|
# For old models that I trained prior to conv based patchification
|
|
|
|
O, I, H, W = model.patch_embed.proj.weight.shape
|
|
|
|
O, I, H, W = model.patch_embed.proj.weight.shape
|
|
|
|
v = v.reshape(O, -1, H, W)
|
|
|
|
v = v.reshape(O, -1, H, W)
|
|
|
|
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
|
|
|
elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
|
|
|
|
# To resize pos embedding when using model at different size from pretrained weights
|
|
|
|
# To resize pos embedding when using model at different size from pretrained weights
|
|
|
|
v = resize_pos_embed(
|
|
|
|
v = resize_pos_embed(
|
|
|
|
v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
|
|
|
v,
|
|
|
|
|
|
|
|
model.pos_embed,
|
|
|
|
|
|
|
|
getattr(model, 'num_prefix_tokens', 1),
|
|
|
|
|
|
|
|
model.patch_embed.grid_size
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
elif 'gamma_' in k:
|
|
|
|
|
|
|
|
# remap layer-scale gamma into sub-module (deit3 models)
|
|
|
|
|
|
|
|
k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
|
|
|
|
elif 'pre_logits' in k:
|
|
|
|
elif 'pre_logits' in k:
|
|
|
|
# NOTE representation layer removed as not used in latest 21k/1k pretrained weights
|
|
|
|
# NOTE representation layer removed as not used in latest 21k/1k pretrained weights
|
|
|
|
continue
|
|
|
|
continue
|
|
|
|