diff --git a/timm/models/davit.py b/timm/models/davit.py index b50e88e3..c5538de3 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -92,7 +92,7 @@ class PatchEmbed(nn.Module): stride=patch_size, padding=(3, 3)) self.norm = nn.LayerNorm(embed_dim) - self.norm_after = False + self.norm_after = True if patch_size[0] == 2: kernel = 3 if overlapped else 2 pad = 1 if overlapped else 0 @@ -103,7 +103,7 @@ class PatchEmbed(nn.Module): stride=patch_size, padding=to_2tuple(pad)) self.norm = nn.LayerNorm(in_chans) - self.norm_after = True + self.norm_after = False def forward(self, x : Tensor, size: Tuple[int, int]):