Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent e9a16ee19d
commit 625bb86a64

@ -512,10 +512,9 @@ class DaViT(nn.Module):
for patch_layer, stage in zip(self.patch_embeds, self.main_blocks): for patch_layer, stage in zip(self.patch_embeds, self.main_blocks):
features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1]) features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1])
print(stage) for _, block in enumerate(stage):
for block in enumerate(stage):
print(block) print(block)
for layer in enumerate(block): for _, layer in enumerate(block):
print(layer) print(layer)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():
features[-1], sizes[-1] = checkpoint.checkpoint(layer, (features[-1], sizes[-1])) features[-1], sizes[-1] = checkpoint.checkpoint(layer, (features[-1], sizes[-1]))

Loading…
Cancel
Save