From 90a71c35b369ecef390d2ad15e2f494ce33e9dae Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 6 Dec 2022 22:27:13 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index b0a6dcff..dd089469 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -519,7 +519,7 @@ class DaViT(nn.Module): ''' if block_index not in branches: - x, size = self.patch_embeds[block_index](features[-1], sizes[-1]) + x, size = self.patch_embeds[int(block_index)](features[-1], sizes[-1]) features.append(x) sizes.append(size) branches.append(branch_id) @@ -527,9 +527,9 @@ class DaViT(nn.Module): for layer_index, branch_id in enumerate(block_param): if self.grad_checkpointing and not torch.jit.is_scripting(): - features[branch_id], _ = checkpoint.checkpoint(self.main_blocks[block_index][layer_index], features[branch_id], sizes[branch_id]) + features[int(branch_id)], _ = checkpoint.checkpoint(self.main_blocks[int(block_index)][int(layer_index)], features[int(branch_id)], sizes[int(branch_id)]) else: - features[branch_id], _ = self.main_blocks[block_index][layer_index](features[branch_id], sizes[branch_id]) + features[int(branch_id)], _ = self.main_blocks[int(block_index)][int(layer_index)](features[int(branch_id)], sizes[int(branch_id)]) ''' # pyramid feature norm logic, no weights for these extra norm layers from pretrained classification model outs = []