Fix class token support in MViT-V2, add small_class variant to ensure it's tested. Fix #1443

pull/804/merge
Ross Wightman 2 years ago
parent b94b7cea65
commit f66e5f0e35

@ -135,6 +135,11 @@ model_cfgs = dict(
num_heads=2,
expand_attn=False,
),
mvitv2_small_cls=MultiScaleVitCfg(
depths=(1, 2, 11, 2),
use_cls_token=True,
),
)
@ -641,7 +646,7 @@ class MultiScaleBlock(nn.Module):
if self.shortcut_pool_attn is None:
return x
if self.has_cls_token:
cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :]
cls_tok, x = x[:, :1, :], x[:, 1:, :]
else:
cls_tok = None
B, L, C = x.shape
@ -650,7 +655,7 @@ class MultiScaleBlock(nn.Module):
x = self.shortcut_pool_attn(x)
x = x.reshape(B, C, -1).transpose(1, 2)
if cls_tok is not None:
x = torch.cat((cls_tok, x), dim=2)
x = torch.cat((cls_tok, x), dim=1)
return x
def forward(self, x, feat_size: List[int]):
@ -996,3 +1001,8 @@ def mvitv2_large(pretrained=False, **kwargs):
# @register_model
# def mvitv2_huge_in21k(pretrained=False, **kwargs):
# return _create_mvitv2('mvitv2_huge_in21k', pretrained=pretrained, **kwargs)
@register_model
def mvitv2_small_cls(pretrained=False, **kwargs):
return _create_mvitv2('mvitv2_small_cls', pretrained=pretrained, **kwargs)

Loading…
Cancel
Save