diff --git a/timm/models/davit.py b/timm/models/davit.py index 02098ac3..b0a6dcff 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -38,8 +38,9 @@ __all__ = ['DaViT'] class MySequential(nn.Sequential): def forward(self, inputs : Tuple[Tensor, Tuple[int, int]]): for module in self: - #if type(inputs) == tuple: inputs = module(*inputs) + #if type(inputs) == tuple: + # inputs = module(*inputs) #else: # inputs = module(inputs) return inputs @@ -507,14 +508,23 @@ class DaViT(nn.Module): branches = [0] for block_index, block_param in enumerate(self.architecture): - #branch_ids = sorted(set(block_param)) - branch_ids = block_index + ''' + branch_ids = sorted(set(block_param)) for branch_id in branch_ids: if branch_id not in branches: x, size = self.patch_embeds[branch_id](features[-1], sizes[-1]) features.append(x) sizes.append(size) branches.append(branch_id) + ''' + + if block_index not in branches: + x, size = self.patch_embeds[block_index](features[-1], sizes[-1]) + features.append(x) + sizes.append(size) + branches.append(branch_id) + + for layer_index, branch_id in enumerate(block_param): if self.grad_checkpointing and not torch.jit.is_scripting(): features[branch_id], _ = checkpoint.checkpoint(self.main_blocks[block_index][layer_index], features[branch_id], sizes[branch_id])