|
|
@ -381,18 +381,30 @@ class Twins(nn.Module):
|
|
|
|
if self.task == 'seg':
|
|
|
|
if self.task == 'seg':
|
|
|
|
outputs = list()
|
|
|
|
outputs = list()
|
|
|
|
B = x.shape[0]
|
|
|
|
B = x.shape[0]
|
|
|
|
for i, (embed, drop, blocks, pos_blk) in enumerate(
|
|
|
|
if self.extra_norm:
|
|
|
|
zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)):
|
|
|
|
for i, (embed, drop, blocks, pos_blk, norm_list) in enumerate(
|
|
|
|
x, size = embed(x)
|
|
|
|
zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block, self.norm_list)):
|
|
|
|
x = drop(x)
|
|
|
|
x, size = embed(x)
|
|
|
|
for j, blk in enumerate(blocks):
|
|
|
|
x = drop(x)
|
|
|
|
x = blk(x, size)
|
|
|
|
for j, blk in enumerate(blocks):
|
|
|
|
if j == 0:
|
|
|
|
x = blk(x, size)
|
|
|
|
x = pos_blk(x, size) # PEG here
|
|
|
|
if j == 0:
|
|
|
|
if self.extra_norm:
|
|
|
|
x = pos_blk(x, size) # PEG here
|
|
|
|
x = self.norm_list[i](x)
|
|
|
|
x = norm_list(x)
|
|
|
|
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
|
|
|
|
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
|
|
|
|
outputs.append(x)
|
|
|
|
outputs.append(x)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
for i, (embed, drop, blocks, pos_blk) in enumerate(
|
|
|
|
|
|
|
|
zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)):
|
|
|
|
|
|
|
|
x, size = embed(x)
|
|
|
|
|
|
|
|
x = drop(x)
|
|
|
|
|
|
|
|
for j, blk in enumerate(blocks):
|
|
|
|
|
|
|
|
x = blk(x, size)
|
|
|
|
|
|
|
|
if j == 0:
|
|
|
|
|
|
|
|
x = pos_blk(x, size) # PEG here
|
|
|
|
|
|
|
|
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
|
|
|
|
|
|
|
|
outputs.append(x)
|
|
|
|
|
|
|
|
|
|
|
|
return outputs
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|