diff --git a/timm/models/davit.py b/timm/models/davit.py index 2e9d3ab9..97f8813b 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -166,9 +166,9 @@ class PatchEmbed(nn.Module): x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) - x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + x = F.pad(x, (0, self.patch_size[1] - torch.floor(W % self.patch_size[1]))) - x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + x = F.pad(x, (0, 0, 0, self.patch_size[0] - torch.floor(H % self.patch_size[0]))) x = self.proj(x)