fix some bugs to support segmentation task

pull/659/head
lixinjie 4 years ago
parent f8a16f5267
commit 041a85fe7b

@ -291,10 +291,10 @@ class Twins(nn.Module):
patch_size = 2
self.task = task
self.F4=F4
self.extra_norm = extra_norm
self.strides = strides
if self.task == 'seg':
self.F4=F4
self.extra_norm = extra_norm
self.strides = strides
if self.extra_norm:
self.norm_list = nn.ModuleList()
for dim in embed_dims:
@ -363,7 +363,6 @@ class Twins(nn.Module):
m.bias.data.zero_()
def forward_features(self, x):
if self.task == 'cls':
B = x.shape[0]
for i, (embed, drop, blocks, pos_blk) in enumerate(
@ -404,7 +403,6 @@ class Twins(nn.Module):
if self.F4:
x = x[3:4]
return x
def _create_twins(variant, pretrained=False, **kwargs):
@ -466,4 +464,4 @@ def twins_svt_large(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4],
depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs)
return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs)
return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs)

Loading…
Cancel
Save