Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 1470cebb6c
commit 2b5f391142

@ -473,13 +473,7 @@ class DaViT(nn.Module):
self.grad_checkpointing = False
self.feature_info = []
self.patch_embed = PatchEmbed(
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dims[0],
overlapped=overlapped_patch
)
self.patch_embed = None
stages = []
for stage_id in range(self.num_stages):
@ -503,6 +497,7 @@ class DaViT(nn.Module):
)
if stage_id == 0:
self.patch_embed = stage.patch_embed
stage.patch_embed = nn.Identity()
stages.append(stage)

Loading…
Cancel
Save