|
|
|
@ -899,6 +899,10 @@ class SwinTransformerV2CR(nn.Module):
|
|
|
|
|
Returns:
|
|
|
|
|
features (List[torch.Tensor]): List of feature maps from each stage
|
|
|
|
|
"""
|
|
|
|
|
# Check input resolution
|
|
|
|
|
assert input.shape[2:] == self.input_resolution, \
|
|
|
|
|
"Input resolution and utilized resolution does not match. Please update the models resolution by calling " \
|
|
|
|
|
"update_resolution the provided method."
|
|
|
|
|
# Perform patch embedding
|
|
|
|
|
output: torch.Tensor = self.patch_embedding(input)
|
|
|
|
|
# Init list to store feature
|
|
|
|
@ -919,6 +923,10 @@ class SwinTransformerV2CR(nn.Module):
|
|
|
|
|
Returns:
|
|
|
|
|
classification (torch.Tensor): Classification of the shape (B, num_classes)
|
|
|
|
|
"""
|
|
|
|
|
# Check input resolution
|
|
|
|
|
assert input.shape[2:] == self.input_resolution, \
|
|
|
|
|
"Input resolution and utilized resolution does not match. Please update the models resolution by calling " \
|
|
|
|
|
"update_resolution the provided method."
|
|
|
|
|
# Perform patch embedding
|
|
|
|
|
output: torch.Tensor = self.patch_embedding(input)
|
|
|
|
|
# Forward pass of each stage
|
|
|
|
|