From e222f474f421f208e7d6d6cc7c25663e04bef16c Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 6 Dec 2022 22:30:58 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index dd089469..f8e4a262 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -518,18 +518,22 @@ class DaViT(nn.Module): branches.append(branch_id) ''' + block_index : int + if block_index not in branches: - x, size = self.patch_embeds[int(block_index)](features[-1], sizes[-1]) + x, size = self.patch_embeds[block_index](features[-1], sizes[-1]) features.append(x) sizes.append(size) branches.append(branch_id) for layer_index, branch_id in enumerate(block_param): + layer_index : int + branch_id : int if self.grad_checkpointing and not torch.jit.is_scripting(): - features[int(branch_id)], _ = checkpoint.checkpoint(self.main_blocks[int(block_index)][int(layer_index)], features[int(branch_id)], sizes[int(branch_id)]) + features[branch_id], _ = checkpoint.checkpoint(self.main_blocks[block_index][layer_index], features[branch_id], sizes[branch_id]) else: - features[int(branch_id)], _ = self.main_blocks[int(block_index)][int(layer_index)](features[int(branch_id)], sizes[int(branch_id)]) + features[branch_id], _ = self.main_blocks[block_index][layer_index](features[branch_id], sizes[branch_id]) ''' # pyramid feature norm logic, no weights for these extra norm layers from pretrained classification model outs = []