diff --git a/timm/models/davit.py b/timm/models/davit.py index 58931cfd..b50e88e3 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -92,6 +92,7 @@ class PatchEmbed(nn.Module): stride=patch_size, padding=(3, 3)) self.norm = nn.LayerNorm(embed_dim) + self.norm_after = False if patch_size[0] == 2: kernel = 3 if overlapped else 2 pad = 1 if overlapped else 0 @@ -102,18 +103,23 @@ class PatchEmbed(nn.Module): stride=patch_size, padding=to_2tuple(pad)) self.norm = nn.LayerNorm(in_chans) + self.norm_after = True def forward(self, x : Tensor, size: Tuple[int, int]): H, W = size - dim = x.dim() - if dim == 3: - B, HW, C = x.shape + #dim = x.dim() + #if dim == 3: + in_shape = x.shape + B = in_shape[0] + C = in_shape[-1] + + if self.norm_after == False: x = self.norm(x) - x = x.reshape(B, - H, - W, - C).permute(0, 3, 1, 2).contiguous() + x = x.reshape(B, + H, + W, + C).permute(0, 3, 1, 2).contiguous() B, C, H, W = x.shape if W % self.patch_size[1] != 0: @@ -124,7 +130,8 @@ class PatchEmbed(nn.Module): x = self.proj(x) newsize = (x.size(2), x.size(3)) x = x.flatten(2).transpose(1, 2) - if dim == 4: + #if dim == 4: + if self.norm_after == True: x = self.norm(x) return x, newsize