diff --git a/timm/models/davit.py b/timm/models/davit.py index f21e799e..510b7206 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -407,7 +407,11 @@ class DaViTStage(nn.Module): stage_blocks.append(nn.Sequential(*dual_attention_block)) self.blocks = nn.Sequential(*stage_blocks) - + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + def forward(self, x : Tensor): x = self.patch_embed(x) if self.grad_checkpointing and not torch.jit.is_scripting(): @@ -522,6 +526,8 @@ class DaViT(nn.Module): @torch.jit.ignore def set_grad_checkpointing(self, enable=True): self.grad_checkpointing = enable + for stage in self.stages: + stage.set_grad_checkpointing(enable=enable) @torch.jit.ignore def get_classifier(self): @@ -535,7 +541,10 @@ class DaViT(nn.Module): def forward_features(self, x): x = self.stem(x) - x = self.stages(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return x