Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 11f27df29f
commit f7a8fb9f97

@ -494,9 +494,7 @@ class DaViT(nn.Module):
for patch_layer, stage in zip(self.patch_embeds, self.main_blocks): for patch_layer, stage in zip(self.patch_embeds, self.main_blocks):
features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1]) features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1])
for _, block in enumerate(stage): for _, block in enumerate(stage):
print(block)
for _, layer in enumerate(block): for _, layer in enumerate(block):
print(layer)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():
features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1]) features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1])
else: else:

Loading…
Cancel
Save