|
|
|
@ -135,6 +135,11 @@ model_cfgs = dict(
|
|
|
|
|
num_heads=2,
|
|
|
|
|
expand_attn=False,
|
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
mvitv2_small_cls=MultiScaleVitCfg(
|
|
|
|
|
depths=(1, 2, 11, 2),
|
|
|
|
|
use_cls_token=True,
|
|
|
|
|
),
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -641,7 +646,7 @@ class MultiScaleBlock(nn.Module):
|
|
|
|
|
if self.shortcut_pool_attn is None:
|
|
|
|
|
return x
|
|
|
|
|
if self.has_cls_token:
|
|
|
|
|
cls_tok, x = x[:, :, :1, :], x[:, :, 1:, :]
|
|
|
|
|
cls_tok, x = x[:, :1, :], x[:, 1:, :]
|
|
|
|
|
else:
|
|
|
|
|
cls_tok = None
|
|
|
|
|
B, L, C = x.shape
|
|
|
|
@ -650,7 +655,7 @@ class MultiScaleBlock(nn.Module):
|
|
|
|
|
x = self.shortcut_pool_attn(x)
|
|
|
|
|
x = x.reshape(B, C, -1).transpose(1, 2)
|
|
|
|
|
if cls_tok is not None:
|
|
|
|
|
x = torch.cat((cls_tok, x), dim=2)
|
|
|
|
|
x = torch.cat((cls_tok, x), dim=1)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def forward(self, x, feat_size: List[int]):
|
|
|
|
@ -996,3 +1001,8 @@ def mvitv2_large(pretrained=False, **kwargs):
|
|
|
|
|
# @register_model
|
|
|
|
|
# def mvitv2_huge_in21k(pretrained=False, **kwargs):
|
|
|
|
|
# return _create_mvitv2('mvitv2_huge_in21k', pretrained=pretrained, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
def mvitv2_small_cls(pretrained=False, **kwargs):
|
|
|
|
|
return _create_mvitv2('mvitv2_small_cls', pretrained=pretrained, **kwargs)
|
|
|
|
|