From 5c7c8f2e36c737c454b99a11d5071b639794c0c1 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 7 Dec 2022 00:45:04 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 464b976f..aa006a7d 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -510,14 +510,15 @@ class DaViT(nn.Module): - for patch_layer, blocks in zip(self.patch_embeds, self.main_blocks): + for patch_layer, stage in zip(self.patch_embeds, self.main_blocks): features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1]) - for layer in enumerate(blocks): - if self.grad_checkpointing and not torch.jit.is_scripting(): - features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1]) - else: - features[-1], sizes[-1] = layer(features[-1], sizes[-1]) + for block in enumerate(stage): + for layer in enumerate(block): + if self.grad_checkpointing and not torch.jit.is_scripting(): + features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1]) + else: + features[-1], sizes[-1] = layer(features[-1], sizes[-1]) features.append(features[-1])