|
|
|
@ -294,9 +294,9 @@ class Twins(nn.Module):
|
|
|
|
|
self.F4=F4
|
|
|
|
|
self.extra_norm = extra_norm
|
|
|
|
|
self.strides = strides
|
|
|
|
|
self.norm_list = nn.ModuleList()
|
|
|
|
|
if self.task == 'seg':
|
|
|
|
|
if self.extra_norm:
|
|
|
|
|
self.norm_list = nn.ModuleList()
|
|
|
|
|
for dim in embed_dims:
|
|
|
|
|
self.norm_list.append(norm_layer(dim))
|
|
|
|
|
|
|
|
|
@ -311,6 +311,8 @@ class Twins(nn.Module):
|
|
|
|
|
self.patch_embeds.append(
|
|
|
|
|
PatchEmbed(img_size // patch_size // s, strides[i-1], embed_dims[i - 1], embed_dims[i]))
|
|
|
|
|
s = s * strides[i-1]
|
|
|
|
|
if self.task == 'cls':
|
|
|
|
|
del self.norm_list
|
|
|
|
|
|
|
|
|
|
self.blocks = nn.ModuleList()
|
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
|
|
|
|