fix some bugs

pull/659/head
lixinjie 4 years ago
parent 56bd5822f1
commit aebc5b58c9

@ -179,7 +179,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
EXCLUDE_JIT_FILTERS = [
'*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable
'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point
'vit_large_*', 'vit_huge_*',
'vit_large_*', 'vit_huge_*', 'twins_*',
]

@ -291,12 +291,12 @@ class Twins(nn.Module):
patch_size = 2
self.task = task
self.F4=F4
self.extra_norm = extra_norm
self.strides = strides
self.norm_list = nn.ModuleList()
if self.task == 'seg':
self.F4=F4
self.extra_norm = extra_norm
self.strides = strides
if self.extra_norm:
self.norm_list = nn.ModuleList()
for dim in embed_dims:
self.norm_list.append(norm_layer(dim))
@ -381,30 +381,18 @@ class Twins(nn.Module):
if self.task == 'seg':
outputs = list()
B = x.shape[0]
if self.extra_norm:
for i, (embed, drop, blocks, pos_blk, norm_list) in enumerate(
zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block, self.norm_list)):
x, size = embed(x)
x = drop(x)
for j, blk in enumerate(blocks):
x = blk(x, size)
if j == 0:
x = pos_blk(x, size) # PEG here
x = norm_list(x)
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
outputs.append(x)
else:
for i, (embed, drop, blocks, pos_blk) in enumerate(
zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)):
x, size = embed(x)
x = drop(x)
for j, blk in enumerate(blocks):
x = blk(x, size)
if j == 0:
x = pos_blk(x, size) # PEG here
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
outputs.append(x)
for i, (embed, drop, blocks, pos_blk) in enumerate(
zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)):
x, size = embed(x)
x = drop(x)
for j, blk in enumerate(blocks):
x = blk(x, size)
if j == 0:
x = pos_blk(x, size) # PEG here
if self.extra_norm:
x = self.norm_list[i](x)
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
outputs.append(x)
return outputs
def forward(self, x):

Loading…
Cancel
Save