From 62b9d696b0e7d1c11084c2c385383f8912ee127c Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 7 Dec 2022 00:51:12 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 2aee5b1a..270f687c 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -431,7 +431,7 @@ class DaViT(nn.Module): layer_offset_id = len(list(itertools.chain(*self.architecture[:block_id]))) block = nn.ModuleList([ - MySequential(*[ + nn.ModuleList([ ChannelBlock( dim=self.embed_dims[item], num_heads=self.num_heads[item], @@ -513,12 +513,12 @@ class DaViT(nn.Module): for patch_layer, stage in zip(self.patch_embeds, self.main_blocks): features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1]) - for layer in enumerate(stage): - #for layer in enumerate(block): - if self.grad_checkpointing and not torch.jit.is_scripting(): - features[-1], sizes[-1] = checkpoint.checkpoint(layer, (features[-1], sizes[-1])) - else: - features[-1], sizes[-1] = layer((features[-1], sizes[-1])) + for block in enumerate(stage): + for layer in enumerate(block): + if self.grad_checkpointing and not torch.jit.is_scripting(): + features[-1], sizes[-1] = checkpoint.checkpoint(layer, (features[-1], sizes[-1])) + else: + features[-1], sizes[-1] = layer((features[-1], sizes[-1])) features.append(features[-1])