pull/659/head
lixinjie 4 years ago
parent 8d9ebe3788
commit 56bd5822f1

@ -381,18 +381,30 @@ class Twins(nn.Module):
if self.task == 'seg':
outputs = list()
B = x.shape[0]
for i, (embed, drop, blocks, pos_blk) in enumerate(
zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)):
x, size = embed(x)
x = drop(x)
for j, blk in enumerate(blocks):
x = blk(x, size)
if j == 0:
x = pos_blk(x, size) # PEG here
if self.extra_norm:
x = self.norm_list[i](x)
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
outputs.append(x)
if self.extra_norm:
for i, (embed, drop, blocks, pos_blk, norm_list) in enumerate(
zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block, self.norm_list)):
x, size = embed(x)
x = drop(x)
for j, blk in enumerate(blocks):
x = blk(x, size)
if j == 0:
x = pos_blk(x, size) # PEG here
x = norm_list(x)
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
outputs.append(x)
else:
for i, (embed, drop, blocks, pos_blk) in enumerate(
zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)):
x, size = embed(x)
x = drop(x)
for j, blk in enumerate(blocks):
x = blk(x, size)
if j == 0:
x = pos_blk(x, size) # PEG here
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
outputs.append(x)
return outputs
def forward(self, x):

Loading…
Cancel
Save