diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index 21957d58..c6920020 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -336,7 +336,7 @@ class EfficientFormerStage(nn.Module): def forward(self, x): x = self.downsample(x) - if self.grad_checkpointing: + if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) diff --git a/timm/models/efficientformer_v2.py b/timm/models/efficientformer_v2.py index e2adccdb..737e314a 100644 --- a/timm/models/efficientformer_v2.py +++ b/timm/models/efficientformer_v2.py @@ -499,7 +499,7 @@ class EfficientFormerV2Stage(nn.Module): def forward(self, x): x = self.downsample(x) - if self.grad_checkpointing: + if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x)