diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 82d4ee49..c7c9027d 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -268,12 +268,16 @@ class HybridEmbed(nn.Module): class VisionTransformer(nn.Module): """ Vision Transformer - A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - - https://arxiv.org/abs/2010.11929 + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` + - https://arxiv.org/abs/2010.11929 + + Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` + - https://arxiv.org/abs/2012.12877 """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None): + num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, distilled=False, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., hybrid_backbone=None, norm_layer=None, + weight_init=''): """ Args: img_size (int, tuple): input image size @@ -287,11 +291,13 @@ class VisionTransformer(nn.Module): qkv_bias (bool): enable bias for qkv if True qk_scale (float): override default qk scale of head_dim ** -0.5 if set representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + distilled (bool): model includes a distillation token and head as in DeiT models drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate hybrid_backbone (nn.Module): CNN backbone to use in-place of PatchEmbed module norm_layer: (nn.Module): normalization layer + weight_init: (str): weight init scheme """ super().__init__() self.num_classes = num_classes @@ -307,11 +313,13 @@ class VisionTransformer(nn.Module): num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) + self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) if distilled else None + num_tokens = 2 if distilled else 1 + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + num_tokens, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule - self.blocks = nn.ModuleList([ + self.blocks = nn.Sequential(*[ Block( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) @@ -319,7 +327,7 @@ class VisionTransformer(nn.Module): self.norm = norm_layer(embed_dim) # Representation layer - if representation_size: + if representation_size and not distilled: self.num_features = representation_size self.pre_logits = nn.Sequential(OrderedDict([ ('fc', nn.Linear(embed_dim, representation_size)), @@ -328,11 +336,15 @@ class VisionTransformer(nn.Module): else: self.pre_logits = nn.Identity() - # Classifier head + # Classifier head(s) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \ + if num_classes > 0 and distilled else nn.Identity() trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) + if self.dist_token is not None: + trunc_normal_(self.dist_token, std=.02) self.apply(self._init_weights) def _init_weights(self, m): @@ -346,91 +358,58 @@ class VisionTransformer(nn.Module): @torch.jit.ignore def no_weight_decay(self): - return {'pos_embed', 'cls_token'} + return {'pos_embed', 'cls_token', 'dist_token'} def get_classifier(self): - return self.head + if self.dist_token is None: + return self.head + else: + return self.head, self.head_dist def reset_classifier(self, num_classes, global_pool=''): self.num_classes = num_classes self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) \ + if num_classes > 0 and self.dist_token is not None else nn.Identity() def forward_features(self, x): - B = x.shape[0] - x = self.patch_embed(x) - - cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks - x = torch.cat((cls_tokens, x), dim=1) - x = x + self.pos_embed - x = self.pos_drop(x) - - for blk in self.blocks: - x = blk(x) - - x = self.norm(x)[:, 0] - x = self.pre_logits(x) - return x - - def forward(self, x): - x = self.forward_features(x) - x = self.head(x) - return x - - -class DistilledVisionTransformer(VisionTransformer): - """ Vision Transformer with distillation token. - - Paper: `Training data-efficient image transformers & distillation through attention` - - https://arxiv.org/abs/2012.12877 - - This impl of distilled ViT is taken from https://github.com/facebookresearch/deit - """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) - num_patches = self.patch_embed.num_patches - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim)) - self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() - - trunc_normal_(self.dist_token, std=.02) - trunc_normal_(self.pos_embed, std=.02) - self.head_dist.apply(self._init_weights) - - def forward_features(self, x): - B = x.shape[0] x = self.patch_embed(x) - - cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks - dist_token = self.dist_token.expand(B, -1, -1) - x = torch.cat((cls_tokens, dist_token, x), dim=1) - - x = x + self.pos_embed - x = self.pos_drop(x) - - for blk in self.blocks: - x = blk(x) - + cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks + if self.dist_token is None: + x = torch.cat((cls_token, x), dim=1) + else: + x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) + x = self.pos_drop(x + self.pos_embed) + x = self.blocks(x) x = self.norm(x) - return x[:, 0], x[:, 1] + if self.dist_token is not None: + return x[:, 0], x[:, 1] + else: + return self.pre_logits(x[:, 0]) def forward(self, x): - x, x_dist = self.forward_features(x) - x = self.head(x) - x_dist = self.head_dist(x_dist) - if self.training: - return x, x_dist + x = self.forward_features(x) + if isinstance(x, tuple): + x, x_dist = self.head(x[0]), self.head_dist(x[1]) + if self.training and not torch.jit.is_scripting(): + # during inference, return the average of both classifier predictions + return x, x_dist + else: + return (x + x_dist) / 2 else: - # during inference, return the average of both classifier predictions - return (x + x_dist) / 2 + x = self.head(x) + return x -def resize_pos_embed(posemb, posemb_new): +def resize_pos_embed(posemb, posemb_new, token='class'): # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) ntok_new = posemb_new.shape[1] - if True: - posemb_tok, posemb_grid = posemb[:, :1], posemb[0, 1:] + if token: + assert token in ('class', 'distill') + token_idx = 2 if token == 'distill' else 1 + posemb_tok, posemb_grid = posemb[:, :token_idx], posemb[0, token_idx:] ntok_new -= 1 else: posemb_tok, posemb_grid = posemb[:, :0], posemb[0] @@ -457,12 +436,12 @@ def checkpoint_filter_fn(state_dict, model): v = v.reshape(O, -1, H, W) elif k == 'pos_embed' and v.shape != model.pos_embed.shape: # To resize pos embedding when using model at different size from pretrained weights - v = resize_pos_embed(v, model.pos_embed) + v = resize_pos_embed(v, model.pos_embed, token='distill' if model.dist_token is not None else 'class') out_dict[k] = v return out_dict -def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwargs): +def _create_vision_transformer(variant, pretrained=False, **kwargs): default_cfg = deepcopy(default_cfgs[variant]) overlay_external_default_cfg(default_cfg, kwargs) default_num_classes = default_cfg['num_classes'] @@ -480,9 +459,8 @@ def _create_vision_transformer(variant, pretrained=False, distilled=False, **kwa if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') - model_cls = DistilledVisionTransformer if distilled else VisionTransformer model = build_model_with_cfg( - model_cls, variant, pretrained, + VisionTransformer, variant, pretrained, default_cfg=default_cfg, img_size=img_size, num_classes=num_classes,