From 67d140446bd0a23fcb433819902cf0c51a1bfb80 Mon Sep 17 00:00:00 2001 From: Christoph Reich <34400551+ChristophReich1996@users.noreply.github.com> Date: Sun, 20 Feb 2022 22:28:05 +0100 Subject: [PATCH] Fix bug in classification head --- timm/models/swin_transformer_v2_cr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index aeb713a5..033ad694 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -918,7 +918,7 @@ class SwinTransformerV2CR(nn.Module): self.patch_size: int = patch_size self.input_resolution: Tuple[int, int] = img_size self.window_size: int = window_size - self.num_features: int = int(embed_dim * (2 ** len(depths) - 1)) + self.num_features: int = int(embed_dim * 2 ** (len(depths) - 1)) # Init patch embedding self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_chans, out_channels=embed_dim, patch_size=patch_size, norm_layer=norm_layer) @@ -1038,7 +1038,7 @@ class SwinTransformerV2CR(nn.Module): for stage in self.stages: output: torch.Tensor = stage(output) # Perform average pooling - output: torch.Tensor = self.average_pool(output) + output: torch.Tensor = self.average_pool(output).flatten(start_dim=1) # Predict classification classification: torch.Tensor = self.head(output) return classification