diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 9659b5ec..6b0ca1c1 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -899,6 +899,10 @@ class SwinTransformerV2CR(nn.Module): Returns: features (List[torch.Tensor]): List of feature maps from each stage """ + # Check input resolution + assert input.shape[2:] == self.input_resolution, \ + "Input resolution and utilized resolution does not match. Please update the models resolution by calling " \ + "update_resolution the provided method." # Perform patch embedding output: torch.Tensor = self.patch_embedding(input) # Init list to store feature @@ -919,6 +923,10 @@ class SwinTransformerV2CR(nn.Module): Returns: classification (torch.Tensor): Classification of the shape (B, num_classes) """ + # Check input resolution + assert input.shape[2:] == self.input_resolution, \ + "Input resolution and utilized resolution does not match. Please update the models resolution by calling " \ + "update_resolution the provided method." # Perform patch embedding output: torch.Tensor = self.patch_embedding(input) # Forward pass of each stage