|
|
@ -514,7 +514,8 @@ class DaViT(nn.Module):
|
|
|
|
features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1])
|
|
|
|
features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1])
|
|
|
|
|
|
|
|
|
|
|
|
for block in enumerate(stage):
|
|
|
|
for block in enumerate(stage):
|
|
|
|
for _, layer in enumerate(block):
|
|
|
|
for layer in enumerate(block):
|
|
|
|
|
|
|
|
print(layer)
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
features[-1], sizes[-1] = checkpoint.checkpoint(layer, (features[-1], sizes[-1]))
|
|
|
|
features[-1], sizes[-1] = checkpoint.checkpoint(layer, (features[-1], sizes[-1]))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|