|
|
|
@ -523,9 +523,9 @@ class DaViT(nn.Module):
|
|
|
|
|
features.append(features[-1])
|
|
|
|
|
sizes.append(sizes[-1])
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
for block_index, block_param in enumerate(self.architecture):
|
|
|
|
|
|
|
|
|
|
branch_ids = sorted(set(block_param))
|
|
|
|
@ -562,7 +562,12 @@ class DaViT(nn.Module):
|
|
|
|
|
H, W = sizes[i]
|
|
|
|
|
out = x_out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()
|
|
|
|
|
outs.append(out)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# non-normalized pyramid features + corresponding sizes
|
|
|
|
|
return tuple(features), tuple(sizes)
|
|
|
|
|
|
|
|
|
|