|
|
|
@ -918,7 +918,7 @@ class SwinTransformerV2CR(nn.Module):
|
|
|
|
|
self.patch_size: int = patch_size
|
|
|
|
|
self.input_resolution: Tuple[int, int] = img_size
|
|
|
|
|
self.window_size: int = window_size
|
|
|
|
|
self.num_features: int = int(embed_dim * (2 ** len(depths) - 1))
|
|
|
|
|
self.num_features: int = int(embed_dim * 2 ** (len(depths) - 1))
|
|
|
|
|
# Init patch embedding
|
|
|
|
|
self.patch_embedding: nn.Module = PatchEmbedding(in_channels=in_chans, out_channels=embed_dim,
|
|
|
|
|
patch_size=patch_size, norm_layer=norm_layer)
|
|
|
|
@ -1038,7 +1038,7 @@ class SwinTransformerV2CR(nn.Module):
|
|
|
|
|
for stage in self.stages:
|
|
|
|
|
output: torch.Tensor = stage(output)
|
|
|
|
|
# Perform average pooling
|
|
|
|
|
output: torch.Tensor = self.average_pool(output)
|
|
|
|
|
output: torch.Tensor = self.average_pool(output).flatten(start_dim=1)
|
|
|
|
|
# Predict classification
|
|
|
|
|
classification: torch.Tensor = self.head(output)
|
|
|
|
|
return classification
|
|
|
|
|