From abf02db3d7ac2e6cf62412fae668c5c74b51a731 Mon Sep 17 00:00:00 2001 From: Carl-Johann SIMON-GABRIEL Date: Sun, 25 Jul 2021 12:02:31 +0200 Subject: [PATCH] [M] making PatchEmbed safer --- timm/models/layers/patch_embed.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/timm/models/layers/patch_embed.py b/timm/models/layers/patch_embed.py index 42997fb8..6d322987 100644 --- a/timm/models/layers/patch_embed.py +++ b/timm/models/layers/patch_embed.py @@ -24,9 +24,11 @@ class PatchEmbed(nn.Module): self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten + if norm_layer is not None: + assert flatten, "Only use `norm_layer` if `flatten` is True" self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) - self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + self.norm = norm_layer(embed_dim) if norm_layer else None def forward(self, x): B, C, H, W = x.shape @@ -35,5 +37,5 @@ class PatchEmbed(nn.Module): x = self.proj(x) if self.flatten: x = x.flatten(2).transpose(1, 2) # BCHW -> BNC - x = self.norm(x) + x = self.norm(x) return x