From 5f81d4de234f579bdc988e8346da14b37a3af160 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 26 Jan 2022 22:53:57 -0800 Subject: [PATCH] Move DeiT to own file, vit getting crowded. Working towards fixing #1029, make pooling interface for transformers and mlp closer to convnets. Still working through some details... --- tests/test_models.py | 11 +- timm/models/__init__.py | 1 + timm/models/beit.py | 37 ++--- timm/models/cait.py | 44 +++--- timm/models/coat.py | 24 ++-- timm/models/convit.py | 3 +- timm/models/convmixer.py | 3 +- timm/models/convnext.py | 1 - timm/models/crossvit.py | 6 +- timm/models/deit.py | 201 ++++++++++++++++++++++++++ timm/models/levit.py | 10 +- timm/models/mlp_mixer.py | 2 +- timm/models/mobilenetv3.py | 7 +- timm/models/pit.py | 11 +- timm/models/swin_transformer.py | 57 ++++---- timm/models/tnt.py | 3 +- timm/models/twins.py | 3 +- timm/models/vision_transformer.py | 231 ++++++++---------------------- timm/models/xcit.py | 6 +- 19 files changed, 370 insertions(+), 291 deletions(-) create mode 100644 timm/models/deit.py diff --git a/tests/test_models.py b/tests/test_models.py index a6fc4a4a..6b448dc9 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -205,22 +205,23 @@ def test_model_default_cfgs_non_std(model_name, batch_size): outputs = model.forward_features(input_tensor) if isinstance(outputs, (tuple, list)): outputs = outputs[0] - assert outputs.shape[1] == model.num_features + feat_dim = -1 if outputs.ndim == 3 else 1 + assert outputs.shape[feat_dim] == model.num_features # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features model.reset_classifier(0) outputs = model.forward(input_tensor) if isinstance(outputs, (tuple, list)): outputs = outputs[0] - assert len(outputs.shape) == 2 - assert outputs.shape[1] == model.num_features + feat_dim = -1 if outputs.ndim == 3 else 1 + assert outputs.shape[feat_dim] == model.num_features model = create_model(model_name, pretrained=False, num_classes=0).eval() outputs = model.forward(input_tensor) if isinstance(outputs, (tuple, list)): outputs = outputs[0] - assert len(outputs.shape) == 2 - assert outputs.shape[1] == model.num_features + feat_dim = -1 if outputs.ndim == 3 else 1 + assert outputs.shape[feat_dim] == model.num_features # check classifier name matches default_cfg if cfg.get('num_classes', None): diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 306d5aeb..44e31f36 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -8,6 +8,7 @@ from .convmixer import * from .convnext import * from .crossvit import * from .cspnet import * +from .deit import * from .densenet import * from .dla import * from .dpn import * diff --git a/timm/models/beit.py b/timm/models/beit.py index e82f6f63..68ca44b7 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -232,13 +232,15 @@ class Beit(nn.Module): """ Vision Transformer with support for patch or hybrid CNN input stage """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., - drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), init_values=None, - use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, - use_mean_pooling=True, init_scale=0.001): + def __init__( + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg', + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., + attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), + init_values=None, use_abs_pos_emb=True, use_rel_pos_bias=False, use_shared_rel_pos_bias=False, + head_init_scale=0.001): super().__init__() self.num_classes = num_classes + self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.patch_embed = PatchEmbed( @@ -247,10 +249,7 @@ class Beit(nn.Module): self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - if use_abs_pos_emb: - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) - else: - self.pos_embed = None + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) if use_abs_pos_emb else None self.pos_drop = nn.Dropout(p=drop_rate) if use_shared_rel_pos_bias: @@ -266,8 +265,9 @@ class Beit(nn.Module): drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, init_values=init_values, window_size=self.patch_embed.grid_size if use_rel_pos_bias else None) for i in range(depth)]) - self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim) - self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None + use_fc_norm = self.global_pool == 'avg' + self.norm = nn.Identity() if use_fc_norm else norm_layer(embed_dim) + self.fc_norm = norm_layer(embed_dim) if use_fc_norm else None self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() self.apply(self._init_weights) @@ -278,8 +278,8 @@ class Beit(nn.Module): self.fix_init_weight() if isinstance(self.head, nn.Linear): trunc_normal_(self.head.weight, std=.02) - self.head.weight.data.mul_(init_scale) - self.head.bias.data.mul_(init_scale) + self.head.weight.data.mul_(head_init_scale) + self.head.bias.data.mul_(head_init_scale) def fix_init_weight(self): def rescale(param, layer_id): @@ -327,14 +327,15 @@ class Beit(nn.Module): x = blk(x, rel_pos_bias=rel_pos_bias) x = self.norm(x) - if self.fc_norm is not None: - t = x[:, 1:, :] - return self.fc_norm(t.mean(1)) - else: - return x[:, 0] + return x def forward(self, x): x = self.forward_features(x) + if self.fc_norm is not None: + x = x[:, 1:].mean(dim=1) + x = self.fc_norm(x) + else: + x = x[:, 0] x = self.head(x) return x diff --git a/timm/models/cait.py b/timm/models/cait.py index c09f942c..331111f2 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -213,11 +213,11 @@ class Cait(nn.Module): act_layer=nn.GELU, attn_block=TalkingHeadAttn, mlp_block=Mlp, - init_scale=1e-4, + init_values=1e-4, attn_block_token_only=ClassAttn, mlp_block_token_only=Mlp, depth_token_only=2, - mlp_ratio_clstk=4.0 + mlp_ratio_token_only=4.0 ): super().__init__() @@ -234,19 +234,19 @@ class Cait(nn.Module): self.pos_drop = nn.Dropout(p=drop_rate) dpr = [drop_path_rate for i in range(depth)] - self.blocks = nn.ModuleList([ + self.blocks = nn.Sequential(*[ block_layers( dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, - act_layer=act_layer, attn_block=attn_block, mlp_block=mlp_block, init_values=init_scale) + act_layer=act_layer, attn_block=attn_block, mlp_block=mlp_block, init_values=init_values) for i in range(depth)]) self.blocks_token_only = nn.ModuleList([ block_layers_token( - dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_clstk, qkv_bias=qkv_bias, + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio_token_only, qkv_bias=qkv_bias, drop=0.0, attn_drop=0.0, drop_path=0.0, norm_layer=norm_layer, act_layer=act_layer, attn_block=attn_block_token_only, - mlp_block=mlp_block_token_only, init_values=init_scale) + mlp_block=mlp_block_token_only, init_values=init_values) for i in range(depth_token_only)]) self.norm = norm_layer(embed_dim) @@ -281,25 +281,21 @@ class Cait(nn.Module): def forward_features(self, x): B = x.shape[0] x = self.patch_embed(x) - - cls_tokens = self.cls_token.expand(B, -1, -1) - x = x + self.pos_embed x = self.pos_drop(x) + x = self.blocks(x) - for i, blk in enumerate(self.blocks): - x = blk(x) - + cls_tokens = self.cls_token.expand(B, -1, -1) for i, blk in enumerate(self.blocks_token_only): cls_tokens = blk(x, cls_tokens) - x = torch.cat((cls_tokens, x), dim=1) x = self.norm(x) - return x[:, 0] + return x def forward(self, x): x = self.forward_features(x) + x = x[:, 0] x = self.head(x) return x @@ -326,69 +322,69 @@ def _create_cait(variant, pretrained=False, **kwargs): @register_model def cait_xxs24_224(pretrained=False, **kwargs): - model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_scale=1e-5, **kwargs) + model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_values=1e-5, **kwargs) model = _create_cait('cait_xxs24_224', pretrained=pretrained, **model_args) return model @register_model def cait_xxs24_384(pretrained=False, **kwargs): - model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_scale=1e-5, **kwargs) + model_args = dict(patch_size=16, embed_dim=192, depth=24, num_heads=4, init_values=1e-5, **kwargs) model = _create_cait('cait_xxs24_384', pretrained=pretrained, **model_args) return model @register_model def cait_xxs36_224(pretrained=False, **kwargs): - model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_scale=1e-5, **kwargs) + model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_values=1e-5, **kwargs) model = _create_cait('cait_xxs36_224', pretrained=pretrained, **model_args) return model @register_model def cait_xxs36_384(pretrained=False, **kwargs): - model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_scale=1e-5, **kwargs) + model_args = dict(patch_size=16, embed_dim=192, depth=36, num_heads=4, init_values=1e-5, **kwargs) model = _create_cait('cait_xxs36_384', pretrained=pretrained, **model_args) return model @register_model def cait_xs24_384(pretrained=False, **kwargs): - model_args = dict(patch_size=16, embed_dim=288, depth=24, num_heads=6, init_scale=1e-5, **kwargs) + model_args = dict(patch_size=16, embed_dim=288, depth=24, num_heads=6, init_values=1e-5, **kwargs) model = _create_cait('cait_xs24_384', pretrained=pretrained, **model_args) return model @register_model def cait_s24_224(pretrained=False, **kwargs): - model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_scale=1e-5, **kwargs) + model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_values=1e-5, **kwargs) model = _create_cait('cait_s24_224', pretrained=pretrained, **model_args) return model @register_model def cait_s24_384(pretrained=False, **kwargs): - model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_scale=1e-5, **kwargs) + model_args = dict(patch_size=16, embed_dim=384, depth=24, num_heads=8, init_values=1e-5, **kwargs) model = _create_cait('cait_s24_384', pretrained=pretrained, **model_args) return model @register_model def cait_s36_384(pretrained=False, **kwargs): - model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=8, init_scale=1e-6, **kwargs) + model_args = dict(patch_size=16, embed_dim=384, depth=36, num_heads=8, init_values=1e-6, **kwargs) model = _create_cait('cait_s36_384', pretrained=pretrained, **model_args) return model @register_model def cait_m36_384(pretrained=False, **kwargs): - model_args = dict(patch_size=16, embed_dim=768, depth=36, num_heads=16, init_scale=1e-6, **kwargs) + model_args = dict(patch_size=16, embed_dim=768, depth=36, num_heads=16, init_values=1e-6, **kwargs) model = _create_cait('cait_m36_384', pretrained=pretrained, **model_args) return model @register_model def cait_m48_448(pretrained=False, **kwargs): - model_args = dict(patch_size=16, embed_dim=768, depth=48, num_heads=16, init_scale=1e-6, **kwargs) + model_args = dict(patch_size=16, embed_dim=768, depth=48, num_heads=16, init_values=1e-6, **kwargs) model = _create_cait('cait_m48_448', pretrained=pretrained, **model_args) return model diff --git a/timm/models/coat.py b/timm/models/coat.py index 6425d67e..4188243f 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -447,6 +447,7 @@ class CoaT(nn.Module): self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() else: # CoaT-Lite series: Use feature of last scale for classification. + self.aggregate = None self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() # Initialize weights. @@ -542,8 +543,7 @@ class CoaT(nn.Module): else: # Return features for classification. x4 = self.norm4(x4) - x4_cls = x4[:, 0] - return x4_cls + return x4 # Parallel blocks. for blk in self.parallel_blocks: @@ -574,20 +574,20 @@ class CoaT(nn.Module): x2 = self.norm2(x2) x3 = self.norm3(x3) x4 = self.norm4(x4) - x2_cls = x2[:, :1] # [B, 1, C] - x3_cls = x3[:, :1] - x4_cls = x4[:, :1] - merged_cls = torch.cat((x2_cls, x3_cls, x4_cls), dim=1) # [B, 3, C] - merged_cls = self.aggregate(merged_cls).squeeze(dim=1) # Shape: [B, C] - return merged_cls - - def forward(self, x): - if self.return_interm_layers: + return [x2, x3, x4] + + def forward(self, x) -> torch.Tensor: + if not torch.jit.is_scripting() and self.return_interm_layers: # Return intermediate features (for down-stream tasks). return self.forward_features(x) else: # Return features for classification. - x = self.forward_features(x) + x_feat = self.forward_features(x) + if isinstance(x_feat, (tuple, list)): + x = torch.cat([xl[:, :1] for xl in x_feat], dim=1) # [B, 3, C] + x = self.aggregate(x).squeeze(dim=1) # Shape: [B, C] + else: + x = x_feat[:, 0] x = self.head(x) return x diff --git a/timm/models/convit.py b/timm/models/convit.py index 51165aef..a3287574 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -308,10 +308,11 @@ class ConViT(nn.Module): x = blk(x) x = self.norm(x) - return x[:, 0] + return x def forward(self, x): x = self.forward_features(x) + x = x[:, 0] x = self.head(x) return x diff --git a/timm/models/convmixer.py b/timm/models/convmixer.py index df551788..f4eb9795 100644 --- a/timm/models/convmixer.py +++ b/timm/models/convmixer.py @@ -69,13 +69,12 @@ class ConvMixer(nn.Module): def forward_features(self, x): x = self.stem(x) x = self.blocks(x) - x = self.pooling(x) return x def forward(self, x): x = self.forward_features(x) + x = self.pooling(x) x = self.head(x) - return x diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 5f75647b..8f0b9e0a 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -319,7 +319,6 @@ def checkpoint_filter_fn(state_dict, model): def _create_convnext(variant, pretrained=False, **kwargs): model = build_model_with_cfg( ConvNeXt, variant, pretrained, - default_cfg=default_cfgs[variant], pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), **kwargs) diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index f533a86c..653da40b 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -368,7 +368,7 @@ class CrossViT(nn.Module): [nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)]) - def forward_features(self, x): + def forward_features(self, x) -> List[torch.Tensor]: B = x.shape[0] xs = [] for i, patch_embed in enumerate(self.patch_embed): @@ -389,11 +389,11 @@ class CrossViT(nn.Module): # NOTE: was before branch token section, move to here to assure all branch token are before layer norm xs = [norm(xs[i]) for i, norm in enumerate(self.norm)] - return [xo[:, 0] for xo in xs] + return xs def forward(self, x): xs = self.forward_features(x) - ce_logits = [head(xs[i]) for i, head in enumerate(self.head)] + ce_logits = [head(xs[i][:, 0]) for i, head in enumerate(self.head)] if not isinstance(self.head[0], nn.Identity): ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0) return ce_logits diff --git a/timm/models/deit.py b/timm/models/deit.py new file mode 100644 index 00000000..5cb49394 --- /dev/null +++ b/timm/models/deit.py @@ -0,0 +1,201 @@ +""" DeiT - Data-efficient Image Transformers + +DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below +paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 + +Modifications copyright 2021, Ross Wightman +""" +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. +import torch +from torch import nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn + +from .helpers import build_model_with_cfg +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + # deit models (FB weights) + 'deit_tiny_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth'), + 'deit_small_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth'), + 'deit_base_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth'), + 'deit_base_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', + input_size=(3, 384, 384), crop_pct=1.0), + + 'deit_tiny_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', + classifier=('head', 'head_dist')), + 'deit_small_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', + classifier=('head', 'head_dist')), + 'deit_base_distilled_patch16_224': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', + classifier=('head', 'head_dist')), + 'deit_base_distilled_patch16_384': _cfg( + url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', + input_size=(3, 384, 384), crop_pct=1.0, + classifier=('head', 'head_dist')), +} + + +class VisionTransformerDistilled(VisionTransformer): + """ Vision Transformer w/ Distillation Token and Head + + Distillation token & head support for `DeiT: Data-efficient Image Transformers` + - https://arxiv.org/abs/2012.12877 + """ + + def __init__(self, *args, **kwargs): + weight_init = kwargs.pop('weight_init', '') + super().__init__(*args, **kwargs, weight_init='skip') + self.num_tokens = 2 + self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim)) + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity() + + self.init_weights(weight_init) + + def init_weights(self, mode=''): + trunc_normal_(self.dist_token, std=.02) + super().init_weights(mode=mode) + + def get_classifier(self): + return self.head, self.head_dist + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x) -> torch.Tensor: + x = self.patch_embed(x) + x = torch.cat(( + self.cls_token.expand(x.shape[0], -1, -1), + self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) + x = self.pos_drop(x + self.pos_embed) + x = self.blocks(x) + x = self.norm(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x_dist = self.head_dist(x[:, 1]) + x = self.head(x[:, 0]) + if self.training and not torch.jit.is_scripting(): + return x, x_dist + else: + # during inference, return the average of both classifier predictions + return (x + x_dist) / 2 + + +def _create_deit(variant, pretrained=False, distilled=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + model_cls = VisionTransformerDistilled if distilled else VisionTransformer + model = build_model_with_cfg( + model_cls, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + +@register_model +def deit_tiny_patch16_224(pretrained=False, **kwargs): + """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_deit('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit_small_patch16_224(pretrained=False, **kwargs): + """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_deit('deit_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit_base_patch16_224(pretrained=False, **kwargs): + """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_deit('deit_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit_base_patch16_384(pretrained=False, **kwargs): + """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_deit('deit_base_patch16_384', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): + """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) + model = _create_deit( + 'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + return model + + +@register_model +def deit_small_distilled_patch16_224(pretrained=False, **kwargs): + """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) + model = _create_deit( + 'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + return model + + +@register_model +def deit_base_distilled_patch16_224(pretrained=False, **kwargs): + """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_deit( + 'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) + return model + + +@register_model +def deit_base_distilled_patch16_384(pretrained=False, **kwargs): + """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). + ImageNet-1k weights from https://github.com/facebookresearch/deit. + """ + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model = _create_deit( + 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) + return model diff --git a/timm/models/levit.py b/timm/models/levit.py index 23f4df31..5c21f50f 100644 --- a/timm/models/levit.py +++ b/timm/models/levit.py @@ -290,10 +290,10 @@ class Attention(nn.Module): qkv = self.qkv(x) q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3) q = q.permute(0, 2, 1, 3) - k = k.permute(0, 2, 1, 3) + k = k.permute(0, 2, 3, 1) v = v.permute(0, 2, 1, 3) - attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) + attn = q @ k * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) @@ -383,11 +383,11 @@ class AttentionSubsample(nn.Module): else: B, N, C = x.shape k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3) - k = k.permute(0, 2, 1, 3) # BHNC + k = k.permute(0, 2, 3, 1) # BHCN v = v.permute(0, 2, 1, 3) # BHNC q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) - attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) + attn = q @ k * self.scale + self.get_attention_biases(x.device) attn = attn.softmax(dim=-1) x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) @@ -519,11 +519,11 @@ class Levit(nn.Module): if not self.use_conv: x = x.flatten(2).transpose(1, 2) x = self.blocks(x) - x = x.mean((-2, -1)) if self.use_conv else x.mean(1) return x def forward(self, x): x = self.forward_features(x) + x = x.mean((-2, -1)) if self.use_conv else x.mean(1) if self.head_dist is not None: x, x_dist = self.head(x), self.head_dist(x) if self.training and not torch.jit.is_scripting(): diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index dc5d70a4..ca20fbc4 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -294,11 +294,11 @@ class MlpMixer(nn.Module): x = self.stem(x) x = self.blocks(x) x = self.norm(x) - x = x.mean(dim=1) return x def forward(self, x): x = self.forward_features(x) + x = x.mean(dim=1) x = self.head(x) return x diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index e92224ab..7171e2ee 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -200,9 +200,10 @@ class MobileNetV3Features(nn.Module): and object detection models. """ - def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3, - stem_size=16, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels, - se_from_exp=True, act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.): + def __init__( + self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3, + stem_size=16, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels, + se_from_exp=True, act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.): super(MobileNetV3Features, self).__init__() act_layer = act_layer or nn.ReLU norm_layer = norm_layer or nn.BatchNorm2d diff --git a/timm/models/pit.py b/timm/models/pit.py index 843880e7..b0788c1e 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -125,10 +125,8 @@ class ConvHeadPooling(nn.Module): self.fc = nn.Linear(in_feature, out_feature) def forward(self, x, cls_token) -> Tuple[torch.Tensor, torch.Tensor]: - x = self.conv(x) cls_token = self.fc(cls_token) - return x, cls_token @@ -225,21 +223,18 @@ class PoolingVisionTransformer(nn.Module): cls_tokens = self.cls_token.expand(x.shape[0], -1, -1) x, cls_tokens = self.transformers((x, cls_tokens)) cls_tokens = self.norm(cls_tokens) - if self.head_dist is not None: - return cls_tokens[:, 0], cls_tokens[:, 1] - else: - return cls_tokens[:, 0] + return cls_tokens def forward(self, x): x = self.forward_features(x) if self.head_dist is not None: - x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple + x, x_dist = self.head(x[:, 0]), self.head_dist(x[:, 1]) # x must be a tuple if self.training and not torch.jit.is_scripting(): return x, x_dist else: return (x + x_dist) / 2 else: - return self.head(x) + return self.head(x[:, 0]) def checkpoint_filter_fn(state_dict, model): diff --git a/timm/models/swin_transformer.py b/timm/models/swin_transformer.py index c3151a74..cd571a0d 100644 --- a/timm/models/swin_transformer.py +++ b/timm/models/swin_transformer.py @@ -14,7 +14,7 @@ Modifications and additions for timm hacked together by / Copyright 2021, Ross W # -------------------------------------------------------- import logging import math -from copy import deepcopy +from functools import partial from typing import Optional import torch @@ -23,9 +23,8 @@ import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .fx_features import register_notrace_function -from .helpers import build_model_with_cfg -from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ -from .layers import _assert +from .helpers import build_model_with_cfg, named_apply +from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_, _assert from .registry import register_model from .vision_transformer import checkpoint_filter_fn, _init_vit_weights @@ -444,15 +443,17 @@ class SwinTransformer(nn.Module): use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False """ - def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, - embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), - window_size=7, mlp_ratio=4., qkv_bias=True, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, - norm_layer=nn.LayerNorm, ape=False, patch_norm=True, - use_checkpoint=False, weight_init='', **kwargs): + def __init__( + self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg', + embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), + window_size=7, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + use_checkpoint=False, weight_init='', **kwargs): super().__init__() - + assert global_pool in ('', 'avg') self.num_classes = num_classes + self.global_pool = global_pool self.num_layers = len(depths) self.embed_dim = embed_dim self.ape = ape @@ -468,18 +469,11 @@ class SwinTransformer(nn.Module): self.patch_grid = self.patch_embed.grid_size # absolute position embedding - if self.ape: - self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) - trunc_normal_(self.absolute_pos_embed, std=.02) - else: - self.absolute_pos_embed = None - + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) if ape else None self.pos_drop = nn.Dropout(p=drop_rate) - # stochastic depth - dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule - # build layers + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule layers = [] for i_layer in range(self.num_layers): layers += [BasicLayer( @@ -500,16 +494,16 @@ class SwinTransformer(nn.Module): self.layers = nn.Sequential(*layers) self.norm = norm_layer(self.num_features) - self.avgpool = nn.AdaptiveAvgPool1d(1) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - assert weight_init in ('jax', 'jax_nlhb', 'nlhb', '') - head_bias = -math.log(self.num_classes) if 'nlhb' in weight_init else 0. - if weight_init.startswith('jax'): - for n, m in self.named_modules(): - _init_vit_weights(m, n, head_bias=head_bias, jax_impl=True) - else: - self.apply(_init_vit_weights) + self.init_weights(weight_init) + + def init_weights(self, mode=''): + assert mode in ('jax', 'jax_nlhb', 'nlhb', '') + if self.absolute_pos_embed is not None: + trunc_normal_(self.absolute_pos_embed, std=.02) + head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. + named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl='jax' in mode), self) @torch.jit.ignore def no_weight_decay(self): @@ -522,8 +516,9 @@ class SwinTransformer(nn.Module): def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=''): + def reset_classifier(self, num_classes, global_pool='avg'): self.num_classes = num_classes + self.global_pool = global_pool self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): @@ -533,12 +528,12 @@ class SwinTransformer(nn.Module): x = self.pos_drop(x) x = self.layers(x) x = self.norm(x) # B L C - x = self.avgpool(x.transpose(1, 2)) # B C 1 - x = torch.flatten(x, 1) return x def forward(self, x): x = self.forward_features(x) + if self.global_pool == 'avg': + x = x.mean(dim=1) x = self.head(x) return x diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 8affc13e..60879ccd 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -226,10 +226,11 @@ class TNT(nn.Module): pixel_embed, patch_embed = blk(pixel_embed, patch_embed) patch_embed = self.norm(patch_embed) - return patch_embed[:, 0] + return patch_embed def forward(self, x): x = self.forward_features(x) + x = x[:, 0] x = self.head(x) return x diff --git a/timm/models/twins.py b/timm/models/twins.py index 6894e5c2..bb82e1fc 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -357,10 +357,11 @@ class Twins(nn.Module): if i < len(self.depths) - 1: x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() x = self.norm(x) - return x.mean(dim=1) # GAP here + return x def forward(self, x): x = self.forward_features(x) + x = x.mean(dim=1) x = self.head(x) return x diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 1bfe30cb..6d89f2bf 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -10,9 +10,6 @@ A PyTorch implement of Vision Transformers as described in: The official jax code is released and available at https://github.com/google-research/vision_transformer -DeiT model defs and weights from https://github.com/facebookresearch/deit, -paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877 - Acknowledgments: * The paper authors for releasing code and weights, thanks! * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out @@ -26,7 +23,6 @@ import math import logging from functools import partial from collections import OrderedDict -from copy import deepcopy import torch import torch.nn as nn @@ -105,6 +101,7 @@ default_cfgs = { 'L_16-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_384.npz', input_size=(3, 384, 384), crop_pct=1.0), + 'vit_large_patch14_224': _cfg(url=''), 'vit_huge_patch14_224': _cfg(url=''), 'vit_giant_patch14_224': _cfg(url=''), 'vit_gigantic_patch14_224': _cfg(url=''), @@ -161,32 +158,6 @@ default_cfgs = { url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - # deit models (FB weights) - 'deit_tiny_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), - 'deit_small_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), - 'deit_base_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), - 'deit_base_patch16_384': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0), - 'deit_tiny_distilled_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), - 'deit_small_distilled_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), - 'deit_base_distilled_patch16_224': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, classifier=('head', 'head_dist')), - 'deit_base_distilled_patch16_384': _cfg( - url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, input_size=(3, 384, 384), crop_pct=1.0, - classifier=('head', 'head_dist')), # ViT ImageNet-21K-P pretraining by MILL 'vit_base_patch16_224_miil_in21k': _cfg( @@ -253,15 +224,13 @@ class VisionTransformer(nn.Module): A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - https://arxiv.org/abs/2010.11929 - - Includes distillation token & head support for `DeiT: Data-efficient Image Transformers` - - https://arxiv.org/abs/2012.12877 """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, - num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, distilled=False, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None, - act_layer=None, weight_init=''): + def __init__( + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, + num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, global_pool='', + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', + embed_layer=PatchEmbed, norm_layer=None, act_layer=None): """ Args: img_size (int, tuple): input image size @@ -274,18 +243,19 @@ class VisionTransformer(nn.Module): mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set - distilled (bool): model includes a distillation token and head as in DeiT models + weight_init: (str): weight init scheme drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate embed_layer (nn.Module): patch embedding layer norm_layer: (nn.Module): normalization layer - weight_init: (str): weight init scheme + act_layer: (nn.Module): MLP activation layer """ super().__init__() self.num_classes = num_classes + self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.num_tokens = 2 if distilled else 1 + self.num_tokens = 1 norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU @@ -294,7 +264,6 @@ class VisionTransformer(nn.Module): num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) @@ -304,38 +273,41 @@ class VisionTransformer(nn.Module): dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) for i in range(depth)]) - self.norm = norm_layer(embed_dim) - - # Representation layer - if representation_size and not distilled: - self.num_features = representation_size + use_fc_norm = self.global_pool == 'avg' + self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() + + # Representation layer. Used for original ViT models w/ in21k pretraining. + self.representation_size = representation_size + self.pre_logits = nn.Identity() + if representation_size: + self._reset_representation(representation_size) + + # Classifier Head + self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() + final_chs = self.representation_size if self.representation_size else self.embed_dim + self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity() + + if weight_init != 'skip': + self.init_weights(weight_init) + + def _reset_representation(self, representation_size): + self.representation_size = representation_size + if self.representation_size: self.pre_logits = nn.Sequential(OrderedDict([ - ('fc', nn.Linear(embed_dim, representation_size)), + ('fc', nn.Linear(self.embed_dim, self.representation_size)), ('act', nn.Tanh()) ])) else: self.pre_logits = nn.Identity() - # Classifier head(s) - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - self.head_dist = None - if distilled: - self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() - - self.init_weights(weight_init) - def init_weights(self, mode=''): assert mode in ('jax', 'jax_nlhb', 'nlhb', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. trunc_normal_(self.pos_embed, std=.02) - if self.dist_token is not None: - trunc_normal_(self.dist_token, std=.02) - if mode.startswith('jax'): - # leave cls token as zeros to match jax impl - named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self) - else: + if 'jax' not in mode: + # init cls token to truncated normal if not following jax impl, jax impl is zero trunc_normal_(self.cls_token, std=.02) - self.apply(_init_vit_weights) + named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl='jax' in mode), self) def _init_weights(self, m): # this fn left here for compat with downstream users @@ -350,43 +322,33 @@ class VisionTransformer(nn.Module): return {'pos_embed', 'cls_token', 'dist_token'} def get_classifier(self): - if self.dist_token is None: - return self.head - else: - return self.head, self.head_dist + return self.head - def reset_classifier(self, num_classes, global_pool=''): + def reset_classifier(self, num_classes, global_pool='', representation_size=None): self.num_classes = num_classes - self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() - if self.num_tokens == 2: - self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity() + self.global_pool = global_pool + if representation_size is not None: + self._reset_representation(representation_size) + final_chs = self.representation_size if self.representation_size else self.embed_dim + self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.patch_embed(x) - cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks - if self.dist_token is None: - x = torch.cat((cls_token, x), dim=1) - else: - x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1) + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = self.pos_drop(x + self.pos_embed) x = self.blocks(x) x = self.norm(x) - if self.dist_token is None: - return self.pre_logits(x[:, 0]) - else: - return x[:, 0], x[:, 1] + return x def forward(self, x): x = self.forward_features(x) - if self.head_dist is not None: - x, x_dist = self.head(x[0]), self.head_dist(x[1]) # x must be a tuple - if self.training and not torch.jit.is_scripting(): - # during inference, return the average of both classifier predictions - return x, x_dist - else: - return (x + x_dist) / 2 + if self.global_pool == 'avg': + x = x[:, self.num_tokens:].mean(dim=1) else: - x = self.head(x) + x = x[:, 0] + x = self.fc_norm(x) + x = self.pre_logits(x) + x = self.head(x) return x @@ -708,7 +670,7 @@ def vit_large_patch32_384(pretrained=False, **kwargs): @register_model def vit_large_patch16_224(pretrained=False, **kwargs): - """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). + """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) @@ -726,6 +688,15 @@ def vit_large_patch16_384(pretrained=False, **kwargs): return model +@register_model +def vit_large_patch14_224(pretrained=False, **kwargs): + """ ViT-Large model (ViT-L/14) + """ + model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_huge_patch14_224(pretrained=False, **kwargs): """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). @@ -914,90 +885,6 @@ def vit_base_patch8_224_dino(pretrained=False, **kwargs): return model -@register_model -def deit_tiny_patch16_224(pretrained=False, **kwargs): - """ DeiT-tiny model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). - ImageNet-1k weights from https://github.com/facebookresearch/deit. - """ - model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) - model = _create_vision_transformer('deit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def deit_small_patch16_224(pretrained=False, **kwargs): - """ DeiT-small model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). - ImageNet-1k weights from https://github.com/facebookresearch/deit. - """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('deit_small_patch16_224', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def deit_base_patch16_224(pretrained=False, **kwargs): - """ DeiT base model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). - ImageNet-1k weights from https://github.com/facebookresearch/deit. - """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('deit_base_patch16_224', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def deit_base_patch16_384(pretrained=False, **kwargs): - """ DeiT base model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). - ImageNet-1k weights from https://github.com/facebookresearch/deit. - """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('deit_base_patch16_384', pretrained=pretrained, **model_kwargs) - return model - - -@register_model -def deit_tiny_distilled_patch16_224(pretrained=False, **kwargs): - """ DeiT-tiny distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). - ImageNet-1k weights from https://github.com/facebookresearch/deit. - """ - model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) - model = _create_vision_transformer( - 'deit_tiny_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) - return model - - -@register_model -def deit_small_distilled_patch16_224(pretrained=False, **kwargs): - """ DeiT-small distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). - ImageNet-1k weights from https://github.com/facebookresearch/deit. - """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer( - 'deit_small_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) - return model - - -@register_model -def deit_base_distilled_patch16_224(pretrained=False, **kwargs): - """ DeiT-base distilled model @ 224x224 from paper (https://arxiv.org/abs/2012.12877). - ImageNet-1k weights from https://github.com/facebookresearch/deit. - """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer( - 'deit_base_distilled_patch16_224', pretrained=pretrained, distilled=True, **model_kwargs) - return model - - -@register_model -def deit_base_distilled_patch16_384(pretrained=False, **kwargs): - """ DeiT-base distilled model @ 384x384 from paper (https://arxiv.org/abs/2012.12877). - ImageNet-1k weights from https://github.com/facebookresearch/deit. - """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer( - 'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs) - return model - - @register_model def vit_base_patch16_224_miil_in21k(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). diff --git a/timm/models/xcit.py b/timm/models/xcit.py index 55cb27c1..91c99fc5 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -426,17 +426,17 @@ class XCiT(nn.Module): for blk in self.blocks: x = blk(x, Hp, Wp) - cls_tokens = self.cls_token.expand(B, -1, -1) - x = torch.cat((cls_tokens, x), dim=1) + x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1) for blk in self.cls_attn_blocks: x = blk(x) - x = self.norm(x)[:, 0] + x = self.norm(x) return x def forward(self, x): x = self.forward_features(x) + x = x[:, 0] x = self.head(x) return x