Push all MaxxViT weights to HF hub, cleanup impl, add feature map extraction support and prompote to 'std' architecture. Fix norm head for proper embedding / feat map output. Add new in12k + ft 1k weights.

pull/1641/head
Ross Wightman 2 years ago
parent ca38e1e73f
commit bed350f5e5

@ -27,8 +27,9 @@ NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*', 'flexivit*'
'eva_*', 'flexivit*'
]
#'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', '
NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures
@ -53,7 +54,7 @@ MAX_JIT_SIZE = 320
TARGET_FFEAT_SIZE = 96
MAX_FFEAT_SIZE = 256
TARGET_FWD_FX_SIZE = 128
MAX_FWD_FX_SIZE = 224
MAX_FWD_FX_SIZE = 256
TARGET_BWD_FX_SIZE = 128
MAX_BWD_FX_SIZE = 224

@ -12,9 +12,6 @@ These configs work well and appear to be a bit faster / lower resource than the
The models without extra prefix / suffix' (coatnet_0_224, maxvit_tiny_224, etc), are intended to
match paper, BUT, without any official pretrained weights it's difficult to confirm a 100% match.
# FIXME / WARNING
This impl remains a WIP, some configs and models may vanish or change...
Papers:
MaxViT: Multi-Axis Vision Transformer - https://arxiv.org/abs/2204.01697
@ -76,6 +73,8 @@ class MaxxVitTransformerCfg:
partition_ratio: int = 32
window_size: Optional[Tuple[int, int]] = None
grid_size: Optional[Tuple[int, int]] = None
no_block_attn: bool = False # disable window block attention for maxvit (ie only grid)
use_nchw_attn: bool = False # for MaxViT variants (not used for CoAt), keep tensors in NCHW order
init_values: Optional[float] = None
act_layer: str = 'gelu'
norm_layer: str = 'layernorm2d'
@ -889,19 +888,17 @@ class MaxxVitBlock(nn.Module):
stride: int = 1,
conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
use_nchw_attn: bool = False, # FIXME move to cfg? True is ~20-30% faster on TPU, 5-10% slower on GPU
use_block_attn: bool = True, # FIXME for testing ConvNeXt conv w/o block attention
drop_path: float = 0.,
):
super().__init__()
self.nchw_attn = transformer_cfg.use_nchw_attn
conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttentionCl
self.nchw_attn = use_nchw_attn
self.attn_block = partition_layer(**attn_kwargs) if use_block_attn else None
partition_layer = PartitionAttention2d if self.nchw_attn else PartitionAttentionCl
self.attn_block = None if transformer_cfg.no_block_attn else partition_layer(**attn_kwargs)
self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs)
def init_weights(self, scheme=''):
@ -1084,26 +1081,48 @@ class NormMlpHead(nn.Module):
hidden_size=None,
pool_type='avg',
drop_rate=0.,
norm_layer=nn.LayerNorm,
act_layer=nn.Tanh,
norm_layer='layernorm2d',
act_layer='tanh',
):
super().__init__()
self.drop_rate = drop_rate
self.in_features = in_features
self.hidden_size = hidden_size
self.num_features = in_features
self.use_conv = not pool_type
norm_layer = get_norm_layer(norm_layer)
act_layer = get_act_layer(act_layer)
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
self.norm = norm_layer(in_features)
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
if hidden_size:
self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(in_features, hidden_size)),
('fc', linear_layer(in_features, hidden_size)),
('act', act_layer()),
]))
self.num_features = hidden_size
else:
self.pre_logits = nn.Identity()
self.drop = nn.Dropout(self.drop_rate)
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def reset(self, num_classes, global_pool=None):
if global_pool is not None:
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.use_conv = self.global_pool.is_identity()
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
if self.hidden_size:
if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or
(isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)):
with torch.no_grad():
new_fc = linear_layer(self.in_features, self.hidden_size)
new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape))
new_fc.bias.copy_(self.pre_logits.fc.bias)
self.pre_logits.fc = new_fc
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x, pre_logits: bool = False):
x = self.global_pool(x)
@ -1163,6 +1182,7 @@ class MaxxVit(nn.Module):
self.num_features = self.embed_dim = cfg.embed_dim[-1]
self.drop_rate = drop_rate
self.grad_checkpointing = False
self.feature_info = []
self.stem = Stem(
in_chs=in_chans,
@ -1173,8 +1193,8 @@ class MaxxVit(nn.Module):
norm_layer=cfg.conv_cfg.norm_layer,
norm_eps=cfg.conv_cfg.norm_eps,
)
stride = self.stem.stride
self.feature_info += [dict(num_chs=self.stem.out_chs, reduction=2, module='stem')]
feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))])
num_stages = len(cfg.embed_dim)
@ -1198,15 +1218,17 @@ class MaxxVit(nn.Module):
)]
stride *= stage_stride
in_chs = out_chs
self.feature_info += [dict(num_chs=out_chs, reduction=stride, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
final_norm_layer = partial(get_norm_layer(cfg.transformer_cfg.norm_layer), eps=cfg.transformer_cfg.norm_eps)
if cfg.head_hidden_size:
self.head_hidden_size = cfg.head_hidden_size
if self.head_hidden_size:
self.norm = nn.Identity()
self.head = NormMlpHead(
self.num_features,
num_classes,
hidden_size=cfg.head_hidden_size,
hidden_size=self.head_hidden_size,
pool_type=global_pool,
drop_rate=drop_rate,
norm_layer=final_norm_layer,
@ -1253,9 +1275,7 @@ class MaxxVit(nn.Module):
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is None:
global_pool = self.head.global_pool.pool_type
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
self.head.reset(num_classes, global_pool)
def forward_features(self, x):
x = self.stem(x)
@ -1376,6 +1396,7 @@ def _next_cfg(
transformer_norm_layer='layernorm2d',
transformer_norm_layer_cl='layernorm',
window_size=None,
no_block_attn=False,
init_values=1e-6,
rel_pos_type='mlp', # MLP by default for maxxvit
rel_pos_dim=512,
@ -1396,6 +1417,7 @@ def _next_cfg(
expand_first=False,
pool_type=pool_type,
window_size=window_size,
no_block_attn=no_block_attn, # enabled for MaxxViT-V2
init_values=init_values[1],
norm_layer=transformer_norm_layer,
norm_layer_cl=transformer_norm_layer_cl,
@ -1422,8 +1444,8 @@ def _tf_cfg():
model_cfgs = dict(
# Fiddling with configs / defaults / still pretraining
coatnet_pico_rw_224=MaxxVitCfg(
# timm specific CoAtNet configs
coatnet_pico_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 3, 5, 2),
stem_width=(32, 64),
@ -1432,7 +1454,7 @@ model_cfgs = dict(
conv_attn_ratio=0.25,
),
),
coatnet_nano_rw_224=MaxxVitCfg(
coatnet_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3),
stem_width=(32, 64),
@ -1442,7 +1464,7 @@ model_cfgs = dict(
conv_attn_ratio=0.25,
),
),
coatnet_0_rw_224=MaxxVitCfg(
coatnet_0_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 3, 7, 2), # deeper than paper '0' model
stem_width=(32, 64),
@ -1451,7 +1473,7 @@ model_cfgs = dict(
transformer_shortcut_bias=False,
),
),
coatnet_1_rw_224=MaxxVitCfg(
coatnet_1_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
stem_width=(32, 64),
@ -1461,7 +1483,7 @@ model_cfgs = dict(
transformer_shortcut_bias=False,
)
),
coatnet_2_rw_224=MaxxVitCfg(
coatnet_2_rw=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2),
stem_width=(64, 128),
@ -1471,7 +1493,7 @@ model_cfgs = dict(
#init_values=1e-6,
),
),
coatnet_3_rw_224=MaxxVitCfg(
coatnet_3_rw=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2),
stem_width=(96, 192),
@ -1482,8 +1504,8 @@ model_cfgs = dict(
),
),
# Highly experimental configs
coatnet_bn_0_rw_224=MaxxVitCfg(
# Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
coatnet_bn_0_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 3, 7, 2), # deeper than paper '0' model
stem_width=(32, 64),
@ -1494,7 +1516,7 @@ model_cfgs = dict(
transformer_norm_layer='batchnorm2d',
)
),
coatnet_rmlp_nano_rw_224=MaxxVitCfg(
coatnet_rmlp_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3),
stem_width=(32, 64),
@ -1505,7 +1527,7 @@ model_cfgs = dict(
rel_pos_dim=384,
),
),
coatnet_rmlp_0_rw_224=MaxxVitCfg(
coatnet_rmlp_0_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 3, 7, 2), # deeper than paper '0' model
stem_width=(32, 64),
@ -1514,7 +1536,7 @@ model_cfgs = dict(
rel_pos_type='mlp',
),
),
coatnet_rmlp_1_rw_224=MaxxVitCfg(
coatnet_rmlp_1_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
stem_width=(32, 64),
@ -1526,7 +1548,7 @@ model_cfgs = dict(
rel_pos_dim=384, # was supposed to be 512, woops
),
),
coatnet_rmlp_1_rw2_224=MaxxVitCfg(
coatnet_rmlp_1_rw2=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
stem_width=(32, 64),
@ -1536,7 +1558,7 @@ model_cfgs = dict(
rel_pos_dim=512, # was supposed to be 512, woops
),
),
coatnet_rmlp_2_rw_224=MaxxVitCfg(
coatnet_rmlp_2_rw=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2),
stem_width=(64, 128),
@ -1547,7 +1569,7 @@ model_cfgs = dict(
rel_pos_type='mlp'
),
),
coatnet_rmlp_3_rw_224=MaxxVitCfg(
coatnet_rmlp_3_rw=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2),
stem_width=(96, 192),
@ -1559,14 +1581,14 @@ model_cfgs = dict(
),
),
coatnet_nano_cc_224=MaxxVitCfg(
coatnet_nano_cc=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3),
stem_width=(32, 64),
block_type=('C', 'C', ('C', 'T'), ('C', 'T')),
**_rw_coat_cfg(),
),
coatnext_nano_rw_224=MaxxVitCfg(
coatnext_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3),
stem_width=(32, 64),
@ -1578,99 +1600,95 @@ model_cfgs = dict(
),
# Trying to be like the CoAtNet paper configs
coatnet_0_224=MaxxVitCfg(
coatnet_0=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 3, 5, 2),
stem_width=64,
head_hidden_size=768,
),
coatnet_1_224=MaxxVitCfg(
coatnet_1=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
stem_width=64,
head_hidden_size=768,
),
coatnet_2_224=MaxxVitCfg(
coatnet_2=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2),
stem_width=128,
head_hidden_size=1024,
),
coatnet_3_224=MaxxVitCfg(
coatnet_3=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2),
stem_width=192,
head_hidden_size=1536,
),
coatnet_4_224=MaxxVitCfg(
coatnet_4=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536),
depths=(2, 12, 28, 2),
stem_width=192,
head_hidden_size=1536,
),
coatnet_5_224=MaxxVitCfg(
coatnet_5=MaxxVitCfg(
embed_dim=(256, 512, 1280, 2048),
depths=(2, 12, 28, 2),
stem_width=192,
head_hidden_size=2048,
),
# Experimental MaxVit configs
maxvit_pico_rw_256=MaxxVitCfg(
maxvit_pico_rw=MaxxVitCfg(
embed_dim=(32, 64, 128, 256),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(24, 32),
**_rw_max_cfg(),
),
maxvit_nano_rw_256=MaxxVitCfg(
maxvit_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(),
),
maxvit_tiny_rw_224=MaxxVitCfg(
maxvit_tiny_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(),
),
maxvit_tiny_rw_256=MaxxVitCfg(
maxvit_tiny_pm=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
block_type=('PM',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(),
),
maxvit_rmlp_pico_rw_256=MaxxVitCfg(
maxvit_rmlp_pico_rw=MaxxVitCfg(
embed_dim=(32, 64, 128, 256),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(24, 32),
**_rw_max_cfg(rel_pos_type='mlp'),
),
maxvit_rmlp_nano_rw_256=MaxxVitCfg(
maxvit_rmlp_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(rel_pos_type='mlp'),
),
maxvit_rmlp_tiny_rw_256=MaxxVitCfg(
maxvit_rmlp_tiny_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(rel_pos_type='mlp'),
),
maxvit_rmlp_small_rw_224=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(
rel_pos_type='mlp',
init_values=1e-6,
),
),
maxvit_rmlp_small_rw_256=MaxxVitCfg(
maxvit_rmlp_small_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
@ -1680,17 +1698,7 @@ model_cfgs = dict(
init_values=1e-6,
),
),
maxvit_rmlp_base_rw_224=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
head_hidden_size=768,
**_rw_max_cfg(
rel_pos_type='mlp',
),
),
maxvit_rmlp_base_rw_384=MaxxVitCfg(
maxvit_rmlp_base_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
block_type=('M',) * 4,
@ -1701,15 +1709,7 @@ model_cfgs = dict(
),
),
maxvit_tiny_pm_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('PM',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(),
),
maxxvit_rmlp_nano_rw_256=MaxxVitCfg(
maxxvit_rmlp_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
@ -1717,33 +1717,50 @@ model_cfgs = dict(
weight_init='normal',
**_next_cfg(),
),
maxxvit_rmlp_tiny_rw_256=MaxxVitCfg(
maxxvit_rmlp_tiny_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_next_cfg(),
),
maxxvit_rmlp_small_rw_256=MaxxVitCfg(
maxxvit_rmlp_small_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(48, 96),
**_next_cfg(),
),
maxxvit_rmlp_base_rw_224=MaxxVitCfg(
maxxvitv2_nano_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
stem_width=(48, 96),
**_next_cfg(),
weight_init='normal',
**_next_cfg(
no_block_attn=True,
rel_pos_type='bias',
),
),
maxxvit_rmlp_large_rw_224=MaxxVitCfg(
maxxvitv2_rmlp_base_rw=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 12, 2),
block_type=('M',) * 4,
stem_width=(64, 128),
**_next_cfg(),
**_next_cfg(
no_block_attn=True,
),
),
maxxvitv2_rmlp_large_rw=MaxxVitCfg(
embed_dim=(160, 320, 640, 1280),
depths=(2, 6, 16, 2),
block_type=('M',) * 4,
stem_width=(80, 160),
head_hidden_size=1280,
**_next_cfg(
no_block_attn=True,
),
),
# Trying to be like the MaxViT paper configs
@ -1795,11 +1812,29 @@ model_cfgs = dict(
)
def checkpoint_filter_fn(state_dict, model: nn.Module):
model_state_dict = model.state_dict()
out_dict = {}
for k, v in state_dict.items():
if k in model_state_dict and v.ndim != model_state_dict[k].ndim and v.numel() == model_state_dict[k].numel():
# adapt between conv2d / linear layers
assert v.ndim in (2, 4)
v = v.reshape(model_state_dict[k].shape)
out_dict[k] = v
return out_dict
def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs):
if cfg_variant is None:
if variant in model_cfgs:
cfg_variant = variant
else:
cfg_variant = '_'.join(variant.split('_')[:-1])
return build_model_with_cfg(
MaxxVit, variant, pretrained,
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
model_cfg=model_cfgs[cfg_variant],
feature_cfg=dict(flatten_sequential=True),
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
@ -1815,155 +1850,218 @@ def _cfg(url='', **kwargs):
default_cfgs = generate_default_cfgs({
# Fiddling with configs / defaults / still pretraining
'coatnet_pico_rw_224': _cfg(url=''),
'coatnet_nano_rw_224': _cfg(
# timm specific CoAtNet configs, ImageNet-1k pretrain, fixed rel-pos
'coatnet_pico_rw_224.untrained': _cfg(url=''),
'coatnet_nano_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth',
crop_pct=0.9),
'coatnet_0_rw_224': _cfg(
'coatnet_0_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'),
'coatnet_1_rw_224': _cfg(
'coatnet_1_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth'
),
'coatnet_2_rw_224': _cfg(url=''),
'coatnet_3_rw_224': _cfg(url=''),
# Highly experimental configs
'coatnet_bn_0_rw_224': _cfg(
# timm specific CoAtNet configs, ImageNet-12k pretrain w/ 1k fine-tune, fixed rel-pos
'coatnet_2_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
#'coatnet_3_rw_224.untrained': _cfg(url=''),
# Experimental CoAtNet configs w/ ImageNet-12k pretrain -> 1k fine-tune (different norm layers, MLP rel-pos)
'coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
'coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
'coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
# Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
'coatnet_bn_0_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
crop_pct=0.95),
'coatnet_rmlp_nano_rw_224': _cfg(
'coatnet_rmlp_nano_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth',
crop_pct=0.9),
'coatnet_rmlp_0_rw_224': _cfg(url=''),
'coatnet_rmlp_1_rw_224': _cfg(
'coatnet_rmlp_0_rw_224.untrained': _cfg(url=''),
'coatnet_rmlp_1_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'),
'coatnet_rmlp_1_rw2_224': _cfg(url=''),
'coatnet_rmlp_2_rw_224': _cfg(
'coatnet_rmlp_2_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'),
'coatnet_rmlp_3_rw_224': _cfg(url=''),
'coatnet_nano_cc_224': _cfg(url=''),
'coatnext_nano_rw_224': _cfg(
'coatnet_rmlp_3_rw_224.untrained': _cfg(url=''),
'coatnet_nano_cc_224.untrained': _cfg(url=''),
'coatnext_nano_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth',
crop_pct=0.9),
# Trying to be like the CoAtNet paper configs
'coatnet_0_224': _cfg(url=''),
'coatnet_1_224': _cfg(url=''),
'coatnet_2_224': _cfg(url=''),
'coatnet_3_224': _cfg(url=''),
'coatnet_4_224': _cfg(url=''),
'coatnet_5_224': _cfg(url=''),
# Experimental configs
'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_nano_rw_256': _cfg(
# ImagenNet-12k pretrain CoAtNet
'coatnet_2_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
'coatnet_3_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
'coatnet_rmlp_1_rw2_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
'coatnet_rmlp_2_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
# Trying to be like the CoAtNet paper configs (will adapt if 'tf' weights are ever released)
'coatnet_0_224.untrained': _cfg(url=''),
'coatnet_1_224.untrained': _cfg(url=''),
'coatnet_2_224.untrained': _cfg(url=''),
'coatnet_3_224.untrained': _cfg(url=''),
'coatnet_4_224.untrained': _cfg(url=''),
'coatnet_5_224.untrained': _cfg(url=''),
# timm specific MaxVit configs, ImageNet-1k pretrain or untrained
'maxvit_pico_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_rw_224': _cfg(
'maxvit_tiny_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'),
'maxvit_tiny_rw_256': _cfg(
'maxvit_tiny_rw_256.untrained': _cfg(
url='',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_pico_rw_256': _cfg(
'maxvit_tiny_pm_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
# timm specific MaxVit w/ MLP rel-pos, ImageNet-1k pretrain
'maxvit_rmlp_pico_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_nano_rw_256': _cfg(
'maxvit_rmlp_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_tiny_rw_256': _cfg(
'maxvit_rmlp_tiny_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_small_rw_224': _cfg(
'maxvit_rmlp_small_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth',
crop_pct=0.9,
),
'maxvit_rmlp_small_rw_256': _cfg(
'maxvit_rmlp_small_rw_256.untrained': _cfg(
url='',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_base_rw_224': _cfg(
url='',
# timm specific MaxVit w/ ImageNet-12k pretrain and 1k fine-tune
'maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
),
'maxvit_rmlp_base_rw_384': _cfg(
url='',
input_size=(3, 384, 384), pool_size=(12, 12)),
'maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
# timm specific MaxVit w/ ImageNet-12k pretrain
'maxvit_rmlp_base_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821,
),
'maxxvit_rmlp_nano_rw_256': _cfg(
# timm MaxxViT configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks)
'maxxvit_rmlp_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_rmlp_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_rmlp_small_rw_256': _cfg(
'maxxvit_rmlp_tiny_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_rmlp_small_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_rmlp_base_rw_224': _cfg(url=''),
'maxxvit_rmlp_large_rw_224': _cfg(url=''),
# timm MaxxViT-V2 configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks, more width, no block attn)
'maxxvitv2_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
'maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxxvitv2_rmlp_large_rw_224.untrained': _cfg(url=''),
'maxxvitv2_rmlp_base_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
# MaxViT models ported from official Tensorflow impl
'maxvit_tiny_tf_224.in1k': _cfg(
hf_hub_id='timm/maxvit_tiny_tf_224.in1k',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_tiny_tf_384.in1k': _cfg(
hf_hub_id='timm/maxvit_tiny_tf_384.in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_tiny_tf_512.in1k': _cfg(
hf_hub_id='timm/maxvit_tiny_tf_512.in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_small_tf_224.in1k': _cfg(
hf_hub_id='timm/maxvit_small_tf_224.in1k',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_small_tf_384.in1k': _cfg(
hf_hub_id='timm/maxvit_small_tf_384.in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_small_tf_512.in1k': _cfg(
hf_hub_id='timm/maxvit_small_tf_512.in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_224.in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_224.in1k',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_base_tf_384.in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_384.in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_512.in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_512.in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_224.in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_224.in1k',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_large_tf_384.in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_384.in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_512.in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_512.in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_224.in21k': _cfg(
url=''),
'maxvit_base_tf_384.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_384.in21k_ft_in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_512.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_512.in21k_ft_in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_224.in21k': _cfg(
url=''),
'maxvit_large_tf_384.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_384.in21k_ft_in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_512.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_512.in21k_ft_in1k',
hf_hub_id='timm/',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
'maxvit_xlarge_tf_224.in21k': _cfg(
url=''),
'maxvit_xlarge_tf_384.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_xlarge_tf_384.in21k_ft_in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_xlarge_tf_512.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_xlarge_tf_512.in21k_ft_in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
})
@ -2027,6 +2125,11 @@ def coatnet_rmlp_2_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_rmlp_2_rw_384(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_2_rw_384', pretrained=pretrained, **kwargs)
@register_model
def coatnet_rmlp_3_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs)
@ -2148,13 +2251,23 @@ def maxxvit_rmlp_small_rw_256(pretrained=False, **kwargs):
@register_model
def maxxvit_rmlp_base_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
def maxxvitv2_nano_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxxvitv2_nano_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxxvitv2_rmlp_base_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxxvitv2_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
@register_model
def maxxvitv2_rmlp_base_rw_384(pretrained=False, **kwargs):
return _create_maxxvit('maxxvitv2_rmlp_base_rw_384', pretrained=pretrained, **kwargs)
@register_model
def maxxvit_rmlp_large_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxxvit_rmlp_large_rw_224', pretrained=pretrained, **kwargs)
def maxxvitv2_rmlp_large_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxxvitv2_rmlp_large_rw_224', pretrained=pretrained, **kwargs)
@register_model

Loading…
Cancel
Save