From cf324ea38f7a537992b3f61c22f73ec927c899d1 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 17 Feb 2023 16:26:26 -0800 Subject: [PATCH] Fix grad checkpointing in focalnet --- timm/models/focalnet.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/timm/models/focalnet.py b/timm/models/focalnet.py index 40cbe39d..8178cfc3 100644 --- a/timm/models/focalnet.py +++ b/timm/models/focalnet.py @@ -266,6 +266,10 @@ class BasicLayer(nn.Module): ) for i in range(depth)]) + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + def forward(self, x): x = self.downsample(x) for blk in self.blocks: