diff --git a/timm/models/twins.py b/timm/models/twins.py index 3b15ed93..ff973ad8 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -291,10 +291,10 @@ class Twins(nn.Module): patch_size = 2 self.task = task + self.F4=F4 + self.extra_norm = extra_norm + self.strides = strides if self.task == 'seg': - self.F4=F4 - self.extra_norm = extra_norm - self.strides = strides if self.extra_norm: self.norm_list = nn.ModuleList() for dim in embed_dims: @@ -363,7 +363,6 @@ class Twins(nn.Module): m.bias.data.zero_() def forward_features(self, x): - if self.task == 'cls': B = x.shape[0] for i, (embed, drop, blocks, pos_blk) in enumerate( @@ -404,7 +403,6 @@ class Twins(nn.Module): if self.F4: x = x[3:4] return x - def _create_twins(variant, pretrained=False, **kwargs): @@ -466,4 +464,4 @@ def twins_svt_large(pretrained=False, **kwargs): model_kwargs = dict( patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4], depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs) - return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs) \ No newline at end of file + return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs)