diff --git a/timm/models/davit.py b/timm/models/davit.py index d054f2ab..3aab3f6e 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -121,10 +121,9 @@ class PatchEmbed(nn.Module): C).permute(0, 3, 1, 2).contiguous() B, C, H, W = x.shape - if W % self.patch_size[1] != 0: - x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) - if H % self.patch_size[0] != 0: - x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) x = self.proj(x) newsize = (x.size(2), x.size(3))