From 56bd5822f12666402b51877c2296a532d1d8fdce Mon Sep 17 00:00:00 2001 From: lixinjie Date: Mon, 31 May 2021 12:51:25 +0800 Subject: [PATCH] fix --- timm/models/twins.py | 36 ++++++++++++++++++++++++------------ 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/timm/models/twins.py b/timm/models/twins.py index c69bcf9b..79ef9e08 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -381,18 +381,30 @@ class Twins(nn.Module): if self.task == 'seg': outputs = list() B = x.shape[0] - for i, (embed, drop, blocks, pos_blk) in enumerate( - zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)): - x, size = embed(x) - x = drop(x) - for j, blk in enumerate(blocks): - x = blk(x, size) - if j == 0: - x = pos_blk(x, size) # PEG here - if self.extra_norm: - x = self.norm_list[i](x) - x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() - outputs.append(x) + if self.extra_norm: + for i, (embed, drop, blocks, pos_blk, norm_list) in enumerate( + zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block, self.norm_list)): + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j == 0: + x = pos_blk(x, size) # PEG here + x = norm_list(x) + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + outputs.append(x) + else: + for i, (embed, drop, blocks, pos_blk) in enumerate( + zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)): + x, size = embed(x) + x = drop(x) + for j, blk in enumerate(blocks): + x = blk(x, size) + if j == 0: + x = pos_blk(x, size) # PEG here + x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() + outputs.append(x) + return outputs def forward(self, x):