Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 5c7c8f2e36
commit 8e421bccfb

@ -516,9 +516,9 @@ class DaViT(nn.Module):
for block in enumerate(stage):
for layer in enumerate(block):
if self.grad_checkpointing and not torch.jit.is_scripting():
features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1])
features[-1], sizes[-1] = checkpoint.checkpoint(layer, (features[-1], sizes[-1]))
else:
features[-1], sizes[-1] = layer(features[-1], sizes[-1])
features[-1], sizes[-1] = layer((features[-1], sizes[-1]))
features.append(features[-1])

Loading…
Cancel
Save