Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent f1fff20524
commit 90a71c35b3

@ -519,7 +519,7 @@ class DaViT(nn.Module):
'''
if block_index not in branches:
x, size = self.patch_embeds[block_index](features[-1], sizes[-1])
x, size = self.patch_embeds[int(block_index)](features[-1], sizes[-1])
features.append(x)
sizes.append(size)
branches.append(branch_id)
@ -527,9 +527,9 @@ class DaViT(nn.Module):
for layer_index, branch_id in enumerate(block_param):
if self.grad_checkpointing and not torch.jit.is_scripting():
features[branch_id], _ = checkpoint.checkpoint(self.main_blocks[block_index][layer_index], features[branch_id], sizes[branch_id])
features[int(branch_id)], _ = checkpoint.checkpoint(self.main_blocks[int(block_index)][int(layer_index)], features[int(branch_id)], sizes[int(branch_id)])
else:
features[branch_id], _ = self.main_blocks[block_index][layer_index](features[branch_id], sizes[branch_id])
features[int(branch_id)], _ = self.main_blocks[int(block_index)][int(layer_index)](features[int(branch_id)], sizes[int(branch_id)])
'''
# pyramid feature norm logic, no weights for these extra norm layers from pretrained classification model
outs = []

Loading…
Cancel
Save