diff --git a/timm/models/focalnet.py b/timm/models/focalnet.py index 40cbe39d..8178cfc3 100644 --- a/timm/models/focalnet.py +++ b/timm/models/focalnet.py @@ -266,6 +266,10 @@ class BasicLayer(nn.Module): ) for i in range(depth)]) + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + def forward(self, x): x = self.downsample(x) for blk in self.blocks: