pull/659/head
lixinjie 4 years ago
parent 8d9ebe3788
commit 56bd5822f1

@ -381,18 +381,30 @@ class Twins(nn.Module):
if self.task == 'seg': if self.task == 'seg':
outputs = list() outputs = list()
B = x.shape[0] B = x.shape[0]
for i, (embed, drop, blocks, pos_blk) in enumerate( if self.extra_norm:
zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)): for i, (embed, drop, blocks, pos_blk, norm_list) in enumerate(
x, size = embed(x) zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block, self.norm_list)):
x = drop(x) x, size = embed(x)
for j, blk in enumerate(blocks): x = drop(x)
x = blk(x, size) for j, blk in enumerate(blocks):
if j == 0: x = blk(x, size)
x = pos_blk(x, size) # PEG here if j == 0:
if self.extra_norm: x = pos_blk(x, size) # PEG here
x = self.norm_list[i](x) x = norm_list(x)
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous() x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
outputs.append(x) outputs.append(x)
else:
for i, (embed, drop, blocks, pos_blk) in enumerate(
zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)):
x, size = embed(x)
x = drop(x)
for j, blk in enumerate(blocks):
x = blk(x, size)
if j == 0:
x = pos_blk(x, size) # PEG here
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
outputs.append(x)
return outputs return outputs
def forward(self, x): def forward(self, x):

Loading…
Cancel
Save