From 87b4d7a29af80d7c1334af36392f744e9382cdf0 Mon Sep 17 00:00:00 2001 From: Christoph Reich <34400551+ChristophReich1996@users.noreply.github.com> Date: Sat, 19 Feb 2022 22:47:02 +0100 Subject: [PATCH] Add get and reset classifier method --- timm/models/swin_transformer_v2_cr.py | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 6b0ca1c1..7adf1ec0 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -831,9 +831,11 @@ class SwinTransformerV2CR(nn.Module): # Call super constructor super(SwinTransformerV2CR, self).__init__() # Save parameters + self.num_classes: int = num_classes self.patch_size: int = patch_size self.input_resolution: Tuple[int, int] = img_size self.window_size: int = window_size + self.num_features: int = int(embed_dim * (2 ** len(depths) - 1)) # Init patch embedding self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_chans, out_channels=embed_dim, patch_size=patch_size) @@ -863,7 +865,7 @@ class SwinTransformerV2CR(nn.Module): )) # Init final adaptive average pooling, and classification head self.average_pool: nn.Module = nn.AdaptiveAvgPool2d(1) - self.head: nn.Module = nn.Linear(in_features=embed_dim * (2 ** len(depths) - 1), + self.head: nn.Module = nn.Linear(in_features=self.num_features, out_features=num_classes) def update_resolution(self, @@ -889,6 +891,25 @@ class SwinTransformerV2CR(nn.Module): new_input_resolution=(new_patch_resolution[0] // (2 ** max(index - 1, 0)), new_patch_resolution[1] // (2 ** max(index - 1, 0)))) + def get_classifier(self) -> nn.Module: + """ Method returns the classification head of the model. + Returns: + head (nn.Module): Current classification head + """ + head: nn.Module = self.head + return head + + def reset_classifier(self, num_classes: int, global_pool: str = '') -> None: + """ Method results the classification head + + Args: + num_classes (int): Number of classes to be predicted + global_pool (str): Unused + """ + self.num_classes: int = num_classes + self.head: nn.Module = nn.Linear(in_features=self.num_features, out_features=num_classes) \ + if num_classes > 0 else nn.Identity() + def forward_features(self, input: torch.Tensor) -> List[torch.Tensor]: """ Forward pass to extract feature maps of each stage.