diff --git a/timm/models/layers/patch_embed.py b/timm/models/layers/patch_embed.py index 42997fb8..6d322987 100644 --- a/timm/models/layers/patch_embed.py +++ b/timm/models/layers/patch_embed.py @@ -24,9 +24,11 @@ class PatchEmbed(nn.Module): self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten + if norm_layer is not None: + assert flatten, "Only use `norm_layer` if `flatten` is True" self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + self.norm = norm_layer(embed_dim) if norm_layer else None def forward(self, x): B, C, H, W = x.shape @@ -35,5 +37,5 @@ class PatchEmbed(nn.Module): x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # BCHW -> BNC - x = self.norm(x) + x = self.norm(x) return x