diff --git a/timm/models/tnt.py b/timm/models/tnt.py index cc732677..8e038718 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -14,7 +14,9 @@ from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.models.helpers import load_pretrained from timm.models.layers import Mlp, DropPath, trunc_normal_ +from timm.models.layers.helpers import to_2tuple from timm.models.registry import register_model +from timm.models.vision_transformer import resize_pos_embed def _cfg(url='', **kwargs): @@ -118,11 +120,15 @@ class PixelEmbed(nn.Module): """ def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4): super().__init__() - num_patches = (img_size // patch_size) ** 2 + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + # grid_size property necessary for resizing positional embedding + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + num_patches = (self.grid_size[0]) * (self.grid_size[1]) self.img_size = img_size self.num_patches = num_patches self.in_dim = in_dim - new_patch_size = math.ceil(patch_size / stride) + new_patch_size = [math.ceil(ps / stride) for ps in patch_size] self.new_patch_size = new_patch_size self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride) @@ -130,11 +136,11 @@ class PixelEmbed(nn.Module): def forward(self, x, pixel_pos): B, C, H, W = x.shape - assert H == self.img_size and W == self.img_size, \ - f"Input image size ({H}*{W}) doesn't match model ({self.img_size}*{self.img_size})." + assert H == self.img_size[0] and W == self.img_size[1], \ + f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." x = self.proj(x) x = self.unfold(x) - x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size, self.new_patch_size) + x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1]) x = x + pixel_pos x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2) return x @@ -155,7 +161,7 @@ class TNT(nn.Module): num_patches = self.pixel_embed.num_patches self.num_patches = num_patches new_patch_size = self.pixel_embed.new_patch_size - num_pixel = new_patch_size ** 2 + num_pixel = new_patch_size[0] * new_patch_size[1] self.norm1_proj = norm_layer(num_pixel * in_dim) self.proj = nn.Linear(num_pixel * in_dim, embed_dim) @@ -163,7 +169,7 @@ class TNT(nn.Module): self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim)) - self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size, new_patch_size)) + self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size[0], new_patch_size[1])) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule @@ -224,6 +230,14 @@ class TNT(nn.Module): return x +def checkpoint_filter_fn(state_dict, model): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + if state_dict['patch_pos'].shape != model.patch_pos.shape: + state_dict['patch_pos'] = resize_pos_embed(state_dict['patch_pos'], + model.patch_pos, getattr(model, 'num_tokens', 1), model.pixel_embed.grid_size) + return state_dict + + @register_model def tnt_s_patch16_224(pretrained=False, **kwargs): model = TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, @@ -231,7 +245,8 @@ def tnt_s_patch16_224(pretrained=False, **kwargs): model.default_cfg = default_cfgs['tnt_s_patch16_224'] if pretrained: load_pretrained( - model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), + filter_fn=checkpoint_filter_fn) return model diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index cc7e0903..bef6dfb0 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -352,7 +352,7 @@ def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = Fa nn.init.ones_(m.weight) -def resize_pos_embed(posemb, posemb_new, num_tokens=1): +def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()): # Rescale the grid of position embeddings when loading from state_dict. Adapted from # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape) @@ -363,11 +363,13 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1): else: posemb_tok, posemb_grid = posemb[:, :0], posemb[0] gs_old = int(math.sqrt(len(posemb_grid))) - gs_new = int(math.sqrt(ntok_new)) - _logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new) + if not len(gs_new): # backwards compatibility + gs_new = [int(math.sqrt(ntok_new))] * 2 + assert len(gs_new) >= 2 + _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new) posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) - posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear') - posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1) + posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear') + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1) posemb = torch.cat([posemb_tok, posemb_grid], dim=1) return posemb @@ -385,7 +387,8 @@ def checkpoint_filter_fn(state_dict, model): v = v.reshape(O, -1, H, W) elif k == 'pos_embed' and v.shape != model.pos_embed.shape: # To resize pos embedding when using model at different size from pretrained weights - v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1)) + v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1), + model.patch_embed.grid_size) out_dict[k] = v return out_dict