|
|
|
@ -473,13 +473,7 @@ class DaViT(nn.Module):
|
|
|
|
|
self.grad_checkpointing = False
|
|
|
|
|
self.feature_info = []
|
|
|
|
|
|
|
|
|
|
self.patch_embed = PatchEmbed(
|
|
|
|
|
patch_size=patch_size,
|
|
|
|
|
in_chans=in_chans,
|
|
|
|
|
embed_dim=embed_dims[0],
|
|
|
|
|
overlapped=overlapped_patch
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
self.patch_embed = None
|
|
|
|
|
stages = []
|
|
|
|
|
|
|
|
|
|
for stage_id in range(self.num_stages):
|
|
|
|
@ -503,6 +497,7 @@ class DaViT(nn.Module):
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if stage_id == 0:
|
|
|
|
|
self.patch_embed = stage.patch_embed
|
|
|
|
|
stage.patch_embed = nn.Identity()
|
|
|
|
|
|
|
|
|
|
stages.append(stage)
|
|
|
|
|