From ff5f6bcd6cb451418e93164ca43c18adfd92f36f Mon Sep 17 00:00:00 2001 From: Christoph Reich <34400551+ChristophReich1996@users.noreply.github.com> Date: Sat, 19 Feb 2022 22:42:02 +0100 Subject: [PATCH] Check input resolution --- timm/models/swin_transformer_v2_cr.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 9659b5ec..6b0ca1c1 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -899,6 +899,10 @@ class SwinTransformerV2CR(nn.Module): Returns: features (List[torch.Tensor]): List of feature maps from each stage """ + # Check input resolution + assert input.shape[2:] == self.input_resolution, \ + "Input resolution and utilized resolution does not match. Please update the models resolution by calling " \ + "update_resolution the provided method." # Perform patch embedding output: torch.Tensor = self.patch_embedding(input) # Init list to store feature @@ -919,6 +923,10 @@ class SwinTransformerV2CR(nn.Module): Returns: classification (torch.Tensor): Classification of the shape (B, num_classes) """ + # Check input resolution + assert input.shape[2:] == self.input_resolution, \ + "Input resolution and utilized resolution does not match. Please update the models resolution by calling " \ + "update_resolution the provided method." # Perform patch embedding output: torch.Tensor = self.patch_embedding(input) # Forward pass of each stage