Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent c1a5882c5a
commit 5c7c8f2e36

@ -510,14 +510,15 @@ class DaViT(nn.Module):
for patch_layer, blocks in zip(self.patch_embeds, self.main_blocks): for patch_layer, stage in zip(self.patch_embeds, self.main_blocks):
features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1]) features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1])
for layer in enumerate(blocks): for block in enumerate(stage):
if self.grad_checkpointing and not torch.jit.is_scripting(): for layer in enumerate(block):
features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1]) if self.grad_checkpointing and not torch.jit.is_scripting():
else: features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1])
features[-1], sizes[-1] = layer(features[-1], sizes[-1]) else:
features[-1], sizes[-1] = layer(features[-1], sizes[-1])
features.append(features[-1]) features.append(features[-1])

Loading…
Cancel
Save