diff --git a/timm/models/mvitv2.py b/timm/models/mvitv2.py index 002225c6..76e70135 100644 --- a/timm/models/mvitv2.py +++ b/timm/models/mvitv2.py @@ -135,6 +135,11 @@ model_cfgs = dict( num_heads=2, expand_attn=False, ), + + mvitv2_small_cls=MultiScaleVitCfg( + depths=(1, 2, 11, 2), + use_cls_token=True, + ), ) @@ -641,7 +646,7 @@ class MultiScaleBlock(nn.Module): if self.shortcut_pool_attn is None: return x if self.has_cls_token: - cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :] + cls_tok, x = x[:, :1, :], x[:, 1:, :] else: cls_tok = None B, L, C = x.shape @@ -650,7 +655,7 @@ class MultiScaleBlock(nn.Module): x = self.shortcut_pool_attn(x) x = x.reshape(B, C, -1).transpose(1, 2) if cls_tok is not None: - x = torch.cat((cls_tok, x), dim=2) + x = torch.cat((cls_tok, x), dim=1) return x def forward(self, x, feat_size: List[int]): @@ -996,3 +1001,8 @@ def mvitv2_large(pretrained=False, **kwargs): # @register_model # def mvitv2_huge_in21k(pretrained=False, **kwargs): # return _create_mvitv2('mvitv2_huge_in21k', pretrained=pretrained, **kwargs) + + +@register_model +def mvitv2_small_cls(pretrained=False, **kwargs): + return _create_mvitv2('mvitv2_small_cls', pretrained=pretrained, **kwargs)