is_scripting() guard on checkpoint_seq

pull/1655/head
Ross Wightman 2 years ago
parent 95ec255f7f
commit 72fba669a8

@ -336,7 +336,7 @@ class EfficientFormerStage(nn.Module):
def forward(self, x): def forward(self, x):
x = self.downsample(x) x = self.downsample(x)
if self.grad_checkpointing: if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x) x = checkpoint_seq(self.blocks, x)
else: else:
x = self.blocks(x) x = self.blocks(x)

@ -499,7 +499,7 @@ class EfficientFormerV2Stage(nn.Module):
def forward(self, x): def forward(self, x):
x = self.downsample(x) x = self.downsample(x)
if self.grad_checkpointing: if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x) x = checkpoint_seq(self.blocks, x)
else: else:
x = self.blocks(x) x = self.blocks(x)

Loading…
Cancel
Save