fix some bugs

pull/659/head
lixinjie 4 years ago
parent 041a85fe7b
commit bd996241f6

@ -294,9 +294,9 @@ class Twins(nn.Module):
self.F4=F4
self.extra_norm = extra_norm
self.strides = strides
self.norm_list = nn.ModuleList()
if self.task == 'seg':
if self.extra_norm:
self.norm_list = nn.ModuleList()
for dim in embed_dims:
self.norm_list.append(norm_layer(dim))
@ -311,6 +311,8 @@ class Twins(nn.Module):
self.patch_embeds.append(
PatchEmbed(img_size // patch_size // s, strides[i-1], embed_dims[i - 1], embed_dims[i]))
s = s * strides[i-1]
if self.task == 'cls':
del self.norm_list
self.blocks = nn.ModuleList()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule

Loading…
Cancel
Save