diff --git a/timm/models/davit.py b/timm/models/davit.py index 59850ff1..23c28b6a 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -89,7 +89,6 @@ class PatchEmbed(nn.Module): stride=patch_size, padding=(3, 3)) self.norm = nn.LayerNorm(embed_dim) - self.norm_after = True if patch_size[0] == 2: kernel = 3 if overlapped else 2 pad = 1 if overlapped else 0 @@ -100,17 +99,13 @@ class PatchEmbed(nn.Module): stride=patch_size, padding=to_2tuple(pad)) self.norm = nn.LayerNorm(in_chans) - self.norm_after = False - + @register_notrace_function # reason: dim in control sequence def forward(self, x : Tensor, size: Tuple[int, int]): H, W = size - in_shape = x.shape - - - # norm_after variable used as a workaround to original len(x.shape) == 3 - if self.norm_after == False: - B, HW, C = in_shape + dim = len(x.shape) + if dim == 3: + B, HW, C = x.shape x = self.norm(x) x = x.reshape(B, H, @@ -118,14 +113,15 @@ class PatchEmbed(nn.Module): C).permute(0, 3, 1, 2).contiguous() B, C, H, W = x.shape - - x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) - x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) x = self.proj(x) newsize = (x.size(2), x.size(3)) x = x.flatten(2).transpose(1, 2) - if self.norm_after == True: + if dim == 4: x = self.norm(x) return x, newsize