Post crossvit merge cleanup, change model names to reflect input size, cleanup img size vs scale handling, fix tests

pull/821/head
Ross Wightman 3 years ago
parent a897e0ebcc
commit f1808e0970

@ -189,10 +189,12 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
input_tensor = torch.randn((batch_size, *input_size)) input_tensor = torch.randn((batch_size, *input_size))
# test forward_features (always unpooled) # test forward_features (always unpooled)
outputs = model.forward_features(input_tensor) if 'crossvit' not in model_name:
if isinstance(outputs, tuple): # FIXME remove crossvit exception
outputs = outputs[0] outputs = model.forward_features(input_tensor)
assert outputs.shape[1] == model.num_features if isinstance(outputs, tuple):
outputs = outputs[0]
assert outputs.shape[1] == model.num_features
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
model.reset_classifier(0) model.reset_classifier(0)

@ -1,9 +1,9 @@
from .byoanet import * from .byoanet import *
from .byobnet import * from .byobnet import *
from .cait import * from .cait import *
from .crossvit import *
from .coat import * from .coat import *
from .convit import * from .convit import *
from .crossvit import *
from .cspnet import * from .cspnet import *
from .densenet import * from .densenet import *
from .dla import * from .dla import *
@ -37,6 +37,7 @@ from .sknet import *
from .swin_transformer import * from .swin_transformer import *
from .tnt import * from .tnt import *
from .tresnet import * from .tresnet import *
from .twins import *
from .vgg import * from .vgg import *
from .visformer import * from .visformer import *
from .vision_transformer import * from .vision_transformer import *
@ -45,7 +46,6 @@ from .vovnet import *
from .xception import * from .xception import *
from .xception_aligned import * from .xception_aligned import *
from .xcit import * from .xcit import *
from .twins import *
from .factory import create_model, split_model_name, safe_model_name from .factory import create_model, split_model_name, safe_model_name
from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .helpers import load_checkpoint, resume_checkpoint, model_parameters

@ -10,6 +10,8 @@
Paper link: https://arxiv.org/abs/2103.14899 Paper link: https://arxiv.org/abs/2103.14899
Original code: https://github.com/IBM/CrossViT/blob/main/models/crossvit.py Original code: https://github.com/IBM/CrossViT/blob/main/models/crossvit.py
NOTE: model names have been renamed from originals to represent actual input res all *_224 -> *_240 and *_384 -> *_408
""" """
# Copyright IBM All Rights Reserved. # Copyright IBM All Rights Reserved.
@ -40,30 +42,49 @@ def _cfg(url='', **kwargs):
'url': url, 'url': url,
'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None, 'num_classes': 1000, 'input_size': (3, 240, 240), 'pool_size': None,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
# 'first_conv': 'patch_embed.proj', 'first_conv': ('patch_embed.0.proj', 'patch_embed.1.proj'),
'classifier': 'head', 'classifier': ('head.0', 'head.1'),
**kwargs **kwargs
} }
default_cfgs = { default_cfgs = {
'crossvit_15_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_224.pth'), 'crossvit_15_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_224.pth'),
'crossvit_15_dagger_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_224.pth'), 'crossvit_15_dagger_240': _cfg(
'crossvit_15_dagger_384': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_384.pth'), url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_224.pth',
'crossvit_18_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_224.pth'), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
'crossvit_18_dagger_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_224.pth'), ),
'crossvit_18_dagger_384': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_384.pth'), 'crossvit_15_dagger_408': _cfg(
'crossvit_9_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_224.pth'), url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_15_dagger_384.pth',
'crossvit_9_dagger_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_dagger_224.pth'), input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
'crossvit_base_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_base_224.pth'), ),
'crossvit_small_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_small_224.pth'), 'crossvit_18_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_224.pth'),
'crossvit_tiny_224': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_tiny_224.pth'), 'crossvit_18_dagger_240': _cfg(
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_224.pth',
first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
),
'crossvit_18_dagger_408': _cfg(
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_18_dagger_384.pth',
input_size=(3, 408, 408), first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
),
'crossvit_9_240': _cfg(url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_224.pth'),
'crossvit_9_dagger_240': _cfg(
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_9_dagger_224.pth',
first_conv=('patch_embed.0.proj.0', 'patch_embed.1.proj.0'),
),
'crossvit_base_240': _cfg(
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_base_224.pth'),
'crossvit_small_240': _cfg(
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_small_224.pth'),
'crossvit_tiny_240': _cfg(
url='https://github.com/IBM/CrossViT/releases/download/weights-0.1/crossvit_tiny_224.pth'),
} }
class PatchEmbed(nn.Module): class PatchEmbed(nn.Module):
""" Image to Patch Embedding """ Image to Patch Embedding
""" """
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi_conv=False): def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, multi_conv=False):
super().__init__() super().__init__()
img_size = to_2tuple(img_size) img_size = to_2tuple(img_size)
@ -117,17 +138,19 @@ class CrossAttention(nn.Module):
self.proj_drop = nn.Dropout(proj_drop) self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x): def forward(self, x):
B, N, C = x.shape B, N, C = x.shape
q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # B1C -> B1H(C/H) -> BH1(C/H) # B1C -> B1H(C/H) -> BH1(C/H)
k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H) q = self.wq(x[:, 0:1, ...]).reshape(B, 1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) # BNC -> BNH(C/H) -> BHN(C/H) # BNC -> BNH(C/H) -> BHN(C/H)
k = self.wk(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
# BNC -> BNH(C/H) -> BHN(C/H)
v = self.wv(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
attn = (q @ k.transpose(-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N attn = (q @ k.transpose(-2, -1)) * self.scale # BH1(C/H) @ BH(C/H)N -> BH1N
attn = attn.softmax(dim=-1) attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn) attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C x = (attn @ v).transpose(1, 2).reshape(B, 1, C) # (BH1N @ BHN(C/H)) -> BH1(C/H) -> B1H(C/H) -> B1C
x = self.proj(x) x = self.proj(x)
x = self.proj_drop(x) x = self.proj_drop(x)
return x return x
@ -152,7 +175,7 @@ class CrossAttentionBlock(nn.Module):
class MultiScaleBlock(nn.Module): class MultiScaleBlock(nn.Module):
def __init__(self, dim, patches, depth, num_heads, mlp_ratio, qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., def __init__(self, dim, patches, depth, num_heads, mlp_ratio, qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__() super().__init__()
@ -163,9 +186,9 @@ class MultiScaleBlock(nn.Module):
for d in range(num_branches): for d in range(num_branches):
tmp = [] tmp = []
for i in range(depth[d]): for i in range(depth[d]):
tmp.append( tmp.append(Block(
Block(dim=dim[d], num_heads=num_heads[d], mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, dim=dim[d], num_heads=num_heads[d], mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop, drop_path=drop_path[i], norm_layer=norm_layer)) drop=drop, attn_drop=attn_drop, drop_path=drop_path[i], norm_layer=norm_layer))
if len(tmp) != 0: if len(tmp) != 0:
self.blocks.append(nn.Sequential(*tmp)) self.blocks.append(nn.Sequential(*tmp))
@ -174,32 +197,36 @@ class MultiScaleBlock(nn.Module):
self.projs = nn.ModuleList() self.projs = nn.ModuleList()
for d in range(num_branches): for d in range(num_branches):
if dim[d] == dim[(d+1) % num_branches] and False: if dim[d] == dim[(d + 1) % num_branches] and False:
tmp = [nn.Identity()] tmp = [nn.Identity()]
else: else:
tmp = [norm_layer(dim[d]), act_layer(), nn.Linear(dim[d], dim[(d+1) % num_branches])] tmp = [norm_layer(dim[d]), act_layer(), nn.Linear(dim[d], dim[(d + 1) % num_branches])]
self.projs.append(nn.Sequential(*tmp)) self.projs.append(nn.Sequential(*tmp))
self.fusion = nn.ModuleList() self.fusion = nn.ModuleList()
for d in range(num_branches): for d in range(num_branches):
d_ = (d+1) % num_branches d_ = (d + 1) % num_branches
nh = num_heads[d_] nh = num_heads[d_]
if depth[-1] == 0: # backward capability: if depth[-1] == 0: # backward capability:
self.fusion.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale, self.fusion.append(
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer)) CrossAttentionBlock(
dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer))
else: else:
tmp = [] tmp = []
for _ in range(depth[-1]): for _ in range(depth[-1]):
tmp.append(CrossAttentionBlock(dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias, qk_scale=qk_scale, tmp.append(CrossAttentionBlock(
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer)) dim=dim[d_], num_heads=nh, mlp_ratio=mlp_ratio[d], qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop, drop_path=drop_path[-1], norm_layer=norm_layer))
self.fusion.append(nn.Sequential(*tmp)) self.fusion.append(nn.Sequential(*tmp))
self.revert_projs = nn.ModuleList() self.revert_projs = nn.ModuleList()
for d in range(num_branches): for d in range(num_branches):
if dim[(d+1) % num_branches] == dim[d] and False: if dim[(d + 1) % num_branches] == dim[d] and False:
tmp = [nn.Identity()] tmp = [nn.Identity()]
else: else:
tmp = [norm_layer(dim[(d+1) % num_branches]), act_layer(), nn.Linear(dim[(d+1) % num_branches], dim[d])] tmp = [norm_layer(dim[(d + 1) % num_branches]), act_layer(),
nn.Linear(dim[(d + 1) % num_branches], dim[d])]
self.revert_projs.append(nn.Sequential(*tmp)) self.revert_projs.append(nn.Sequential(*tmp))
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]: def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
@ -225,23 +252,29 @@ class MultiScaleBlock(nn.Module):
def _compute_num_patches(img_size, patches): def _compute_num_patches(img_size, patches):
return [i // p * i // p for i, p in zip(img_size,patches)] return [i[0] // p * i[1] // p for i, p in zip(img_size, patches)]
class CrossViT(nn.Module): class CrossViT(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage """ Vision Transformer with support for patch or hybrid CNN input stage
""" """
def __init__(self, img_size=(224, 224), patch_size=(8, 16), in_chans=3, num_classes=1000, embed_dim=(192, 384), depth=([1, 3, 1], [1, 3, 1], [1, 3, 1]),
num_heads=(6, 12), mlp_ratio=(2., 2., 4.), qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., def __init__(
drop_path_rate=0., norm_layer=nn.LayerNorm, multi_conv=False): self, img_size=224, img_scale=(1.0, 1.0), patch_size=(8, 16), in_chans=3, num_classes=1000,
embed_dim=(192, 384), depth=((1, 3, 1), (1, 3, 1), (1, 3, 1)), num_heads=(6, 12), mlp_ratio=(2., 2., 4.),
qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=False
):
super().__init__() super().__init__()
self.num_classes = num_classes self.num_classes = num_classes
if not isinstance(img_size, list): if not isinstance(img_size, (tuple, list)):
img_size = to_2tuple(img_size) img_size = to_2tuple(img_size)
self.img_size = img_size self.img_size = img_size
if not isinstance(img_scale, (tuple, list)):
num_patches = _compute_num_patches(img_size, patch_size) img_scale = to_2tuple(img_scale)
self.img_size_scaled = [tuple([int(sj * si) for sj in img_size]) for si in img_scale]
num_patches = _compute_num_patches(self.img_size_scaled, patch_size)
self.num_branches = len(patch_size) self.num_branches = len(patch_size)
self.embed_dim = embed_dim self.embed_dim = embed_dim
self.num_features = embed_dim[0] # to pass the tests self.num_features = embed_dim[0] # to pass the tests
@ -252,8 +285,9 @@ class CrossViT(nn.Module):
setattr(self, f'pos_embed_{i}', nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i]))) setattr(self, f'pos_embed_{i}', nn.Parameter(torch.zeros(1, 1 + num_patches[i], embed_dim[i])))
setattr(self, f'cls_token_{i}', nn.Parameter(torch.zeros(1, 1, embed_dim[i]))) setattr(self, f'cls_token_{i}', nn.Parameter(torch.zeros(1, 1, embed_dim[i])))
for im_s, p, d in zip(img_size, patch_size, embed_dim): for im_s, p, d in zip(self.img_size_scaled, patch_size, embed_dim):
self.patch_embed.append(PatchEmbed(img_size=im_s, patch_size=p, in_chans=in_chans, embed_dim=d, multi_conv=multi_conv)) self.patch_embed.append(
PatchEmbed(img_size=im_s, patch_size=p, in_chans=in_chans, embed_dim=d, multi_conv=multi_conv))
self.pos_drop = nn.Dropout(p=drop_rate) self.pos_drop = nn.Dropout(p=drop_rate)
@ -264,14 +298,16 @@ class CrossViT(nn.Module):
for idx, block_cfg in enumerate(depth): for idx, block_cfg in enumerate(depth):
curr_depth = max(block_cfg[:-1]) + block_cfg[-1] curr_depth = max(block_cfg[:-1]) + block_cfg[-1]
dpr_ = dpr[dpr_ptr:dpr_ptr + curr_depth] dpr_ = dpr[dpr_ptr:dpr_ptr + curr_depth]
blk = MultiScaleBlock(embed_dim, num_patches, block_cfg, num_heads=num_heads, mlp_ratio=mlp_ratio, blk = MultiScaleBlock(
qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr_, embed_dim, num_patches, block_cfg, num_heads=num_heads, mlp_ratio=mlp_ratio,
norm_layer=norm_layer) qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr_, norm_layer=norm_layer)
dpr_ptr += curr_depth dpr_ptr += curr_depth
self.blocks.append(blk) self.blocks.append(blk)
self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)]) self.norm = nn.ModuleList([norm_layer(embed_dim[i]) for i in range(self.num_branches)])
self.head = nn.ModuleList([nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)]) self.head = nn.ModuleList([
nn.Linear(embed_dim[i], num_classes) if num_classes > 0 else nn.Identity()
for i in range(self.num_branches)])
for i in range(self.num_branches): for i in range(self.num_branches):
if hasattr(self, f'pos_embed_{i}'): if hasattr(self, f'pos_embed_{i}'):
@ -301,13 +337,16 @@ class CrossViT(nn.Module):
def reset_classifier(self, num_classes, global_pool=''): def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes self.num_classes = num_classes
self.head = nn.ModuleList([nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in range(self.num_branches)]) self.head = nn.ModuleList(
[nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in
range(self.num_branches)])
def forward_features(self, x): def forward_features(self, x):
B, C, H, W = x.shape B, C, H, W = x.shape
xs = [] xs = []
for i, patch_embed in enumerate(self.patch_embed): for i, patch_embed in enumerate(self.patch_embed):
x_ = torch.nn.functional.interpolate(x, size=(self.img_size[i], self.img_size[i]), mode='bicubic') if H != self.img_size[i] else x ss = self.img_size_scaled[i]
x_ = torch.nn.functional.interpolate(x, size=ss, mode='bicubic') if H != ss[0] else x
tmp = patch_embed(x_) tmp = patch_embed(x_)
cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script cls_tokens = self.cls_token_0 if i == 0 else self.cls_token_1 # hard-coded for torch jit script
cls_tokens = cls_tokens.expand(B, -1, -1) cls_tokens = cls_tokens.expand(B, -1, -1)
@ -322,14 +361,16 @@ class CrossViT(nn.Module):
# NOTE: was before branch token section, move to here to assure all branch token are before layer norm # NOTE: was before branch token section, move to here to assure all branch token are before layer norm
xs = [norm(xs[i]) for i, norm in enumerate(self.norm)] xs = [norm(xs[i]) for i, norm in enumerate(self.norm)]
out = [x[:, 0] for x in xs] return tuple([x[:, 0] for x in xs])
return out
def forward(self, x): def forward(self, x):
xs = self.forward_features(x) xs = self.forward_features(x)
ce_logits = [head(xs[i]) for i, head in enumerate(self.head)] ce_logits = [head(xs[i]) for i, head in enumerate(self.head)]
ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0) if isinstance(self.head[0], nn.Identity):
# FIXME to pass current passthrough features tests, could use better approach
ce_logits = tuple(ce_logits)
else:
ce_logits = torch.mean(torch.stack(ce_logits, dim=0), dim=0)
return ce_logits return ce_logits
@ -353,109 +394,101 @@ def _create_crossvit(variant, pretrained=False, **kwargs):
pretrained_filter_fn=pretrained_filter_fn, pretrained_filter_fn=pretrained_filter_fn,
**kwargs) **kwargs)
@register_model @register_model
def crossvit_tiny_224(pretrained=False, **kwargs): def crossvit_tiny_240(pretrained=False, **kwargs):
model_args = dict( model_args = dict(
img_size=[240, 224], patch_size=[12, 16], embed_dim=[96, 192], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]], img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[96, 192], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
num_heads=[3, 3], mlp_ratio=[4, 4, 1], qkv_bias=True, num_heads=[3, 3], mlp_ratio=[4, 4, 1], **kwargs)
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model = _create_crossvit(variant='crossvit_tiny_240', pretrained=pretrained, **model_args)
model = _create_crossvit(variant='crossvit_tiny_224', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def crossvit_small_224(pretrained=False, **kwargs): def crossvit_small_240(pretrained=False, **kwargs):
model_args = dict(img_size=[240, 224], model_args = dict(
patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]], img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
num_heads=[6, 6], mlp_ratio=[4, 4, 1], qkv_bias=True, num_heads=[6, 6], mlp_ratio=[4, 4, 1], **kwargs)
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model = _create_crossvit(variant='crossvit_small_240', pretrained=pretrained, **model_args)
model = _create_crossvit(variant='crossvit_small_224', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def crossvit_base_224(pretrained=False, **kwargs): def crossvit_base_240(pretrained=False, **kwargs):
model_args = dict(img_size=[240, 224], model_args = dict(
patch_size=[12, 16], embed_dim=[384, 768], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]], img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[384, 768], depth=[[1, 4, 0], [1, 4, 0], [1, 4, 0]],
num_heads=[12, 12], mlp_ratio=[4, 4, 1], qkv_bias=True, num_heads=[12, 12], mlp_ratio=[4, 4, 1], **kwargs)
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model = _create_crossvit(variant='crossvit_base_240', pretrained=pretrained, **model_args)
model = _create_crossvit(variant='crossvit_base_224', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def crossvit_9_224(pretrained=False, **kwargs): def crossvit_9_240(pretrained=False, **kwargs):
model_args = dict(img_size=[240, 224], model_args = dict(
patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]], img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
num_heads=[4, 4], mlp_ratio=[3, 3, 1], qkv_bias=True, num_heads=[4, 4], mlp_ratio=[3, 3, 1], **kwargs)
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model = _create_crossvit(variant='crossvit_9_240', pretrained=pretrained, **model_args)
model = _create_crossvit(variant='crossvit_9_224', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def crossvit_15_224(pretrained=False, **kwargs): def crossvit_15_240(pretrained=False, **kwargs):
model_args = dict(img_size=[240, 224], model_args = dict(
patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]], img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
num_heads=[6, 6], mlp_ratio=[3, 3, 1], qkv_bias=True, num_heads=[6, 6], mlp_ratio=[3, 3, 1], **kwargs)
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model = _create_crossvit(variant='crossvit_15_240', pretrained=pretrained, **model_args)
model = _create_crossvit(variant='crossvit_15_224', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def crossvit_18_224(pretrained=False, **kwargs): def crossvit_18_240(pretrained=False, **kwargs):
model_args = dict(img_size=[240, 224], model_args = dict(
patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]], img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
num_heads=[7, 7], mlp_ratio=[3, 3, 1], qkv_bias=True, num_heads=[7, 7], mlp_ratio=[3, 3, 1], **kwargs)
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs) model = _create_crossvit(variant='crossvit_18_240', pretrained=pretrained, **model_args)
model = _create_crossvit(variant='crossvit_18_224', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def crossvit_9_dagger_224(pretrained=False, **kwargs): def crossvit_9_dagger_240(pretrained=False, **kwargs):
model_args = dict(img_size=[240, 224], model_args = dict(
patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]], img_scale=(1.0, 224 / 240), patch_size=[12, 16], embed_dim=[128, 256], depth=[[1, 3, 0], [1, 3, 0], [1, 3, 0]],
num_heads=[4, 4], mlp_ratio=[3, 3, 1], qkv_bias=True, num_heads=[4, 4], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs) model = _create_crossvit(variant='crossvit_9_dagger_240', pretrained=pretrained, **model_args)
model = _create_crossvit(variant='crossvit_9_dagger_224', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def crossvit_15_dagger_224(pretrained=False, **kwargs): def crossvit_15_dagger_240(pretrained=False, **kwargs):
model_args = dict(img_size=[240, 224], model_args = dict(
patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]], img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
num_heads=[6, 6], mlp_ratio=[3, 3, 1], qkv_bias=True, num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs) model = _create_crossvit(variant='crossvit_15_dagger_240', pretrained=pretrained, **model_args)
model = _create_crossvit(variant='crossvit_15_dagger_224', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def crossvit_15_dagger_384(pretrained=False, **kwargs): def crossvit_15_dagger_408(pretrained=False, **kwargs):
model_args = dict(img_size=[408, 384], model_args = dict(
patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]], img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[192, 384], depth=[[1, 5, 0], [1, 5, 0], [1, 5, 0]],
num_heads=[6, 6], mlp_ratio=[3, 3, 1], qkv_bias=True, num_heads=[6, 6], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs) model = _create_crossvit(variant='crossvit_15_dagger_408', pretrained=pretrained, **model_args)
model = _create_crossvit(variant='crossvit_15_dagger_384', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def crossvit_18_dagger_224(pretrained=False, **kwargs): def crossvit_18_dagger_240(pretrained=False, **kwargs):
model_args = dict(img_size=[240, 224], model_args = dict(
patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]], img_scale=(1.0, 224/240), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
num_heads=[7, 7], mlp_ratio=[3, 3, 1], qkv_bias=True, num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs) model = _create_crossvit(variant='crossvit_18_dagger_240', pretrained=pretrained, **model_args)
model = _create_crossvit(variant='crossvit_18_dagger_224', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def crossvit_18_dagger_384(pretrained=False, **kwargs): def crossvit_18_dagger_408(pretrained=False, **kwargs):
model_args = dict(img_size=[408, 384], model_args = dict(
patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]], img_scale=(1.0, 384/408), patch_size=[12, 16], embed_dim=[224, 448], depth=[[1, 6, 0], [1, 6, 0], [1, 6, 0]],
num_heads=[7, 7], mlp_ratio=[3, 3, 1], qkv_bias=True, num_heads=[7, 7], mlp_ratio=[3, 3, 1], multi_conv=True, **kwargs)
norm_layer=partial(nn.LayerNorm, eps=1e-6), multi_conv=True, **kwargs) model = _create_crossvit(variant='crossvit_18_dagger_408', pretrained=pretrained, **model_args)
model = _create_crossvit(variant='crossvit_18_dagger_384', pretrained=pretrained, **model_args)
return model return model

Loading…
Cancel
Save