Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 90a71c35b3
commit e222f474f4

@ -518,18 +518,22 @@ class DaViT(nn.Module):
branches.append(branch_id) branches.append(branch_id)
''' '''
block_index : int
if block_index not in branches: if block_index not in branches:
x, size = self.patch_embeds[int(block_index)](features[-1], sizes[-1]) x, size = self.patch_embeds[block_index](features[-1], sizes[-1])
features.append(x) features.append(x)
sizes.append(size) sizes.append(size)
branches.append(branch_id) branches.append(branch_id)
for layer_index, branch_id in enumerate(block_param): for layer_index, branch_id in enumerate(block_param):
layer_index : int
branch_id : int
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():
features[int(branch_id)], _ = checkpoint.checkpoint(self.main_blocks[int(block_index)][int(layer_index)], features[int(branch_id)], sizes[int(branch_id)]) features[branch_id], _ = checkpoint.checkpoint(self.main_blocks[block_index][layer_index], features[branch_id], sizes[branch_id])
else: else:
features[int(branch_id)], _ = self.main_blocks[int(block_index)][int(layer_index)](features[int(branch_id)], sizes[int(branch_id)]) features[branch_id], _ = self.main_blocks[block_index][layer_index](features[branch_id], sizes[branch_id])
''' '''
# pyramid feature norm logic, no weights for these extra norm layers from pretrained classification model # pyramid feature norm logic, no weights for these extra norm layers from pretrained classification model
outs = [] outs = []

Loading…
Cancel
Save