diff --git a/timm/models/davit.py b/timm/models/davit.py index dd089469..f8e4a262 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -518,18 +518,22 @@ class DaViT(nn.Module): branches.append(branch_id) ''' + block_index : int + if block_index not in branches: - x, size = self.patch_embeds[int(block_index)](features[-1], sizes[-1]) + x, size = self.patch_embeds[block_index](features[-1], sizes[-1]) features.append(x) sizes.append(size) branches.append(branch_id) for layer_index, branch_id in enumerate(block_param): + layer_index : int + branch_id : int if self.grad_checkpointing and not torch.jit.is_scripting(): - features[int(branch_id)], _ = checkpoint.checkpoint(self.main_blocks[int(block_index)][int(layer_index)], features[int(branch_id)], sizes[int(branch_id)]) + features[branch_id], _ = checkpoint.checkpoint(self.main_blocks[block_index][layer_index], features[branch_id], sizes[branch_id]) else: - features[int(branch_id)], _ = self.main_blocks[int(block_index)][int(layer_index)](features[int(branch_id)], sizes[int(branch_id)]) + features[branch_id], _ = self.main_blocks[block_index][layer_index](features[branch_id], sizes[branch_id]) ''' # pyramid feature norm logic, no weights for these extra norm layers from pretrained classification model outs = []