diff --git a/timm/models/davit.py b/timm/models/davit.py index d764b6e6..04b72ace 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -226,7 +226,7 @@ class PatchEmbed(nn.Module): x = self.proj(x) #x = x.flatten(2).transpose(1, 2) if self.norm.normalized_shape[0] == self.embed_dim: - x = self.norm(x) + x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return x