From 347386580f8b5b888c72be55c7af4d0bedad56ee Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 8 Dec 2022 10:54:42 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index c5538de3..47129b3f 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -92,7 +92,7 @@ class PatchEmbed(nn.Module): stride=patch_size, padding=(3, 3)) self.norm = nn.LayerNorm(embed_dim) - self.norm_after = True + self.norm_after = False if patch_size[0] == 2: kernel = 3 if overlapped else 2 pad = 1 if overlapped else 0 @@ -103,7 +103,7 @@ class PatchEmbed(nn.Module): stride=patch_size, padding=to_2tuple(pad)) self.norm = nn.LayerNorm(in_chans) - self.norm_after = False + self.norm_after = True def forward(self, x : Tensor, size: Tuple[int, int]): @@ -111,15 +111,14 @@ class PatchEmbed(nn.Module): #dim = x.dim() #if dim == 3: in_shape = x.shape - B = in_shape[0] - C = in_shape[-1] if self.norm_after == False: + B, HW, C = in_shape x = self.norm(x) - x = x.reshape(B, - H, - W, - C).permute(0, 3, 1, 2).contiguous() + x = x.reshape(B, + H, + W, + C).permute(0, 3, 1, 2).contiguous() B, C, H, W = x.shape if W % self.patch_size[1] != 0: