|
|
@ -10,6 +10,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
Paper link: https://arxiv.org/abs/2103.14899
|
|
|
|
Paper link: https://arxiv.org/abs/2103.14899
|
|
|
|
Original code: https://github.com/IBM/CrossViT/blob/main/models/crossvit.py
|
|
|
|
Original code: https://github.com/IBM/CrossViT/blob/main/models/crossvit.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NOTE: model names have been renamed from originals to represent actual input res all *_224 -> *_240 and *_384 -> *_408
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
# Copyright IBM All Rights Reserved.
|
|
|
|
# Copyright IBM All Rights Reserved.
|
|
|
@ -40,30 +42,49 @@ def _cfg(url='', **kwargs):
|
|
|
|
'url': url,
|
|
|
|
'url': url,
|
|
|
|
'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None,
|
|
|
|
'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None,
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
|
|
|
|
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
|
|
|
|
# 'first_conv': 'patch_embed.proj',
|
|
|
|
'first_conv': ('patch_embed.0.proj', 'patch_embed.1.proj'),
|
|
|
|
'classifier': 'head',
|
|
|
|
'classifier': ('head.0', 'head.1'),
|
|
|
|
**kwargs
|
|
|
|
**kwargs
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
default_cfgs = {
|
|
|
|
default_cfgs = {
|
|
|
|
'crossvit_15_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_224.pth'),
|
|
|
|
'crossvit_15_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_224.pth'),
|
|
|
|
'crossvit_15_dagger_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_224.pth'),
|
|
|
|
'crossvit_15_dagger_240': _cfg(
|
|
|
|
'crossvit_15_dagger_384': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_384.pth'),
|
|
|
|
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_224.pth',
|
|
|
|
'crossvit_18_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_224.pth'),
|
|
|
|
first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
|
|
|
|
'crossvit_18_dagger_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_224.pth'),
|
|
|
|
),
|
|
|
|
'crossvit_18_dagger_384': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_384.pth'),
|
|
|
|
'crossvit_15_dagger_408': _cfg(
|
|
|
|
'crossvit_9_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_224.pth'),
|
|
|
|
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_384.pth',
|
|
|
|
'crossvit_9_dagger_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_dagger_224.pth'),
|
|
|
|
input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
|
|
|
|
'crossvit_base_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_base_224.pth'),
|
|
|
|
),
|
|
|
|
'crossvit_small_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_small_224.pth'),
|
|
|
|
'crossvit_18_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_224.pth'),
|
|
|
|
'crossvit_tiny_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_tiny_224.pth'),
|
|
|
|
'crossvit_18_dagger_240': _cfg(
|
|
|
|
|
|
|
|
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_224.pth',
|
|
|
|
|
|
|
|
first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
'crossvit_18_dagger_408': _cfg(
|
|
|
|
|
|
|
|
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_384.pth',
|
|
|
|
|
|
|
|
input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
'crossvit_9_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_224.pth'),
|
|
|
|
|
|
|
|
'crossvit_9_dagger_240': _cfg(
|
|
|
|
|
|
|
|
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_dagger_224.pth',
|
|
|
|
|
|
|
|
first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
'crossvit_base_240': _cfg(
|
|
|
|
|
|
|
|
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_base_224.pth'),
|
|
|
|
|
|
|
|
'crossvit_small_240': _cfg(
|
|
|
|
|
|
|
|
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_small_224.pth'),
|
|
|
|
|
|
|
|
'crossvit_tiny_240': _cfg(
|
|
|
|
|
|
|
|
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_tiny_224.pth'),
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PatchEmbed(nn.Module):
|
|
|
|
class PatchEmbed(nn.Module):
|
|
|
|
""" Image to Patch Embedding
|
|
|
|
""" Image to Patch Embedding
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi_conv=False):
|
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi_conv=False):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
img_size = to_2tuple(img_size)
|
|
|
|
img_size = to_2tuple(img_size)
|
|
|
@ -117,17 +138,19 @@ class CrossAttention(nn.Module):
|
|
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
self.proj_drop = nn.Dropout(proj_drop)
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
|
|
|
|
|
|
|
|
B, N, C = x.shape
|
|
|
|
B, N, C = x.shape
|
|
|
|
q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # B1C -> B1H(C/H) -> BH1(C/H)
|
|
|
|
# B1C -> B1H(C/H) -> BH1(C/H)
|
|
|
|
k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H)
|
|
|
|
q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
|
|
|
v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H)
|
|
|
|
# BNC -> BNH(C/H) -> BHN(C/H)
|
|
|
|
|
|
|
|
k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
|
|
|
|
|
|
|
# BNC -> BNH(C/H) -> BHN(C/H)
|
|
|
|
|
|
|
|
v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
|
|
|
|
|
|
|
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N
|
|
|
|
attn = (q @ k.transpose(-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
attn = attn.softmax(dim=-1)
|
|
|
|
attn = self.attn_drop(attn)
|
|
|
|
attn = self.attn_drop(attn)
|
|
|
|
|
|
|
|
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C
|
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C
|
|
|
|
x = self.proj(x)
|
|
|
|
x = self.proj(x)
|
|
|
|
x = self.proj_drop(x)
|
|
|
|
x = self.proj_drop(x)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
@ -152,7 +175,7 @@ class CrossAttentionBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
class MultiScaleBlock(nn.Module):
|
|
|
|
class MultiScaleBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, dim, patches, depth, num_heads, mlp_ratio, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
|
|
|
def __init__(self, dim, patches, depth, num_heads, mlp_ratio, qkv_bias=False, drop=0., attn_drop=0.,
|
|
|
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
|
|
|
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
@ -163,9 +186,9 @@ class MultiScaleBlock(nn.Module):
|
|
|
|
for d in range(num_branches):
|
|
|
|
for d in range(num_branches):
|
|
|
|
tmp = []
|
|
|
|
tmp = []
|
|
|
|
for i in range(depth[d]):
|
|
|
|
for i in range(depth[d]):
|
|
|
|
tmp.append(
|
|
|
|
tmp.append(Block(
|
|
|
|
Block(dim=dim[d], num_heads=num_heads[d], mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias,
|
|
|
|
dim=dim[d], num_heads=num_heads[d], mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias,
|
|
|
|
drop=drop, attn_drop=attn_drop, drop_path=drop_path[i], norm_layer=norm_layer))
|
|
|
|
drop=drop, attn_drop=attn_drop, drop_path=drop_path[i], norm_layer=norm_layer))
|
|
|
|
if len(tmp) != 0:
|
|
|
|
if len(tmp) != 0:
|
|
|
|
self.blocks.append(nn.Sequential(*tmp))
|
|
|
|
self.blocks.append(nn.Sequential(*tmp))
|
|
|
|
|
|
|
|
|
|
|
@ -174,32 +197,36 @@ class MultiScaleBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
self.projs = nn.ModuleList()
|
|
|
|
self.projs = nn.ModuleList()
|
|
|
|
for d in range(num_branches):
|
|
|
|
for d in range(num_branches):
|
|
|
|
if dim[d] == dim[(d+1) % num_branches] and False:
|
|
|
|
if dim[d] == dim[(d + 1) % num_branches] and False:
|
|
|
|
tmp = [nn.Identity()]
|
|
|
|
tmp = [nn.Identity()]
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
tmp = [norm_layer(dim[d]), act_layer(), nn.Linear(dim[d], dim[(d+1) % num_branches])]
|
|
|
|
tmp = [norm_layer(dim[d]), act_layer(), nn.Linear(dim[d], dim[(d + 1) % num_branches])]
|
|
|
|
self.projs.append(nn.Sequential(*tmp))
|
|
|
|
self.projs.append(nn.Sequential(*tmp))
|
|
|
|
|
|
|
|
|
|
|
|
self.fusion = nn.ModuleList()
|
|
|
|
self.fusion = nn.ModuleList()
|
|
|
|
for d in range(num_branches):
|
|
|
|
for d in range(num_branches):
|
|
|
|
d_ = (d+1) % num_branches
|
|
|
|
d_ = (d + 1) % num_branches
|
|
|
|
nh = num_heads[d_]
|
|
|
|
nh = num_heads[d_]
|
|
|
|
if depth[-1] == 0: # backward capability:
|
|
|
|
if depth[-1] == 0: # backward capability:
|
|
|
|
self.fusion.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
|
self.fusion.append(
|
|
|
|
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer))
|
|
|
|
CrossAttentionBlock(
|
|
|
|
|
|
|
|
dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias,
|
|
|
|
|
|
|
|
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer))
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
tmp = []
|
|
|
|
tmp = []
|
|
|
|
for _ in range(depth[-1]):
|
|
|
|
for _ in range(depth[-1]):
|
|
|
|
tmp.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
|
|
|
tmp.append(CrossAttentionBlock(
|
|
|
|
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer))
|
|
|
|
dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias,
|
|
|
|
|
|
|
|
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer))
|
|
|
|
self.fusion.append(nn.Sequential(*tmp))
|
|
|
|
self.fusion.append(nn.Sequential(*tmp))
|
|
|
|
|
|
|
|
|
|
|
|
self.revert_projs = nn.ModuleList()
|
|
|
|
self.revert_projs = nn.ModuleList()
|
|
|
|
for d in range(num_branches):
|
|
|
|
for d in range(num_branches):
|
|
|
|
if dim[(d+1) % num_branches] == dim[d] and False:
|
|
|
|
if dim[(d + 1) % num_branches] == dim[d] and False:
|
|
|
|
tmp = [nn.Identity()]
|
|
|
|
tmp = [nn.Identity()]
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
tmp = [norm_layer(dim[(d+1) % num_branches]), act_layer(), nn.Linear(dim[(d+1) % num_branches], dim[d])]
|
|
|
|
tmp = [norm_layer(dim[(d + 1) % num_branches]), act_layer(),
|
|
|
|
|
|
|
|
nn.Linear(dim[(d + 1) % num_branches], dim[d])]
|
|
|
|
self.revert_projs.append(nn.Sequential(*tmp))
|
|
|
|
self.revert_projs.append(nn.Sequential(*tmp))
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
|
|
|
|
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
|
|
|
@ -225,23 +252,29 @@ class MultiScaleBlock(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _compute_num_patches(img_size, patches):
|
|
|
|
def _compute_num_patches(img_size, patches):
|
|
|
|
return [i // p * i // p for i, p in zip(img_size,patches)]
|
|
|
|
return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CrossViT(nn.Module):
|
|
|
|
class CrossViT(nn.Module):
|
|
|
|
""" Vision Transformer with support for patch or hybrid CNN input stage
|
|
|
|
""" Vision Transformer with support for patch or hybrid CNN input stage
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
def __init__(self, img_size=(224, 224), patch_size=(8, 16), in_chans=3, num_classes=1000, embed_dim=(192, 384), depth=([1, 3, 1], [1, 3, 1], [1, 3, 1]),
|
|
|
|
|
|
|
|
num_heads=(6, 12), mlp_ratio=(2., 2., 4.), qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
|
|
|
def __init__(
|
|
|
|
drop_path_rate=0., norm_layer=nn.LayerNorm, multi_conv=False):
|
|
|
|
self, img_size=224, img_scale=(1.0, 1.0), patch_size=(8, 16), in_chans=3, num_classes=1000,
|
|
|
|
|
|
|
|
embed_dim=(192, 384), depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)), num_heads=(6, 12), mlp_ratio=(2., 2., 4.),
|
|
|
|
|
|
|
|
qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
|
|
|
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=False
|
|
|
|
|
|
|
|
):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.num_classes = num_classes
|
|
|
|
if not isinstance(img_size, list):
|
|
|
|
if not isinstance(img_size, (tuple, list)):
|
|
|
|
img_size = to_2tuple(img_size)
|
|
|
|
img_size = to_2tuple(img_size)
|
|
|
|
self.img_size = img_size
|
|
|
|
self.img_size = img_size
|
|
|
|
|
|
|
|
if not isinstance(img_scale, (tuple, list)):
|
|
|
|
num_patches = _compute_num_patches(img_size, patch_size)
|
|
|
|
img_scale = to_2tuple(img_scale)
|
|
|
|
|
|
|
|
self.img_size_scaled = [tuple([int(sj * si) for sj in img_size]) for si in img_scale]
|
|
|
|
|
|
|
|
num_patches = _compute_num_patches(self.img_size_scaled, patch_size)
|
|
|
|
self.num_branches = len(patch_size)
|
|
|
|
self.num_branches = len(patch_size)
|
|
|
|
self.embed_dim = embed_dim
|
|
|
|
self.embed_dim = embed_dim
|
|
|
|
self.num_features = embed_dim[0] # to pass the tests
|
|
|
|
self.num_features = embed_dim[0] # to pass the tests
|
|
|
@ -252,8 +285,9 @@ class CrossViT(nn.Module):
|
|
|
|
setattr(self, f'pos_embed_{i}', nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])))
|
|
|
|
setattr(self, f'pos_embed_{i}', nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])))
|
|
|
|
setattr(self, f'cls_token_{i}', nn.Parameter(torch.zeros(1, 1, embed_dim[i])))
|
|
|
|
setattr(self, f'cls_token_{i}', nn.Parameter(torch.zeros(1, 1, embed_dim[i])))
|
|
|
|
|
|
|
|
|
|
|
|
for im_s, p, d in zip(img_size, patch_size, embed_dim):
|
|
|
|
for im_s, p, d in zip(self.img_size_scaled, patch_size, embed_dim):
|
|
|
|
self.patch_embed.append(PatchEmbed(img_size=im_s, patch_size=p, in_chans=in_chans, embed_dim=d, multi_conv=multi_conv))
|
|
|
|
self.patch_embed.append(
|
|
|
|
|
|
|
|
PatchEmbed(img_size=im_s, patch_size=p, in_chans=in_chans, embed_dim=d, multi_conv=multi_conv))
|
|
|
|
|
|
|
|
|
|
|
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
|
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
|
|
|
|
|
|
|
|
|
@ -264,14 +298,16 @@ class CrossViT(nn.Module):
|
|
|
|
for idx, block_cfg in enumerate(depth):
|
|
|
|
for idx, block_cfg in enumerate(depth):
|
|
|
|
curr_depth = max(block_cfg[:-1]) + block_cfg[-1]
|
|
|
|
curr_depth = max(block_cfg[:-1]) + block_cfg[-1]
|
|
|
|
dpr_ = dpr[dpr_ptr:dpr_ptr + curr_depth]
|
|
|
|
dpr_ = dpr[dpr_ptr:dpr_ptr + curr_depth]
|
|
|
|
blk = MultiScaleBlock(embed_dim, num_patches, block_cfg, num_heads=num_heads, mlp_ratio=mlp_ratio,
|
|
|
|
blk = MultiScaleBlock(
|
|
|
|
qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr_,
|
|
|
|
embed_dim, num_patches, block_cfg, num_heads=num_heads, mlp_ratio=mlp_ratio,
|
|
|
|
norm_layer=norm_layer)
|
|
|
|
qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr_, norm_layer=norm_layer)
|
|
|
|
dpr_ptr += curr_depth
|
|
|
|
dpr_ptr += curr_depth
|
|
|
|
self.blocks.append(blk)
|
|
|
|
self.blocks.append(blk)
|
|
|
|
|
|
|
|
|
|
|
|
self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)])
|
|
|
|
self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)])
|
|
|
|
self.head = nn.ModuleList([nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)])
|
|
|
|
self.head = nn.ModuleList([
|
|
|
|
|
|
|
|
nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
|
|
|
|
for i in range(self.num_branches)])
|
|
|
|
|
|
|
|
|
|
|
|
for i in range(self.num_branches):
|
|
|
|
for i in range(self.num_branches):
|
|
|
|
if hasattr(self, f'pos_embed_{i}'):
|
|
|
|
if hasattr(self, f'pos_embed_{i}'):
|
|
|
@ -301,13 +337,16 @@ class CrossViT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def reset_classifier(self, num_classes, global_pool=''):
|
|
|
|
def reset_classifier(self, num_classes, global_pool=''):
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.head = nn.ModuleList([nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)])
|
|
|
|
self.head = nn.ModuleList(
|
|
|
|
|
|
|
|
[nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in
|
|
|
|
|
|
|
|
range(self.num_branches)])
|
|
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
def forward_features(self, x):
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
xs = []
|
|
|
|
xs = []
|
|
|
|
for i, patch_embed in enumerate(self.patch_embed):
|
|
|
|
for i, patch_embed in enumerate(self.patch_embed):
|
|
|
|
x_ = torch.nn.functional.interpolate(x, size=(self.img_size[i], self.img_size[i]), mode='bicubic') if H != self.img_size[i] else x
|
|
|
|
ss = self.img_size_scaled[i]
|
|
|
|
|
|
|
|
x_ = torch.nn.functional.interpolate(x, size=ss, mode='bicubic') if H != ss[0] else x
|
|
|
|
tmp = patch_embed(x_)
|
|
|
|
tmp = patch_embed(x_)
|
|
|
|
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
|
|
|
|
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
|
|
|
|
cls_tokens = cls_tokens.expand(B, -1, -1)
|
|
|
|
cls_tokens = cls_tokens.expand(B, -1, -1)
|
|
|
@ -322,14 +361,16 @@ class CrossViT(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
# NOTE: was before branch token section, move to here to assure all branch token are before layer norm
|
|
|
|
# NOTE: was before branch token section, move to here to assure all branch token are before layer norm
|
|
|
|
xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
|
|
|
|
xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
|
|
|
|
out = [x[:, 0] for x in xs]
|
|
|
|
return tuple([x[:, 0] for x in xs])
|
|
|
|
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
xs = self.forward_features(x)
|
|
|
|
xs = self.forward_features(x)
|
|
|
|
ce_logits = [head(xs[i]) for i, head in enumerate(self.head)]
|
|
|
|
ce_logits = [head(xs[i]) for i, head in enumerate(self.head)]
|
|
|
|
ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0)
|
|
|
|
if isinstance(self.head[0], nn.Identity):
|
|
|
|
|
|
|
|
# FIXME to pass current passthrough features tests, could use better approach
|
|
|
|
|
|
|
|
ce_logits = tuple(ce_logits)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0)
|
|
|
|
return ce_logits
|
|
|
|
return ce_logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -353,109 +394,101 @@ def _create_crossvit(variant, pretrained=False, **kwargs):
|
|
|
|
pretrained_filter_fn=pretrained_filter_fn,
|
|
|
|
pretrained_filter_fn=pretrained_filter_fn,
|
|
|
|
**kwargs)
|
|
|
|
**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def crossvit_tiny_224(pretrained=False, **kwargs):
|
|
|
|
def crossvit_tiny_240(pretrained=False, **kwargs):
|
|
|
|
model_args = dict(
|
|
|
|
model_args = dict(
|
|
|
|
img_size=[240, 224], patch_size=[12, 16], embed_dim=[96, 192], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
|
|
|
|
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[96, 192], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
|
|
|
|
num_heads=[3, 3], mlp_ratio=[4, 4, 1], qkv_bias=True,
|
|
|
|
num_heads=[3, 3], mlp_ratio=[4, 4, 1], **kwargs)
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
model = _create_crossvit(variant='crossvit_tiny_240', pretrained=pretrained, **model_args)
|
|
|
|
model = _create_crossvit(variant='crossvit_tiny_224', pretrained=pretrained, **model_args)
|
|
|
|
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def crossvit_small_224(pretrained=False, **kwargs):
|
|
|
|
def crossvit_small_240(pretrained=False, **kwargs):
|
|
|
|
model_args = dict(img_size=[240, 224],
|
|
|
|
model_args = dict(
|
|
|
|
patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
|
|
|
|
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
|
|
|
|
num_heads=[6, 6], mlp_ratio=[4, 4, 1], qkv_bias=True,
|
|
|
|
num_heads=[6, 6], mlp_ratio=[4, 4, 1], **kwargs)
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
model = _create_crossvit(variant='crossvit_small_240', pretrained=pretrained, **model_args)
|
|
|
|
model = _create_crossvit(variant='crossvit_small_224', pretrained=pretrained, **model_args)
|
|
|
|
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def crossvit_base_224(pretrained=False, **kwargs):
|
|
|
|
def crossvit_base_240(pretrained=False, **kwargs):
|
|
|
|
model_args = dict(img_size=[240, 224],
|
|
|
|
model_args = dict(
|
|
|
|
patch_size=[12, 16], embed_dim=[384, 768], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
|
|
|
|
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[384, 768], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
|
|
|
|
num_heads=[12, 12], mlp_ratio=[4, 4, 1], qkv_bias=True,
|
|
|
|
num_heads=[12, 12], mlp_ratio=[4, 4, 1], **kwargs)
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
model = _create_crossvit(variant='crossvit_base_240', pretrained=pretrained, **model_args)
|
|
|
|
model = _create_crossvit(variant='crossvit_base_224', pretrained=pretrained, **model_args)
|
|
|
|
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def crossvit_9_224(pretrained=False, **kwargs):
|
|
|
|
def crossvit_9_240(pretrained=False, **kwargs):
|
|
|
|
model_args = dict(img_size=[240, 224],
|
|
|
|
model_args = dict(
|
|
|
|
patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
|
|
|
|
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
|
|
|
|
num_heads=[4, 4], mlp_ratio=[3, 3, 1], qkv_bias=True,
|
|
|
|
num_heads=[4, 4], mlp_ratio=[3, 3, 1], **kwargs)
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
model = _create_crossvit(variant='crossvit_9_240', pretrained=pretrained, **model_args)
|
|
|
|
model = _create_crossvit(variant='crossvit_9_224', pretrained=pretrained, **model_args)
|
|
|
|
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def crossvit_15_224(pretrained=False, **kwargs):
|
|
|
|
def crossvit_15_240(pretrained=False, **kwargs):
|
|
|
|
model_args = dict(img_size=[240, 224],
|
|
|
|
model_args = dict(
|
|
|
|
patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
|
|
|
|
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
|
|
|
|
num_heads=[6, 6], mlp_ratio=[3, 3, 1], qkv_bias=True,
|
|
|
|
num_heads=[6, 6], mlp_ratio=[3, 3, 1], **kwargs)
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
model = _create_crossvit(variant='crossvit_15_240', pretrained=pretrained, **model_args)
|
|
|
|
model = _create_crossvit(variant='crossvit_15_224', pretrained=pretrained, **model_args)
|
|
|
|
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def crossvit_18_224(pretrained=False, **kwargs):
|
|
|
|
def crossvit_18_240(pretrained=False, **kwargs):
|
|
|
|
model_args = dict(img_size=[240, 224],
|
|
|
|
model_args = dict(
|
|
|
|
patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
|
|
|
|
img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
|
|
|
|
num_heads=[7, 7], mlp_ratio=[3, 3, 1], qkv_bias=True,
|
|
|
|
num_heads=[7, 7], mlp_ratio=[3, 3, 1], **kwargs)
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
|
|
model = _create_crossvit(variant='crossvit_18_240', pretrained=pretrained, **model_args)
|
|
|
|
model = _create_crossvit(variant='crossvit_18_224', pretrained=pretrained, **model_args)
|
|
|
|
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def crossvit_9_dagger_224(pretrained=False, **kwargs):
|
|
|
|
def crossvit_9_dagger_240(pretrained=False, **kwargs):
|
|
|
|
model_args = dict(img_size=[240, 224],
|
|
|
|
model_args = dict(
|
|
|
|
patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
|
|
|
|
img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
|
|
|
|
num_heads=[4, 4], mlp_ratio=[3, 3, 1], qkv_bias=True,
|
|
|
|
num_heads=[4, 4], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs)
|
|
|
|
model = _create_crossvit(variant='crossvit_9_dagger_240', pretrained=pretrained, **model_args)
|
|
|
|
model = _create_crossvit(variant='crossvit_9_dagger_224', pretrained=pretrained, **model_args)
|
|
|
|
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def crossvit_15_dagger_224(pretrained=False, **kwargs):
|
|
|
|
def crossvit_15_dagger_240(pretrained=False, **kwargs):
|
|
|
|
model_args = dict(img_size=[240, 224],
|
|
|
|
model_args = dict(
|
|
|
|
patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
|
|
|
|
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
|
|
|
|
num_heads=[6, 6], mlp_ratio=[3, 3, 1], qkv_bias=True,
|
|
|
|
num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs)
|
|
|
|
model = _create_crossvit(variant='crossvit_15_dagger_240', pretrained=pretrained, **model_args)
|
|
|
|
model = _create_crossvit(variant='crossvit_15_dagger_224', pretrained=pretrained, **model_args)
|
|
|
|
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def crossvit_15_dagger_384(pretrained=False, **kwargs):
|
|
|
|
def crossvit_15_dagger_408(pretrained=False, **kwargs):
|
|
|
|
model_args = dict(img_size=[408, 384],
|
|
|
|
model_args = dict(
|
|
|
|
patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
|
|
|
|
img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
|
|
|
|
num_heads=[6, 6], mlp_ratio=[3, 3, 1], qkv_bias=True,
|
|
|
|
num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs)
|
|
|
|
model = _create_crossvit(variant='crossvit_15_dagger_408', pretrained=pretrained, **model_args)
|
|
|
|
model = _create_crossvit(variant='crossvit_15_dagger_384', pretrained=pretrained, **model_args)
|
|
|
|
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def crossvit_18_dagger_224(pretrained=False, **kwargs):
|
|
|
|
def crossvit_18_dagger_240(pretrained=False, **kwargs):
|
|
|
|
model_args = dict(img_size=[240, 224],
|
|
|
|
model_args = dict(
|
|
|
|
patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
|
|
|
|
img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
|
|
|
|
num_heads=[7, 7], mlp_ratio=[3, 3, 1], qkv_bias=True,
|
|
|
|
num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs)
|
|
|
|
model = _create_crossvit(variant='crossvit_18_dagger_240', pretrained=pretrained, **model_args)
|
|
|
|
model = _create_crossvit(variant='crossvit_18_dagger_224', pretrained=pretrained, **model_args)
|
|
|
|
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def crossvit_18_dagger_384(pretrained=False, **kwargs):
|
|
|
|
def crossvit_18_dagger_408(pretrained=False, **kwargs):
|
|
|
|
model_args = dict(img_size=[408, 384],
|
|
|
|
model_args = dict(
|
|
|
|
patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
|
|
|
|
img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
|
|
|
|
num_heads=[7, 7], mlp_ratio=[3, 3, 1], qkv_bias=True,
|
|
|
|
num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
|
|
|
|
norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs)
|
|
|
|
model = _create_crossvit(variant='crossvit_18_dagger_408', pretrained=pretrained, **model_args)
|
|
|
|
model = _create_crossvit(variant='crossvit_18_dagger_384', pretrained=pretrained, **model_args)
|
|
|
|
|
|
|
|
return model
|
|
|
|
return model
|
|
|
|