diff --git a/timm/models/tnt.py b/timm/models/tnt.py index cc732677..8e038718 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -14,7 +14,9 @@ from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.helpers import load_pretrained from timm.models.layers import Mlp, DropPath, trunc_normal_ +from timm.models.layers.helpers import to_2tuple from timm.models.registry import register_model +from timm.models.vision_transformer import resize_pos_embed def _cfg(url='', **kwargs): @@ -118,11 +120,15 @@ class PixelEmbed(nn.Module): """ def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4): super().__init__() - num_patches = (img_size // patch_size) ** 2 + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + # grid_size property necessary for resizing positional embedding + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + num_patches = (self.grid_size[0]) * (self.grid_size[1]) self.img_size = img_size self.num_patches = num_patches self.in_dim = in_dim - new_patch_size = math.ceil(patch_size / stride) + new_patch_size = [math.ceil(ps / stride) for ps in patch_size] self.new_patch_size = new_patch_size self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride) @@ -130,11 +136,11 @@ class PixelEmbed(nn.Module): def forward(self, x, pixel_pos): B, C, H, W = x.shape - assert H == self.img_size and W == self.img_size, \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size}*{self.img_size})." + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x) x = self.unfold(x) - x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size, self.new_patch_size) + x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1]) x = x + pixel_pos x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2) return x @@ -155,7 +161,7 @@ class TNT(nn.Module): num_patches = self.pixel_embed.num_patches self.num_patches = num_patches new_patch_size = self.pixel_embed.new_patch_size - num_pixel = new_patch_size ** 2 + num_pixel = new_patch_size[0] * new_patch_size[1] self.norm1_proj = norm_layer(num_pixel * in_dim) self.proj = nn.Linear(num_pixel * in_dim, embed_dim) @@ -163,7 +169,7 @@ class TNT(nn.Module): self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) - self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size, new_patch_size)) + self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size[0], new_patch_size[1])) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule @@ -224,6 +230,14 @@ class TNT(nn.Module): return x +def checkpoint_filter_fn(state_dict, model): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + if state_dict['patch_pos'].shape != model.patch_pos.shape: + state_dict['patch_pos'] = resize_pos_embed(state_dict['patch_pos'], + model.patch_pos, getattr(model, 'num_tokens', 1), model.pixel_embed.grid_size) + return state_dict + + @register_model def tnt_s_patch16_224(pretrained=False, **kwargs): model = TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, @@ -231,7 +245,8 @@ def tnt_s_patch16_224(pretrained=False, **kwargs): model.default_cfg = default_cfgs['tnt_s_patch16_224'] if pretrained: load_pretrained( - model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), + filter_fn=checkpoint_filter_fn) return model