|
|
|
@ -226,7 +226,7 @@ class PatchEmbed(nn.Module):
|
|
|
|
|
x = self.proj(x)
|
|
|
|
|
#x = x.flatten(2).transpose(1, 2)
|
|
|
|
|
if self.norm.normalized_shape[0] == self.embed_dim:
|
|
|
|
|
x = self.norm(x.flatten(2).transpose(1, 2)).transpose(1, 2).view(B, C, H, W)
|
|
|
|
|
x = self.norm(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|