@ -336,7 +336,7 @@ class EfficientFormerStage(nn.Module):
def forward(self, x):
x = self.downsample(x)
if self.grad_checkpointing:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
@ -499,7 +499,7 @@ class EfficientFormerV2Stage(nn.Module):