Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 375be19be6
commit 62b9d696b0

@ -431,7 +431,7 @@ class DaViT(nn.Module):
layer_offset_id = len(list(itertools.chain(*self.architecture[:block_id])))
block = nn.ModuleList([
MySequential(*[
nn.ModuleList([
ChannelBlock(
dim=self.embed_dims[item],
num_heads=self.num_heads[item],
@ -513,12 +513,12 @@ class DaViT(nn.Module):
for patch_layer, stage in zip(self.patch_embeds, self.main_blocks):
features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1])
for layer in enumerate(stage):
#for layer in enumerate(block):
if self.grad_checkpointing and not torch.jit.is_scripting():
features[-1], sizes[-1] = checkpoint.checkpoint(layer, (features[-1], sizes[-1]))
else:
features[-1], sizes[-1] = layer((features[-1], sizes[-1]))
for block in enumerate(stage):
for layer in enumerate(block):
if self.grad_checkpointing and not torch.jit.is_scripting():
features[-1], sizes[-1] = checkpoint.checkpoint(layer, (features[-1], sizes[-1]))
else:
features[-1], sizes[-1] = layer((features[-1], sizes[-1]))
features.append(features[-1])

Loading…
Cancel
Save