Check input resolution

pull/1150/head
Christoph Reich 3 years ago
parent 81bf0b4033
commit ff5f6bcd6c

@ -899,6 +899,10 @@ class SwinTransformerV2CR(nn.Module):
Returns: Returns:
features (List[torch.Tensor]): List of feature maps from each stage features (List[torch.Tensor]): List of feature maps from each stage
""" """
# Check input resolution
assert input.shape[2:] == self.input_resolution, \
"Input resolution and utilized resolution does not match. Please update the models resolution by calling " \
"update_resolution the provided method."
# Perform patch embedding # Perform patch embedding
output: torch.Tensor = self.patch_embedding(input) output: torch.Tensor = self.patch_embedding(input)
# Init list to store feature # Init list to store feature
@ -919,6 +923,10 @@ class SwinTransformerV2CR(nn.Module):
Returns: Returns:
classification (torch.Tensor): Classification of the shape (B, num_classes) classification (torch.Tensor): Classification of the shape (B, num_classes)
""" """
# Check input resolution
assert input.shape[2:] == self.input_resolution, \
"Input resolution and utilized resolution does not match. Please update the models resolution by calling " \
"update_resolution the provided method."
# Perform patch embedding # Perform patch embedding
output: torch.Tensor = self.patch_embedding(input) output: torch.Tensor = self.patch_embedding(input)
# Forward pass of each stage # Forward pass of each stage

Loading…
Cancel
Save