@ -226,7 +226,7 @@ class PatchEmbed(nn.Module):
x = self.proj(x)
#x = x.flatten(2).transpose(1, 2)
if self.norm.normalized_shape[0] == self.embed_dim:
x = self.norm(x)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x