Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 234748cf30
commit 96e0b06ee1

@ -523,9 +523,9 @@ class DaViT(nn.Module):
features.append(features[-1]) features.append(features[-1])
sizes.append(sizes[-1]) sizes.append(sizes[-1])
'''
'''
for block_index, block_param in enumerate(self.architecture): for block_index, block_param in enumerate(self.architecture):
branch_ids = sorted(set(block_param)) branch_ids = sorted(set(block_param))
@ -562,7 +562,12 @@ class DaViT(nn.Module):
H, W = sizes[i] H, W = sizes[i]
out = x_out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous() out = x_out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()
outs.append(out) outs.append(out)
''' '''
# non-normalized pyramid features + corresponding sizes # non-normalized pyramid features + corresponding sizes
return tuple(features), tuple(sizes) return tuple(features), tuple(sizes)

Loading…
Cancel
Save