|
|
|
@ -291,10 +291,10 @@ class Twins(nn.Module):
|
|
|
|
|
patch_size = 2
|
|
|
|
|
|
|
|
|
|
self.task = task
|
|
|
|
|
self.F4=F4
|
|
|
|
|
self.extra_norm = extra_norm
|
|
|
|
|
self.strides = strides
|
|
|
|
|
if self.task == 'seg':
|
|
|
|
|
self.F4=F4
|
|
|
|
|
self.extra_norm = extra_norm
|
|
|
|
|
self.strides = strides
|
|
|
|
|
if self.extra_norm:
|
|
|
|
|
self.norm_list = nn.ModuleList()
|
|
|
|
|
for dim in embed_dims:
|
|
|
|
@ -363,7 +363,6 @@ class Twins(nn.Module):
|
|
|
|
|
m.bias.data.zero_()
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
|
|
|
|
|
if self.task == 'cls':
|
|
|
|
|
B = x.shape[0]
|
|
|
|
|
for i, (embed, drop, blocks, pos_blk) in enumerate(
|
|
|
|
@ -404,7 +403,6 @@ class Twins(nn.Module):
|
|
|
|
|
if self.F4:
|
|
|
|
|
x = x[3:4]
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_twins(variant, pretrained=False, **kwargs):
|
|
|
|
@ -466,4 +464,4 @@ def twins_svt_large(pretrained=False, **kwargs):
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
|
patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4],
|
|
|
|
|
depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs)
|
|
|
|
|
return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs)
|
|
|
|
|