From 72fba669a8b661e4e1d02118db58a0d29c66fbc7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 4 Feb 2023 14:21:49 -0800 Subject: [PATCH] is_scripting() guard on checkpoint_seq --- timm/models/efficientformer.py | 2 +- timm/models/efficientformer_v2.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/efficientformer.py b/timm/models/efficientformer.py index 21957d58..c6920020 100644 --- a/timm/models/efficientformer.py +++ b/timm/models/efficientformer.py @@ -336,7 +336,7 @@ class EfficientFormerStage(nn.Module): def forward(self, x): x = self.downsample(x) - if self.grad_checkpointing: + if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x) diff --git a/timm/models/efficientformer_v2.py b/timm/models/efficientformer_v2.py index e2adccdb..737e314a 100644 --- a/timm/models/efficientformer_v2.py +++ b/timm/models/efficientformer_v2.py @@ -499,7 +499,7 @@ class EfficientFormerV2Stage(nn.Module): def forward(self, x): x = self.downsample(x) - if self.grad_checkpointing: + if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) else: x = self.blocks(x)