Fix grad checkpointing in focalnet

pull/1628/head
Ross Wightman 1 year ago
parent 848d200767
commit cf324ea38f

@ -266,6 +266,10 @@ class BasicLayer(nn.Module):
)
for i in range(depth)])
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
def forward(self, x):
x = self.downsample(x)
for blk in self.blocks:

Loading…
Cancel
Save