|
|
@ -14,7 +14,9 @@ from functools import partial
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
|
|
|
from timm.models.helpers import load_pretrained
|
|
|
|
from timm.models.helpers import load_pretrained
|
|
|
|
from timm.models.layers import Mlp, DropPath, trunc_normal_
|
|
|
|
from timm.models.layers import Mlp, DropPath, trunc_normal_
|
|
|
|
|
|
|
|
from timm.models.layers.helpers import to_2tuple
|
|
|
|
from timm.models.registry import register_model
|
|
|
|
from timm.models.registry import register_model
|
|
|
|
|
|
|
|
from timm.models.vision_transformer import resize_pos_embed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _cfg(url='', **kwargs):
|
|
|
|
def _cfg(url='', **kwargs):
|
|
|
@ -118,11 +120,15 @@ class PixelEmbed(nn.Module):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4):
|
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
num_patches = (img_size // patch_size) ** 2
|
|
|
|
img_size = to_2tuple(img_size)
|
|
|
|
|
|
|
|
patch_size = to_2tuple(patch_size)
|
|
|
|
|
|
|
|
# grid_size property necessary for resizing positional embedding
|
|
|
|
|
|
|
|
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
|
|
|
|
|
|
|
|
num_patches = (self.grid_size[0]) * (self.grid_size[1])
|
|
|
|
self.img_size = img_size
|
|
|
|
self.img_size = img_size
|
|
|
|
self.num_patches = num_patches
|
|
|
|
self.num_patches = num_patches
|
|
|
|
self.in_dim = in_dim
|
|
|
|
self.in_dim = in_dim
|
|
|
|
new_patch_size = math.ceil(patch_size / stride)
|
|
|
|
new_patch_size = [math.ceil(ps / stride) for ps in patch_size]
|
|
|
|
self.new_patch_size = new_patch_size
|
|
|
|
self.new_patch_size = new_patch_size
|
|
|
|
|
|
|
|
|
|
|
|
self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride)
|
|
|
|
self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride)
|
|
|
@ -130,11 +136,11 @@ class PixelEmbed(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, pixel_pos):
|
|
|
|
def forward(self, x, pixel_pos):
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
assert H == self.img_size and W == self.img_size, \
|
|
|
|
assert H == self.img_size[0] and W == self.img_size[1], \
|
|
|
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size}*{self.img_size})."
|
|
|
|
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
|
|
|
x = self.proj(x)
|
|
|
|
x = self.proj(x)
|
|
|
|
x = self.unfold(x)
|
|
|
|
x = self.unfold(x)
|
|
|
|
x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size, self.new_patch_size)
|
|
|
|
x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])
|
|
|
|
x = x + pixel_pos
|
|
|
|
x = x + pixel_pos
|
|
|
|
x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2)
|
|
|
|
x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
@ -155,7 +161,7 @@ class TNT(nn.Module):
|
|
|
|
num_patches = self.pixel_embed.num_patches
|
|
|
|
num_patches = self.pixel_embed.num_patches
|
|
|
|
self.num_patches = num_patches
|
|
|
|
self.num_patches = num_patches
|
|
|
|
new_patch_size = self.pixel_embed.new_patch_size
|
|
|
|
new_patch_size = self.pixel_embed.new_patch_size
|
|
|
|
num_pixel = new_patch_size ** 2
|
|
|
|
num_pixel = new_patch_size[0] * new_patch_size[1]
|
|
|
|
|
|
|
|
|
|
|
|
self.norm1_proj = norm_layer(num_pixel * in_dim)
|
|
|
|
self.norm1_proj = norm_layer(num_pixel * in_dim)
|
|
|
|
self.proj = nn.Linear(num_pixel * in_dim, embed_dim)
|
|
|
|
self.proj = nn.Linear(num_pixel * in_dim, embed_dim)
|
|
|
@ -163,7 +169,7 @@ class TNT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
|
|
self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
|
|
|
self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
|
|
|
self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size, new_patch_size))
|
|
|
|
self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size[0], new_patch_size[1]))
|
|
|
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
|
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
|
|
|
|
|
|
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
|
|
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
|
|
@ -224,6 +230,14 @@ class TNT(nn.Module):
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
|
|
|
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
|
|
|
|
|
|
|
if state_dict['patch_pos'].shape != model.patch_pos.shape:
|
|
|
|
|
|
|
|
state_dict['patch_pos'] = resize_pos_embed(state_dict['patch_pos'],
|
|
|
|
|
|
|
|
model.patch_pos, getattr(model, 'num_tokens', 1), model.pixel_embed.grid_size)
|
|
|
|
|
|
|
|
return state_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def tnt_s_patch16_224(pretrained=False, **kwargs):
|
|
|
|
def tnt_s_patch16_224(pretrained=False, **kwargs):
|
|
|
|
model = TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4,
|
|
|
|
model = TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4,
|
|
|
@ -231,7 +245,8 @@ def tnt_s_patch16_224(pretrained=False, **kwargs):
|
|
|
|
model.default_cfg = default_cfgs['tnt_s_patch16_224']
|
|
|
|
model.default_cfg = default_cfgs['tnt_s_patch16_224']
|
|
|
|
if pretrained:
|
|
|
|
if pretrained:
|
|
|
|
load_pretrained(
|
|
|
|
load_pretrained(
|
|
|
|
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
|
|
|
|
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3),
|
|
|
|
|
|
|
|
filter_fn=checkpoint_filter_fn)
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|