|
|
@ -431,7 +431,7 @@ class DaViT(nn.Module):
|
|
|
|
layer_offset_id = len(list(itertools.chain(*self.architecture[:block_id])))
|
|
|
|
layer_offset_id = len(list(itertools.chain(*self.architecture[:block_id])))
|
|
|
|
|
|
|
|
|
|
|
|
block = nn.ModuleList([
|
|
|
|
block = nn.ModuleList([
|
|
|
|
MySequential(*[
|
|
|
|
nn.ModuleList([
|
|
|
|
ChannelBlock(
|
|
|
|
ChannelBlock(
|
|
|
|
dim=self.embed_dims[item],
|
|
|
|
dim=self.embed_dims[item],
|
|
|
|
num_heads=self.num_heads[item],
|
|
|
|
num_heads=self.num_heads[item],
|
|
|
@ -513,8 +513,8 @@ class DaViT(nn.Module):
|
|
|
|
for patch_layer, stage in zip(self.patch_embeds, self.main_blocks):
|
|
|
|
for patch_layer, stage in zip(self.patch_embeds, self.main_blocks):
|
|
|
|
features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1])
|
|
|
|
features[-1], sizes[-1] = patch_layer(features[-1], sizes[-1])
|
|
|
|
|
|
|
|
|
|
|
|
for layer in enumerate(stage):
|
|
|
|
for block in enumerate(stage):
|
|
|
|
#for layer in enumerate(block):
|
|
|
|
for layer in enumerate(block):
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
features[-1], sizes[-1] = checkpoint.checkpoint(layer, (features[-1], sizes[-1]))
|
|
|
|
features[-1], sizes[-1] = checkpoint.checkpoint(layer, (features[-1], sizes[-1]))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|