diff --git a/timm/models/twins.py b/timm/models/twins.py index ff973ad8..eded603a 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -294,9 +294,9 @@ class Twins(nn.Module): self.F4=F4 self.extra_norm = extra_norm self.strides = strides + self.norm_list = nn.ModuleList() if self.task == 'seg': if self.extra_norm: - self.norm_list = nn.ModuleList() for dim in embed_dims: self.norm_list.append(norm_layer(dim)) @@ -311,6 +311,8 @@ class Twins(nn.Module): self.patch_embeds.append( PatchEmbed(img_size // patch_size // s, strides[i-1], embed_dims[i - 1], embed_dims[i])) s = s * strides[i-1] + if self.task == 'cls': + del self.norm_list self.blocks = nn.ModuleList() dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule