Fix grad checkpointing in focalnet

pull/1628/head
Ross Wightman 2 years ago
parent 848d200767
commit cf324ea38f

@ -266,6 +266,10 @@ class BasicLayer(nn.Module):
) )
for i in range(depth)]) for i in range(depth)])
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
def forward(self, x): def forward(self, x):
x = self.downsample(x) x = self.downsample(x)
for blk in self.blocks: for blk in self.blocks:

Loading…
Cancel
Save