From e265d4bc31584b7469801393591a4702120cca4b Mon Sep 17 00:00:00 2001 From: Zhiyuan Chen Date: Sat, 21 Nov 2020 05:55:30 +0800 Subject: [PATCH] relaxe size constraints in vit --- timm/models/vision_transformer.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 72f3a61a..4fc126ec 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -146,20 +146,24 @@ class Block(nn.Module): class PatchEmbed(nn.Module): """ Image to Patch Embedding """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, img_width=None, img_height=None, patch_width=None, patch_height=None): super().__init__() - img_size = to_2tuple(img_size) - patch_size = to_2tuple(patch_size) - num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) - self.img_size = img_size - self.patch_size = patch_size - self.num_patches = num_patches + if not img_width: + img_width = img_size + if not img_height: + img_height = img_height + if not patch_width: + patch_width = patch_size + if not patch_height: + patch_height = patch_size + self.img_size = (img_height, img_width) + self.patch_size = (patch_height, patch_width) + self.num_patches = (img_width // patch_width) * (img_height // patch_height) self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) def forward(self, x): B, C, H, W = x.shape - # FIXME look at relaxing size constraints assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x).flatten(2).transpose(1, 2)