|
|
|
@ -268,12 +268,16 @@ class HybridEmbed(nn.Module):
|
|
|
|
|
class VisionTransformer(nn.Module):
|
|
|
|
|
""" Vision Transformer
|
|
|
|
|
|
|
|
|
|
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` -
|
|
|
|
|
https://arxiv.org/abs/2010.11929
|
|
|
|
|
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
|
|
|
|
- https://arxiv.org/abs/2010.11929
|
|
|
|
|
|
|
|
|
|
Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
|
|
|
|
|
- https://arxiv.org/abs/2012.12877
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
|
|
|
|
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None,
|
|
|
|
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None):
|
|
|
|
|
num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, distilled=False,
|
|
|
|
|
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None,
|
|
|
|
|
weight_init=''):
|
|
|
|
|
"""
|
|
|
|
|
Args:
|
|
|
|
|
img_size (int, tuple): input image size
|
|
|
|
@ -287,11 +291,13 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
qkv_bias (bool): enable bias for qkv if True
|
|
|
|
|
qk_scale (float): override default qk scale of head_dim ** -0.5 if set
|
|
|
|
|
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
|
|
|
|
|
distilled (bool): model includes a distillation token and head as in DeiT models
|
|
|
|
|
drop_rate (float): dropout rate
|
|
|
|
|
attn_drop_rate (float): attention dropout rate
|
|
|
|
|
drop_path_rate (float): stochastic depth rate
|
|
|
|
|
hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module
|
|
|
|
|
norm_layer: (nn.Module): normalization layer
|
|
|
|
|
weight_init: (str): weight init scheme
|
|
|
|
|
"""
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
@ -307,11 +313,13 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
num_patches = self.patch_embed.num_patches
|
|
|
|
|
|
|
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
|
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
|
|
|
|
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) if distilled else None
|
|
|
|
|
num_tokens = 2 if distilled else 1
|
|
|
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + num_tokens, embed_dim))
|
|
|
|
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
|
|
|
|
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
|
|
|
|
self.blocks = nn.ModuleList([
|
|
|
|
|
self.blocks = nn.Sequential(*[
|
|
|
|
|
Block(
|
|
|
|
|
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
|
|
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
|
|
|
@ -319,7 +327,7 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
self.norm = norm_layer(embed_dim)
|
|
|
|
|
|
|
|
|
|
# Representation layer
|
|
|
|
|
if representation_size:
|
|
|
|
|
if representation_size and not distilled:
|
|
|
|
|
self.num_features = representation_size
|
|
|
|
|
self.pre_logits = nn.Sequential(OrderedDict([
|
|
|
|
|
('fc', nn.Linear(embed_dim, representation_size)),
|
|
|
|
@ -328,11 +336,15 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
else:
|
|
|
|
|
self.pre_logits = nn.Identity()
|
|
|
|
|
|
|
|
|
|
# Classifier head
|
|
|
|
|
# Classifier head(s)
|
|
|
|
|
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
|
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \
|
|
|
|
|
if num_classes > 0 and distilled else nn.Identity()
|
|
|
|
|
|
|
|
|
|
trunc_normal_(self.pos_embed, std=.02)
|
|
|
|
|
trunc_normal_(self.cls_token, std=.02)
|
|
|
|
|
if self.dist_token is not None:
|
|
|
|
|
trunc_normal_(self.dist_token, std=.02)
|
|
|
|
|
self.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
def _init_weights(self, m):
|
|
|
|
@ -346,91 +358,58 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def no_weight_decay(self):
|
|
|
|
|
return {'pos_embed', 'cls_token'}
|
|
|
|
|
return {'pos_embed', 'cls_token', 'dist_token'}
|
|
|
|
|
|
|
|
|
|
def get_classifier(self):
|
|
|
|
|
if self.dist_token is None:
|
|
|
|
|
return self.head
|
|
|
|
|
else:
|
|
|
|
|
return self.head, self.head_dist
|
|
|
|
|
|
|
|
|
|
def reset_classifier(self, num_classes, global_pool=''):
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
|
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \
|
|
|
|
|
if num_classes > 0 and self.dist_token is not None else nn.Identity()
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
B = x.shape[0]
|
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
|
|
|
|
|
|
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
|
|
|
|
x = torch.cat((cls_tokens, x), dim=1)
|
|
|
|
|
x = x + self.pos_embed
|
|
|
|
|
x = self.pos_drop(x)
|
|
|
|
|
|
|
|
|
|
for blk in self.blocks:
|
|
|
|
|
x = blk(x)
|
|
|
|
|
|
|
|
|
|
x = self.norm(x)[:, 0]
|
|
|
|
|
x = self.pre_logits(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.forward_features(x)
|
|
|
|
|
x = self.head(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DistilledVisionTransformer(VisionTransformer):
|
|
|
|
|
""" Vision Transformer with distillation token.
|
|
|
|
|
|
|
|
|
|
Paper: `Training data-efficient image transformers & distillation through attention` -
|
|
|
|
|
https://arxiv.org/abs/2012.12877
|
|
|
|
|
|
|
|
|
|
This impl of distilled ViT is taken from https://github.com/facebookresearch/deit
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
|
|
super().__init__(*args, **kwargs)
|
|
|
|
|
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
|
|
|
|
num_patches = self.patch_embed.num_patches
|
|
|
|
|
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
|
|
|
|
|
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
|
|
|
|
|
|
|
|
|
trunc_normal_(self.dist_token, std=.02)
|
|
|
|
|
trunc_normal_(self.pos_embed, std=.02)
|
|
|
|
|
self.head_dist.apply(self._init_weights)
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
B = x.shape[0]
|
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
|
|
|
|
|
|
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
|
|
|
|
dist_token = self.dist_token.expand(B, -1, -1)
|
|
|
|
|
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
|
|
|
|
|
|
|
|
|
x = x + self.pos_embed
|
|
|
|
|
x = self.pos_drop(x)
|
|
|
|
|
|
|
|
|
|
for blk in self.blocks:
|
|
|
|
|
x = blk(x)
|
|
|
|
|
|
|
|
|
|
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
|
|
|
|
if self.dist_token is None:
|
|
|
|
|
x = torch.cat((cls_token, x), dim=1)
|
|
|
|
|
else:
|
|
|
|
|
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
|
|
|
|
|
x = self.pos_drop(x + self.pos_embed)
|
|
|
|
|
x = self.blocks(x)
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
if self.dist_token is not None:
|
|
|
|
|
return x[:, 0], x[:, 1]
|
|
|
|
|
else:
|
|
|
|
|
return self.pre_logits(x[:, 0])
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x, x_dist = self.forward_features(x)
|
|
|
|
|
x = self.head(x)
|
|
|
|
|
x_dist = self.head_dist(x_dist)
|
|
|
|
|
if self.training:
|
|
|
|
|
x = self.forward_features(x)
|
|
|
|
|
if isinstance(x, tuple):
|
|
|
|
|
x, x_dist = self.head(x[0]), self.head_dist(x[1])
|
|
|
|
|
if self.training and not torch.jit.is_scripting():
|
|
|
|
|
# during inference, return the average of both classifier predictions
|
|
|
|
|
return x, x_dist
|
|
|
|
|
else:
|
|
|
|
|
# during inference, return the average of both classifier predictions
|
|
|
|
|
return (x + x_dist) / 2
|
|
|
|
|
else:
|
|
|
|
|
x = self.head(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resize_pos_embed(posemb, posemb_new):
|
|
|
|
|
def resize_pos_embed(posemb, posemb_new, token='class'):
|
|
|
|
|
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
|
|
|
|
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
|
|
|
|
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
|
|
|
|
ntok_new = posemb_new.shape[1]
|
|
|
|
|
if True:
|
|
|
|
|
posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:]
|
|
|
|
|
if token:
|
|
|
|
|
assert token in ('class', 'distill')
|
|
|
|
|
token_idx = 2 if token == 'distill' else 1
|
|
|
|
|
posemb_tok, posemb_grid = posemb[:, :token_idx], posemb[0, token_idx:]
|
|
|
|
|
ntok_new -= 1
|
|
|
|
|
else:
|
|
|
|
|
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
|
|
|
@ -457,12 +436,12 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
v = v.reshape(O, -1, H, W)
|
|
|
|
|
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
|
|
|
|
# To resize pos embedding when using model at different size from pretrained weights
|
|
|
|
|
v = resize_pos_embed(v, model.pos_embed)
|
|
|
|
|
v = resize_pos_embed(v, model.pos_embed, token='distill' if model.dist_token is not None else 'class')
|
|
|
|
|
out_dict[k] = v
|
|
|
|
|
return out_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs):
|
|
|
|
|
def _create_vision_transformer(variant, pretrained=False, **kwargs):
|
|
|
|
|
default_cfg = deepcopy(default_cfgs[variant])
|
|
|
|
|
overlay_external_default_cfg(default_cfg, kwargs)
|
|
|
|
|
default_num_classes = default_cfg['num_classes']
|
|
|
|
@ -480,9 +459,8 @@ def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwa
|
|
|
|
|
if kwargs.get('features_only', None):
|
|
|
|
|
raise RuntimeError('features_only not implemented for Vision Transformer models.')
|
|
|
|
|
|
|
|
|
|
model_cls = DistilledVisionTransformer if distilled else VisionTransformer
|
|
|
|
|
model = build_model_with_cfg(
|
|
|
|
|
model_cls, variant, pretrained,
|
|
|
|
|
VisionTransformer, variant, pretrained,
|
|
|
|
|
default_cfg=default_cfg,
|
|
|
|
|
img_size=img_size,
|
|
|
|
|
num_classes=num_classes,
|
|
|
|
|