Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent a813bbd1df
commit c82f9204e7

@ -514,7 +514,8 @@ class DaViT(nn.Module):
features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1])
for block in enumerate(stage):
for _, layer in enumerate(block):
for layer in enumerate(block):
print(layer)
if self.grad_checkpointing and not torch.jit.is_scripting():
features[-1], sizes[-1] = checkpoint.checkpoint(layer, (features[-1], sizes[-1]))
else:

Loading…
Cancel
Save