From bd996241f68f2766da933497f72e4d3ce2a5017a Mon Sep 17 00:00:00 2001 From: lixinjie Date: Sun, 30 May 2021 21:25:28 +0800 Subject: [PATCH] fix some bugs --- timm/models/twins.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/timm/models/twins.py b/timm/models/twins.py index ff973ad8..eded603a 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -294,9 +294,9 @@ class Twins(nn.Module): self.F4=F4 self.extra_norm = extra_norm self.strides = strides + self.norm_list = nn.ModuleList() if self.task == 'seg': if self.extra_norm: - self.norm_list = nn.ModuleList() for dim in embed_dims: self.norm_list.append(norm_layer(dim)) @@ -311,6 +311,8 @@ class Twins(nn.Module): self.patch_embeds.append( PatchEmbed(img_size // patch_size // s, strides[i-1], embed_dims[i - 1], embed_dims[i])) s = s * strides[i-1] + if self.task == 'cls': + del self.norm_list self.blocks = nn.ModuleList() dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule