From f66e5f0e35c83f086f5fe788af7d6ef5f1804001 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 28 Aug 2022 15:24:04 -0700 Subject: [PATCH] Fix class token support in MViT-V2, add small_class variant to ensure it's tested. Fix #1443 --- timm/models/mvitv2.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index 002225c6..76e70135 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -135,6 +135,11 @@ model_cfgs = dict( num_heads=2, expand_attn=False, ), + + mvitv2_small_cls=MultiScaleVitCfg( + depths=(1, 2, 11, 2), + use_cls_token=True, + ), ) @@ -641,7 +646,7 @@ class MultiScaleBlock(nn.Module): if self.shortcut_pool_attn is None: return x if self.has_cls_token: - cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :] + cls_tok, x = x[:, :1, :], x[:, 1:, :] else: cls_tok = None B, L, C = x.shape @@ -650,7 +655,7 @@ class MultiScaleBlock(nn.Module): x = self.shortcut_pool_attn(x) x = x.reshape(B, C, -1).transpose(1, 2) if cls_tok is not None: - x = torch.cat((cls_tok, x), dim=2) + x = torch.cat((cls_tok, x), dim=1) return x def forward(self, x, feat_size: List[int]): @@ -996,3 +1001,8 @@ def mvitv2_large(pretrained=False, **kwargs): # @register_model # def mvitv2_huge_in21k(pretrained=False, **kwargs): # return _create_mvitv2('mvitv2_huge_in21k', pretrained=pretrained, **kwargs) + + +@register_model +def mvitv2_small_cls(pretrained=False, **kwargs): + return _create_mvitv2('mvitv2_small_cls', pretrained=pretrained, **kwargs)