|
|
@ -831,9 +831,11 @@ class SwinTransformerV2CR(nn.Module):
|
|
|
|
# Call super constructor
|
|
|
|
# Call super constructor
|
|
|
|
super(SwinTransformerV2CR, self).__init__()
|
|
|
|
super(SwinTransformerV2CR, self).__init__()
|
|
|
|
# Save parameters
|
|
|
|
# Save parameters
|
|
|
|
|
|
|
|
self.num_classes: int = num_classes
|
|
|
|
self.patch_size: int = patch_size
|
|
|
|
self.patch_size: int = patch_size
|
|
|
|
self.input_resolution: Tuple[int, int] = img_size
|
|
|
|
self.input_resolution: Tuple[int, int] = img_size
|
|
|
|
self.window_size: int = window_size
|
|
|
|
self.window_size: int = window_size
|
|
|
|
|
|
|
|
self.num_features: int = int(embed_dim * (2 ** len(depths) - 1))
|
|
|
|
# Init patch embedding
|
|
|
|
# Init patch embedding
|
|
|
|
self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_chans, out_channels=embed_dim,
|
|
|
|
self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_chans, out_channels=embed_dim,
|
|
|
|
patch_size=patch_size)
|
|
|
|
patch_size=patch_size)
|
|
|
@ -863,7 +865,7 @@ class SwinTransformerV2CR(nn.Module):
|
|
|
|
))
|
|
|
|
))
|
|
|
|
# Init final adaptive average pooling, and classification head
|
|
|
|
# Init final adaptive average pooling, and classification head
|
|
|
|
self.average_pool: nn.Module = nn.AdaptiveAvgPool2d(1)
|
|
|
|
self.average_pool: nn.Module = nn.AdaptiveAvgPool2d(1)
|
|
|
|
self.head: nn.Module = nn.Linear(in_features=embed_dim * (2 ** len(depths) - 1),
|
|
|
|
self.head: nn.Module = nn.Linear(in_features=self.num_features,
|
|
|
|
out_features=num_classes)
|
|
|
|
out_features=num_classes)
|
|
|
|
|
|
|
|
|
|
|
|
def update_resolution(self,
|
|
|
|
def update_resolution(self,
|
|
|
@ -889,6 +891,25 @@ class SwinTransformerV2CR(nn.Module):
|
|
|
|
new_input_resolution=(new_patch_resolution[0] // (2 ** max(index - 1, 0)),
|
|
|
|
new_input_resolution=(new_patch_resolution[0] // (2 ** max(index - 1, 0)),
|
|
|
|
new_patch_resolution[1] // (2 ** max(index - 1, 0))))
|
|
|
|
new_patch_resolution[1] // (2 ** max(index - 1, 0))))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_classifier(self) -> nn.Module:
|
|
|
|
|
|
|
|
""" Method returns the classification head of the model.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
|
|
head (nn.Module): Current classification head
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
head: nn.Module = self.head
|
|
|
|
|
|
|
|
return head
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reset_classifier(self, num_classes: int, global_pool: str = '') -> None:
|
|
|
|
|
|
|
|
""" Method results the classification head
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
|
|
num_classes (int): Number of classes to be predicted
|
|
|
|
|
|
|
|
global_pool (str): Unused
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
self.num_classes: int = num_classes
|
|
|
|
|
|
|
|
self.head: nn.Module = nn.Linear(in_features=self.num_features, out_features=num_classes) \
|
|
|
|
|
|
|
|
if num_classes > 0 else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
def forward_features(self,
|
|
|
|
def forward_features(self,
|
|
|
|
input: torch.Tensor) -> List[torch.Tensor]:
|
|
|
|
input: torch.Tensor) -> List[torch.Tensor]:
|
|
|
|
""" Forward pass to extract feature maps of each stage.
|
|
|
|
""" Forward pass to extract feature maps of each stage.
|
|
|
|