@ -12,9 +12,6 @@ These configs work well and appear to be a bit faster / lower resource than the
The models without extra prefix / suffix ' (coatnet_0_224, maxvit_tiny_224, etc), are intended to
The models without extra prefix / suffix ' (coatnet_0_224, maxvit_tiny_224, etc), are intended to
match paper , BUT , without any official pretrained weights it ' s difficult to confirm a 100 % match.
match paper , BUT , without any official pretrained weights it ' s difficult to confirm a 100 % match.
# FIXME / WARNING
This impl remains a WIP , some configs and models may vanish or change . . .
Papers :
Papers :
MaxViT : Multi - Axis Vision Transformer - https : / / arxiv . org / abs / 2204.01697
MaxViT : Multi - Axis Vision Transformer - https : / / arxiv . org / abs / 2204.01697
@ -76,6 +73,8 @@ class MaxxVitTransformerCfg:
partition_ratio : int = 32
partition_ratio : int = 32
window_size : Optional [ Tuple [ int , int ] ] = None
window_size : Optional [ Tuple [ int , int ] ] = None
grid_size : Optional [ Tuple [ int , int ] ] = None
grid_size : Optional [ Tuple [ int , int ] ] = None
no_block_attn : bool = False # disable window block attention for maxvit (ie only grid)
use_nchw_attn : bool = False # for MaxViT variants (not used for CoAt), keep tensors in NCHW order
init_values : Optional [ float ] = None
init_values : Optional [ float ] = None
act_layer : str = ' gelu '
act_layer : str = ' gelu '
norm_layer : str = ' layernorm2d '
norm_layer : str = ' layernorm2d '
@ -889,19 +888,17 @@ class MaxxVitBlock(nn.Module):
stride : int = 1 ,
stride : int = 1 ,
conv_cfg : MaxxVitConvCfg = MaxxVitConvCfg ( ) ,
conv_cfg : MaxxVitConvCfg = MaxxVitConvCfg ( ) ,
transformer_cfg : MaxxVitTransformerCfg = MaxxVitTransformerCfg ( ) ,
transformer_cfg : MaxxVitTransformerCfg = MaxxVitTransformerCfg ( ) ,
use_nchw_attn : bool = False , # FIXME move to cfg? True is ~20-30% faster on TPU, 5-10% slower on GPU
use_block_attn : bool = True , # FIXME for testing ConvNeXt conv w/o block attention
drop_path : float = 0. ,
drop_path : float = 0. ,
) :
) :
super ( ) . __init__ ( )
super ( ) . __init__ ( )
self . nchw_attn = transformer_cfg . use_nchw_attn
conv_cls = ConvNeXtBlock if conv_cfg . block_type == ' convnext ' else MbConvBlock
conv_cls = ConvNeXtBlock if conv_cfg . block_type == ' convnext ' else MbConvBlock
self . conv = conv_cls ( dim , dim_out , stride = stride , cfg = conv_cfg , drop_path = drop_path )
self . conv = conv_cls ( dim , dim_out , stride = stride , cfg = conv_cfg , drop_path = drop_path )
attn_kwargs = dict ( dim = dim_out , cfg = transformer_cfg , drop_path = drop_path )
attn_kwargs = dict ( dim = dim_out , cfg = transformer_cfg , drop_path = drop_path )
partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttentionCl
partition_layer = PartitionAttention2d if self . nchw_attn else PartitionAttentionCl
self . nchw_attn = use_nchw_attn
self . attn_block = None if transformer_cfg . no_block_attn else partition_layer ( * * attn_kwargs )
self . attn_block = partition_layer ( * * attn_kwargs ) if use_block_attn else None
self . attn_grid = partition_layer ( partition_type = ' grid ' , * * attn_kwargs )
self . attn_grid = partition_layer ( partition_type = ' grid ' , * * attn_kwargs )
def init_weights ( self , scheme = ' ' ) :
def init_weights ( self , scheme = ' ' ) :
@ -1084,26 +1081,48 @@ class NormMlpHead(nn.Module):
hidden_size = None ,
hidden_size = None ,
pool_type = ' avg ' ,
pool_type = ' avg ' ,
drop_rate = 0. ,
drop_rate = 0. ,
norm_layer = nn . LayerNorm ,
norm_layer = ' layernorm2d ' ,
act_layer = nn . Tanh ,
act_layer = ' tanh ' ,
) :
) :
super ( ) . __init__ ( )
super ( ) . __init__ ( )
self . drop_rate = drop_rate
self . drop_rate = drop_rate
self . in_features = in_features
self . hidden_size = hidden_size
self . num_features = in_features
self . num_features = in_features
self . use_conv = not pool_type
norm_layer = get_norm_layer ( norm_layer )
act_layer = get_act_layer ( act_layer )
linear_layer = partial ( nn . Conv2d , kernel_size = 1 ) if self . use_conv else nn . Linear
self . global_pool = SelectAdaptivePool2d ( pool_type = pool_type )
self . global_pool = SelectAdaptivePool2d ( pool_type = pool_type )
self . norm = norm_layer ( in_features )
self . norm = norm_layer ( in_features )
self . flatten = nn . Flatten ( 1 ) if pool_type else nn . Identity ( )
self . flatten = nn . Flatten ( 1 ) if pool_type else nn . Identity ( )
if hidden_size :
if hidden_size :
self . pre_logits = nn . Sequential ( OrderedDict ( [
self . pre_logits = nn . Sequential ( OrderedDict ( [
( ' fc ' , nn. Linea r( in_features , hidden_size ) ) ,
( ' fc ' , linear_laye r( in_features , hidden_size ) ) ,
( ' act ' , act_layer ( ) ) ,
( ' act ' , act_layer ( ) ) ,
] ) )
] ) )
self . num_features = hidden_size
self . num_features = hidden_size
else :
else :
self . pre_logits = nn . Identity ( )
self . pre_logits = nn . Identity ( )
self . drop = nn . Dropout ( self . drop_rate )
self . drop = nn . Dropout ( self . drop_rate )
self . fc = nn . Linear ( self . num_features , num_classes ) if num_classes > 0 else nn . Identity ( )
self . fc = linear_layer ( self . num_features , num_classes ) if num_classes > 0 else nn . Identity ( )
def reset ( self , num_classes , global_pool = None ) :
if global_pool is not None :
self . global_pool = SelectAdaptivePool2d ( pool_type = global_pool )
self . flatten = nn . Flatten ( 1 ) if global_pool else nn . Identity ( )
self . use_conv = self . global_pool . is_identity ( )
linear_layer = partial ( nn . Conv2d , kernel_size = 1 ) if self . use_conv else nn . Linear
if self . hidden_size :
if ( ( isinstance ( self . pre_logits . fc , nn . Conv2d ) and not self . use_conv ) or
( isinstance ( self . pre_logits . fc , nn . Linear ) and self . use_conv ) ) :
with torch . no_grad ( ) :
new_fc = linear_layer ( self . in_features , self . hidden_size )
new_fc . weight . copy_ ( self . pre_logits . fc . weight . reshape ( new_fc . weight . shape ) )
new_fc . bias . copy_ ( self . pre_logits . fc . bias )
self . pre_logits . fc = new_fc
self . fc = linear_layer ( self . num_features , num_classes ) if num_classes > 0 else nn . Identity ( )
def forward ( self , x , pre_logits : bool = False ) :
def forward ( self , x , pre_logits : bool = False ) :
x = self . global_pool ( x )
x = self . global_pool ( x )
@ -1163,6 +1182,7 @@ class MaxxVit(nn.Module):
self . num_features = self . embed_dim = cfg . embed_dim [ - 1 ]
self . num_features = self . embed_dim = cfg . embed_dim [ - 1 ]
self . drop_rate = drop_rate
self . drop_rate = drop_rate
self . grad_checkpointing = False
self . grad_checkpointing = False
self . feature_info = [ ]
self . stem = Stem (
self . stem = Stem (
in_chs = in_chans ,
in_chs = in_chans ,
@ -1173,8 +1193,8 @@ class MaxxVit(nn.Module):
norm_layer = cfg . conv_cfg . norm_layer ,
norm_layer = cfg . conv_cfg . norm_layer ,
norm_eps = cfg . conv_cfg . norm_eps ,
norm_eps = cfg . conv_cfg . norm_eps ,
)
)
stride = self . stem . stride
stride = self . stem . stride
self . feature_info + = [ dict ( num_chs = self . stem . out_chs , reduction = 2 , module = ' stem ' ) ]
feat_size = tuple ( [ i / / s for i , s in zip ( img_size , to_2tuple ( stride ) ) ] )
feat_size = tuple ( [ i / / s for i , s in zip ( img_size , to_2tuple ( stride ) ) ] )
num_stages = len ( cfg . embed_dim )
num_stages = len ( cfg . embed_dim )
@ -1198,15 +1218,17 @@ class MaxxVit(nn.Module):
) ]
) ]
stride * = stage_stride
stride * = stage_stride
in_chs = out_chs
in_chs = out_chs
self . feature_info + = [ dict ( num_chs = out_chs , reduction = stride , module = f ' stages. { i } ' ) ]
self . stages = nn . Sequential ( * stages )
self . stages = nn . Sequential ( * stages )
final_norm_layer = partial ( get_norm_layer ( cfg . transformer_cfg . norm_layer ) , eps = cfg . transformer_cfg . norm_eps )
final_norm_layer = partial ( get_norm_layer ( cfg . transformer_cfg . norm_layer ) , eps = cfg . transformer_cfg . norm_eps )
if cfg . head_hidden_size :
self . head_hidden_size = cfg . head_hidden_size
if self . head_hidden_size :
self . norm = nn . Identity ( )
self . norm = nn . Identity ( )
self . head = NormMlpHead (
self . head = NormMlpHead (
self . num_features ,
self . num_features ,
num_classes ,
num_classes ,
hidden_size = cfg . head_hidden_size ,
hidden_size = self . head_hidden_size ,
pool_type = global_pool ,
pool_type = global_pool ,
drop_rate = drop_rate ,
drop_rate = drop_rate ,
norm_layer = final_norm_layer ,
norm_layer = final_norm_layer ,
@ -1253,9 +1275,7 @@ class MaxxVit(nn.Module):
def reset_classifier ( self , num_classes , global_pool = None ) :
def reset_classifier ( self , num_classes , global_pool = None ) :
self . num_classes = num_classes
self . num_classes = num_classes
if global_pool is None :
self . head . reset ( num_classes , global_pool )
global_pool = self . head . global_pool . pool_type
self . head = ClassifierHead ( self . num_features , num_classes , pool_type = global_pool , drop_rate = self . drop_rate )
def forward_features ( self , x ) :
def forward_features ( self , x ) :
x = self . stem ( x )
x = self . stem ( x )
@ -1376,6 +1396,7 @@ def _next_cfg(
transformer_norm_layer = ' layernorm2d ' ,
transformer_norm_layer = ' layernorm2d ' ,
transformer_norm_layer_cl = ' layernorm ' ,
transformer_norm_layer_cl = ' layernorm ' ,
window_size = None ,
window_size = None ,
no_block_attn = False ,
init_values = 1e-6 ,
init_values = 1e-6 ,
rel_pos_type = ' mlp ' , # MLP by default for maxxvit
rel_pos_type = ' mlp ' , # MLP by default for maxxvit
rel_pos_dim = 512 ,
rel_pos_dim = 512 ,
@ -1396,6 +1417,7 @@ def _next_cfg(
expand_first = False ,
expand_first = False ,
pool_type = pool_type ,
pool_type = pool_type ,
window_size = window_size ,
window_size = window_size ,
no_block_attn = no_block_attn , # enabled for MaxxViT-V2
init_values = init_values [ 1 ] ,
init_values = init_values [ 1 ] ,
norm_layer = transformer_norm_layer ,
norm_layer = transformer_norm_layer ,
norm_layer_cl = transformer_norm_layer_cl ,
norm_layer_cl = transformer_norm_layer_cl ,
@ -1422,8 +1444,8 @@ def _tf_cfg():
model_cfgs = dict (
model_cfgs = dict (
# Fiddling with configs / defaults / still pretraining
# timm specific CoAtNet configs
coatnet_pico_rw _224 = MaxxVitCfg (
coatnet_pico_rw = MaxxVitCfg (
embed_dim = ( 64 , 128 , 256 , 512 ) ,
embed_dim = ( 64 , 128 , 256 , 512 ) ,
depths = ( 2 , 3 , 5 , 2 ) ,
depths = ( 2 , 3 , 5 , 2 ) ,
stem_width = ( 32 , 64 ) ,
stem_width = ( 32 , 64 ) ,
@ -1432,7 +1454,7 @@ model_cfgs = dict(
conv_attn_ratio = 0.25 ,
conv_attn_ratio = 0.25 ,
) ,
) ,
) ,
) ,
coatnet_nano_rw _224 = MaxxVitCfg (
coatnet_nano_rw = MaxxVitCfg (
embed_dim = ( 64 , 128 , 256 , 512 ) ,
embed_dim = ( 64 , 128 , 256 , 512 ) ,
depths = ( 3 , 4 , 6 , 3 ) ,
depths = ( 3 , 4 , 6 , 3 ) ,
stem_width = ( 32 , 64 ) ,
stem_width = ( 32 , 64 ) ,
@ -1442,7 +1464,7 @@ model_cfgs = dict(
conv_attn_ratio = 0.25 ,
conv_attn_ratio = 0.25 ,
) ,
) ,
) ,
) ,
coatnet_0_rw _224 = MaxxVitCfg (
coatnet_0_rw = MaxxVitCfg (
embed_dim = ( 96 , 192 , 384 , 768 ) ,
embed_dim = ( 96 , 192 , 384 , 768 ) ,
depths = ( 2 , 3 , 7 , 2 ) , # deeper than paper '0' model
depths = ( 2 , 3 , 7 , 2 ) , # deeper than paper '0' model
stem_width = ( 32 , 64 ) ,
stem_width = ( 32 , 64 ) ,
@ -1451,7 +1473,7 @@ model_cfgs = dict(
transformer_shortcut_bias = False ,
transformer_shortcut_bias = False ,
) ,
) ,
) ,
) ,
coatnet_1_rw _224 = MaxxVitCfg (
coatnet_1_rw = MaxxVitCfg (
embed_dim = ( 96 , 192 , 384 , 768 ) ,
embed_dim = ( 96 , 192 , 384 , 768 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
stem_width = ( 32 , 64 ) ,
stem_width = ( 32 , 64 ) ,
@ -1461,7 +1483,7 @@ model_cfgs = dict(
transformer_shortcut_bias = False ,
transformer_shortcut_bias = False ,
)
)
) ,
) ,
coatnet_2_rw _224 = MaxxVitCfg (
coatnet_2_rw = MaxxVitCfg (
embed_dim = ( 128 , 256 , 512 , 1024 ) ,
embed_dim = ( 128 , 256 , 512 , 1024 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
stem_width = ( 64 , 128 ) ,
stem_width = ( 64 , 128 ) ,
@ -1471,7 +1493,7 @@ model_cfgs = dict(
#init_values=1e-6,
#init_values=1e-6,
) ,
) ,
) ,
) ,
coatnet_3_rw _224 = MaxxVitCfg (
coatnet_3_rw = MaxxVitCfg (
embed_dim = ( 192 , 384 , 768 , 1536 ) ,
embed_dim = ( 192 , 384 , 768 , 1536 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
stem_width = ( 96 , 192 ) ,
stem_width = ( 96 , 192 ) ,
@ -1482,8 +1504,8 @@ model_cfgs = dict(
) ,
) ,
) ,
) ,
# Highly experimental configs
# Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
coatnet_bn_0_rw _224 = MaxxVitCfg (
coatnet_bn_0_rw = MaxxVitCfg (
embed_dim = ( 96 , 192 , 384 , 768 ) ,
embed_dim = ( 96 , 192 , 384 , 768 ) ,
depths = ( 2 , 3 , 7 , 2 ) , # deeper than paper '0' model
depths = ( 2 , 3 , 7 , 2 ) , # deeper than paper '0' model
stem_width = ( 32 , 64 ) ,
stem_width = ( 32 , 64 ) ,
@ -1494,7 +1516,7 @@ model_cfgs = dict(
transformer_norm_layer = ' batchnorm2d ' ,
transformer_norm_layer = ' batchnorm2d ' ,
)
)
) ,
) ,
coatnet_rmlp_nano_rw _224 = MaxxVitCfg (
coatnet_rmlp_nano_rw = MaxxVitCfg (
embed_dim = ( 64 , 128 , 256 , 512 ) ,
embed_dim = ( 64 , 128 , 256 , 512 ) ,
depths = ( 3 , 4 , 6 , 3 ) ,
depths = ( 3 , 4 , 6 , 3 ) ,
stem_width = ( 32 , 64 ) ,
stem_width = ( 32 , 64 ) ,
@ -1505,7 +1527,7 @@ model_cfgs = dict(
rel_pos_dim = 384 ,
rel_pos_dim = 384 ,
) ,
) ,
) ,
) ,
coatnet_rmlp_0_rw _224 = MaxxVitCfg (
coatnet_rmlp_0_rw = MaxxVitCfg (
embed_dim = ( 96 , 192 , 384 , 768 ) ,
embed_dim = ( 96 , 192 , 384 , 768 ) ,
depths = ( 2 , 3 , 7 , 2 ) , # deeper than paper '0' model
depths = ( 2 , 3 , 7 , 2 ) , # deeper than paper '0' model
stem_width = ( 32 , 64 ) ,
stem_width = ( 32 , 64 ) ,
@ -1514,7 +1536,7 @@ model_cfgs = dict(
rel_pos_type = ' mlp ' ,
rel_pos_type = ' mlp ' ,
) ,
) ,
) ,
) ,
coatnet_rmlp_1_rw _224 = MaxxVitCfg (
coatnet_rmlp_1_rw = MaxxVitCfg (
embed_dim = ( 96 , 192 , 384 , 768 ) ,
embed_dim = ( 96 , 192 , 384 , 768 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
stem_width = ( 32 , 64 ) ,
stem_width = ( 32 , 64 ) ,
@ -1526,7 +1548,7 @@ model_cfgs = dict(
rel_pos_dim = 384 , # was supposed to be 512, woops
rel_pos_dim = 384 , # was supposed to be 512, woops
) ,
) ,
) ,
) ,
coatnet_rmlp_1_rw2 _224 = MaxxVitCfg (
coatnet_rmlp_1_rw2 = MaxxVitCfg (
embed_dim = ( 96 , 192 , 384 , 768 ) ,
embed_dim = ( 96 , 192 , 384 , 768 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
stem_width = ( 32 , 64 ) ,
stem_width = ( 32 , 64 ) ,
@ -1536,7 +1558,7 @@ model_cfgs = dict(
rel_pos_dim = 512 , # was supposed to be 512, woops
rel_pos_dim = 512 , # was supposed to be 512, woops
) ,
) ,
) ,
) ,
coatnet_rmlp_2_rw _224 = MaxxVitCfg (
coatnet_rmlp_2_rw = MaxxVitCfg (
embed_dim = ( 128 , 256 , 512 , 1024 ) ,
embed_dim = ( 128 , 256 , 512 , 1024 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
stem_width = ( 64 , 128 ) ,
stem_width = ( 64 , 128 ) ,
@ -1547,7 +1569,7 @@ model_cfgs = dict(
rel_pos_type = ' mlp '
rel_pos_type = ' mlp '
) ,
) ,
) ,
) ,
coatnet_rmlp_3_rw _224 = MaxxVitCfg (
coatnet_rmlp_3_rw = MaxxVitCfg (
embed_dim = ( 192 , 384 , 768 , 1536 ) ,
embed_dim = ( 192 , 384 , 768 , 1536 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
stem_width = ( 96 , 192 ) ,
stem_width = ( 96 , 192 ) ,
@ -1559,14 +1581,14 @@ model_cfgs = dict(
) ,
) ,
) ,
) ,
coatnet_nano_cc _224 = MaxxVitCfg (
coatnet_nano_cc = MaxxVitCfg (
embed_dim = ( 64 , 128 , 256 , 512 ) ,
embed_dim = ( 64 , 128 , 256 , 512 ) ,
depths = ( 3 , 4 , 6 , 3 ) ,
depths = ( 3 , 4 , 6 , 3 ) ,
stem_width = ( 32 , 64 ) ,
stem_width = ( 32 , 64 ) ,
block_type = ( ' C ' , ' C ' , ( ' C ' , ' T ' ) , ( ' C ' , ' T ' ) ) ,
block_type = ( ' C ' , ' C ' , ( ' C ' , ' T ' ) , ( ' C ' , ' T ' ) ) ,
* * _rw_coat_cfg ( ) ,
* * _rw_coat_cfg ( ) ,
) ,
) ,
coatnext_nano_rw _224 = MaxxVitCfg (
coatnext_nano_rw = MaxxVitCfg (
embed_dim = ( 64 , 128 , 256 , 512 ) ,
embed_dim = ( 64 , 128 , 256 , 512 ) ,
depths = ( 3 , 4 , 6 , 3 ) ,
depths = ( 3 , 4 , 6 , 3 ) ,
stem_width = ( 32 , 64 ) ,
stem_width = ( 32 , 64 ) ,
@ -1578,99 +1600,95 @@ model_cfgs = dict(
) ,
) ,
# Trying to be like the CoAtNet paper configs
# Trying to be like the CoAtNet paper configs
coatnet_0 _224 = MaxxVitCfg (
coatnet_0 = MaxxVitCfg (
embed_dim = ( 96 , 192 , 384 , 768 ) ,
embed_dim = ( 96 , 192 , 384 , 768 ) ,
depths = ( 2 , 3 , 5 , 2 ) ,
depths = ( 2 , 3 , 5 , 2 ) ,
stem_width = 64 ,
stem_width = 64 ,
head_hidden_size = 768 ,
) ,
) ,
coatnet_1 _224 = MaxxVitCfg (
coatnet_1 = MaxxVitCfg (
embed_dim = ( 96 , 192 , 384 , 768 ) ,
embed_dim = ( 96 , 192 , 384 , 768 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
stem_width = 64 ,
stem_width = 64 ,
head_hidden_size = 768 ,
) ,
) ,
coatnet_2 _224 = MaxxVitCfg (
coatnet_2 = MaxxVitCfg (
embed_dim = ( 128 , 256 , 512 , 1024 ) ,
embed_dim = ( 128 , 256 , 512 , 1024 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
stem_width = 128 ,
stem_width = 128 ,
head_hidden_size = 1024 ,
) ,
) ,
coatnet_3 _224 = MaxxVitCfg (
coatnet_3 = MaxxVitCfg (
embed_dim = ( 192 , 384 , 768 , 1536 ) ,
embed_dim = ( 192 , 384 , 768 , 1536 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
stem_width = 192 ,
stem_width = 192 ,
head_hidden_size = 1536 ,
) ,
) ,
coatnet_4 _224 = MaxxVitCfg (
coatnet_4 = MaxxVitCfg (
embed_dim = ( 192 , 384 , 768 , 1536 ) ,
embed_dim = ( 192 , 384 , 768 , 1536 ) ,
depths = ( 2 , 12 , 28 , 2 ) ,
depths = ( 2 , 12 , 28 , 2 ) ,
stem_width = 192 ,
stem_width = 192 ,
head_hidden_size = 1536 ,
) ,
) ,
coatnet_5 _224 = MaxxVitCfg (
coatnet_5 = MaxxVitCfg (
embed_dim = ( 256 , 512 , 1280 , 2048 ) ,
embed_dim = ( 256 , 512 , 1280 , 2048 ) ,
depths = ( 2 , 12 , 28 , 2 ) ,
depths = ( 2 , 12 , 28 , 2 ) ,
stem_width = 192 ,
stem_width = 192 ,
head_hidden_size = 2048 ,
) ,
) ,
# Experimental MaxVit configs
# Experimental MaxVit configs
maxvit_pico_rw _256 = MaxxVitCfg (
maxvit_pico_rw = MaxxVitCfg (
embed_dim = ( 32 , 64 , 128 , 256 ) ,
embed_dim = ( 32 , 64 , 128 , 256 ) ,
depths = ( 2 , 2 , 5 , 2 ) ,
depths = ( 2 , 2 , 5 , 2 ) ,
block_type = ( ' M ' , ) * 4 ,
block_type = ( ' M ' , ) * 4 ,
stem_width = ( 24 , 32 ) ,
stem_width = ( 24 , 32 ) ,
* * _rw_max_cfg ( ) ,
* * _rw_max_cfg ( ) ,
) ,
) ,
maxvit_nano_rw _256 = MaxxVitCfg (
maxvit_nano_rw = MaxxVitCfg (
embed_dim = ( 64 , 128 , 256 , 512 ) ,
embed_dim = ( 64 , 128 , 256 , 512 ) ,
depths = ( 1 , 2 , 3 , 1 ) ,
depths = ( 1 , 2 , 3 , 1 ) ,
block_type = ( ' M ' , ) * 4 ,
block_type = ( ' M ' , ) * 4 ,
stem_width = ( 32 , 64 ) ,
stem_width = ( 32 , 64 ) ,
* * _rw_max_cfg ( ) ,
* * _rw_max_cfg ( ) ,
) ,
) ,
maxvit_tiny_rw _224 = MaxxVitCfg (
maxvit_tiny_rw = MaxxVitCfg (
embed_dim = ( 64 , 128 , 256 , 512 ) ,
embed_dim = ( 64 , 128 , 256 , 512 ) ,
depths = ( 2 , 2 , 5 , 2 ) ,
depths = ( 2 , 2 , 5 , 2 ) ,
block_type = ( ' M ' , ) * 4 ,
block_type = ( ' M ' , ) * 4 ,
stem_width = ( 32 , 64 ) ,
stem_width = ( 32 , 64 ) ,
* * _rw_max_cfg ( ) ,
* * _rw_max_cfg ( ) ,
) ,
) ,
maxvit_tiny_ rw_256 = MaxxVitCfg (
maxvit_tiny_ pm = MaxxVitCfg (
embed_dim = ( 64 , 128 , 256 , 512 ) ,
embed_dim = ( 64 , 128 , 256 , 512 ) ,
depths = ( 2 , 2 , 5 , 2 ) ,
depths = ( 2 , 2 , 5 , 2 ) ,
block_type = ( ' M' , ) * 4 ,
block_type = ( ' P M' , ) * 4 ,
stem_width = ( 32 , 64 ) ,
stem_width = ( 32 , 64 ) ,
* * _rw_max_cfg ( ) ,
* * _rw_max_cfg ( ) ,
) ,
) ,
maxvit_rmlp_pico_rw _256 = MaxxVitCfg (
maxvit_rmlp_pico_rw = MaxxVitCfg (
embed_dim = ( 32 , 64 , 128 , 256 ) ,
embed_dim = ( 32 , 64 , 128 , 256 ) ,
depths = ( 2 , 2 , 5 , 2 ) ,
depths = ( 2 , 2 , 5 , 2 ) ,
block_type = ( ' M ' , ) * 4 ,
block_type = ( ' M ' , ) * 4 ,
stem_width = ( 24 , 32 ) ,
stem_width = ( 24 , 32 ) ,
* * _rw_max_cfg ( rel_pos_type = ' mlp ' ) ,
* * _rw_max_cfg ( rel_pos_type = ' mlp ' ) ,
) ,
) ,
maxvit_rmlp_nano_rw _256 = MaxxVitCfg (
maxvit_rmlp_nano_rw = MaxxVitCfg (
embed_dim = ( 64 , 128 , 256 , 512 ) ,
embed_dim = ( 64 , 128 , 256 , 512 ) ,
depths = ( 1 , 2 , 3 , 1 ) ,
depths = ( 1 , 2 , 3 , 1 ) ,
block_type = ( ' M ' , ) * 4 ,
block_type = ( ' M ' , ) * 4 ,
stem_width = ( 32 , 64 ) ,
stem_width = ( 32 , 64 ) ,
* * _rw_max_cfg ( rel_pos_type = ' mlp ' ) ,
* * _rw_max_cfg ( rel_pos_type = ' mlp ' ) ,
) ,
) ,
maxvit_rmlp_tiny_rw _256 = MaxxVitCfg (
maxvit_rmlp_tiny_rw = MaxxVitCfg (
embed_dim = ( 64 , 128 , 256 , 512 ) ,
embed_dim = ( 64 , 128 , 256 , 512 ) ,
depths = ( 2 , 2 , 5 , 2 ) ,
depths = ( 2 , 2 , 5 , 2 ) ,
block_type = ( ' M ' , ) * 4 ,
block_type = ( ' M ' , ) * 4 ,
stem_width = ( 32 , 64 ) ,
stem_width = ( 32 , 64 ) ,
* * _rw_max_cfg ( rel_pos_type = ' mlp ' ) ,
* * _rw_max_cfg ( rel_pos_type = ' mlp ' ) ,
) ,
) ,
maxvit_rmlp_small_rw_224 = MaxxVitCfg (
maxvit_rmlp_small_rw = MaxxVitCfg (
embed_dim = ( 96 , 192 , 384 , 768 ) ,
depths = ( 2 , 2 , 5 , 2 ) ,
block_type = ( ' M ' , ) * 4 ,
stem_width = ( 32 , 64 ) ,
* * _rw_max_cfg (
rel_pos_type = ' mlp ' ,
init_values = 1e-6 ,
) ,
) ,
maxvit_rmlp_small_rw_256 = MaxxVitCfg (
embed_dim = ( 96 , 192 , 384 , 768 ) ,
embed_dim = ( 96 , 192 , 384 , 768 ) ,
depths = ( 2 , 2 , 5 , 2 ) ,
depths = ( 2 , 2 , 5 , 2 ) ,
block_type = ( ' M ' , ) * 4 ,
block_type = ( ' M ' , ) * 4 ,
@ -1680,17 +1698,7 @@ model_cfgs = dict(
init_values = 1e-6 ,
init_values = 1e-6 ,
) ,
) ,
) ,
) ,
maxvit_rmlp_base_rw_224 = MaxxVitCfg (
maxvit_rmlp_base_rw = MaxxVitCfg (
embed_dim = ( 96 , 192 , 384 , 768 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
block_type = ( ' M ' , ) * 4 ,
stem_width = ( 32 , 64 ) ,
head_hidden_size = 768 ,
* * _rw_max_cfg (
rel_pos_type = ' mlp ' ,
) ,
) ,
maxvit_rmlp_base_rw_384 = MaxxVitCfg (
embed_dim = ( 96 , 192 , 384 , 768 ) ,
embed_dim = ( 96 , 192 , 384 , 768 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
depths = ( 2 , 6 , 14 , 2 ) ,
block_type = ( ' M ' , ) * 4 ,
block_type = ( ' M ' , ) * 4 ,
@ -1701,15 +1709,7 @@ model_cfgs = dict(
) ,
) ,
) ,
) ,
maxvit_tiny_pm_256 = MaxxVitCfg (
maxxvit_rmlp_nano_rw = MaxxVitCfg (
embed_dim = ( 64 , 128 , 256 , 512 ) ,
depths = ( 2 , 2 , 5 , 2 ) ,
block_type = ( ' PM ' , ) * 4 ,
stem_width = ( 32 , 64 ) ,
* * _rw_max_cfg ( ) ,
) ,
maxxvit_rmlp_nano_rw_256 = MaxxVitCfg (
embed_dim = ( 64 , 128 , 256 , 512 ) ,
embed_dim = ( 64 , 128 , 256 , 512 ) ,
depths = ( 1 , 2 , 3 , 1 ) ,
depths = ( 1 , 2 , 3 , 1 ) ,
block_type = ( ' M ' , ) * 4 ,
block_type = ( ' M ' , ) * 4 ,
@ -1717,33 +1717,50 @@ model_cfgs = dict(
weight_init = ' normal ' ,
weight_init = ' normal ' ,
* * _next_cfg ( ) ,
* * _next_cfg ( ) ,
) ,
) ,
maxxvit_rmlp_tiny_rw _256 = MaxxVitCfg (
maxxvit_rmlp_tiny_rw = MaxxVitCfg (
embed_dim = ( 64 , 128 , 256 , 512 ) ,
embed_dim = ( 64 , 128 , 256 , 512 ) ,
depths = ( 2 , 2 , 5 , 2 ) ,
depths = ( 2 , 2 , 5 , 2 ) ,
block_type = ( ' M ' , ) * 4 ,
block_type = ( ' M ' , ) * 4 ,
stem_width = ( 32 , 64 ) ,
stem_width = ( 32 , 64 ) ,
* * _next_cfg ( ) ,
* * _next_cfg ( ) ,
) ,
) ,
maxxvit_rmlp_small_rw _256 = MaxxVitCfg (
maxxvit_rmlp_small_rw = MaxxVitCfg (
embed_dim = ( 96 , 192 , 384 , 768 ) ,
embed_dim = ( 96 , 192 , 384 , 768 ) ,
depths = ( 2 , 2 , 5 , 2 ) ,
depths = ( 2 , 2 , 5 , 2 ) ,
block_type = ( ' M ' , ) * 4 ,
block_type = ( ' M ' , ) * 4 ,
stem_width = ( 48 , 96 ) ,
stem_width = ( 48 , 96 ) ,
* * _next_cfg ( ) ,
* * _next_cfg ( ) ,
) ,
) ,
maxxvit_rmlp_base_rw_224 = MaxxVitCfg (
maxxvitv2_nano_rw = MaxxVitCfg (
embed_dim = ( 96 , 192 , 384 , 768 ) ,
embed_dim = ( 96 , 192 , 384 , 768 ) ,
depths = ( 2, 6 , 14 , 2 ) ,
depths = ( 1, 2 , 3 , 1 ) ,
block_type = ( ' M ' , ) * 4 ,
block_type = ( ' M ' , ) * 4 ,
stem_width = ( 48 , 96 ) ,
stem_width = ( 48 , 96 ) ,
* * _next_cfg ( ) ,
weight_init = ' normal ' ,
* * _next_cfg (
no_block_attn = True ,
rel_pos_type = ' bias ' ,
) ,
) ,
) ,
maxxvit_rmlp_large_rw_224 = MaxxVitCfg (
maxxvit v2_rmlp_base_rw = MaxxVitCfg (
embed_dim = ( 128 , 256 , 512 , 1024 ) ,
embed_dim = ( 128 , 256 , 512 , 1024 ) ,
depths = ( 2 , 6 , 12 , 2 ) ,
depths = ( 2 , 6 , 12 , 2 ) ,
block_type = ( ' M ' , ) * 4 ,
block_type = ( ' M ' , ) * 4 ,
stem_width = ( 64 , 128 ) ,
stem_width = ( 64 , 128 ) ,
* * _next_cfg ( ) ,
* * _next_cfg (
no_block_attn = True ,
) ,
) ,
maxxvitv2_rmlp_large_rw = MaxxVitCfg (
embed_dim = ( 160 , 320 , 640 , 1280 ) ,
depths = ( 2 , 6 , 16 , 2 ) ,
block_type = ( ' M ' , ) * 4 ,
stem_width = ( 80 , 160 ) ,
head_hidden_size = 1280 ,
* * _next_cfg (
no_block_attn = True ,
) ,
) ,
) ,
# Trying to be like the MaxViT paper configs
# Trying to be like the MaxViT paper configs
@ -1795,11 +1812,29 @@ model_cfgs = dict(
)
)
def checkpoint_filter_fn ( state_dict , model : nn . Module ) :
model_state_dict = model . state_dict ( )
out_dict = { }
for k , v in state_dict . items ( ) :
if k in model_state_dict and v . ndim != model_state_dict [ k ] . ndim and v . numel ( ) == model_state_dict [ k ] . numel ( ) :
# adapt between conv2d / linear layers
assert v . ndim in ( 2 , 4 )
v = v . reshape ( model_state_dict [ k ] . shape )
out_dict [ k ] = v
return out_dict
def _create_maxxvit ( variant , cfg_variant = None , pretrained = False , * * kwargs ) :
def _create_maxxvit ( variant , cfg_variant = None , pretrained = False , * * kwargs ) :
if cfg_variant is None :
if variant in model_cfgs :
cfg_variant = variant
else :
cfg_variant = ' _ ' . join ( variant . split ( ' _ ' ) [ : - 1 ] )
return build_model_with_cfg (
return build_model_with_cfg (
MaxxVit , variant , pretrained ,
MaxxVit , variant , pretrained ,
model_cfg = model_cfgs [ variant ] if not cfg_variant else model_cfgs [ cfg_variant ] ,
model_cfg = model_cfgs [ cfg_variant] ,
feature_cfg = dict ( flatten_sequential = True ) ,
feature_cfg = dict ( flatten_sequential = True ) ,
pretrained_filter_fn = checkpoint_filter_fn ,
* * kwargs )
* * kwargs )
@ -1815,155 +1850,218 @@ def _cfg(url='', **kwargs):
default_cfgs = generate_default_cfgs ( {
default_cfgs = generate_default_cfgs ( {
# Fiddling with configs / defaults / still pretraining
# timm specific CoAtNet configs, ImageNet-1k pretrain, fixed rel-pos
' coatnet_pico_rw_224 ' : _cfg ( url = ' ' ) ,
' coatnet_pico_rw_224.untrained ' : _cfg ( url = ' ' ) ,
' coatnet_nano_rw_224 ' : _cfg (
' coatnet_nano_rw_224.sw_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth ' ,
crop_pct = 0.9 ) ,
crop_pct = 0.9 ) ,
' coatnet_0_rw_224 ' : _cfg (
' coatnet_0_rw_224.sw_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth ' ) ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth ' ) ,
' coatnet_1_rw_224 ' : _cfg (
' coatnet_1_rw_224.sw_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth '
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth '
) ,
) ,
' coatnet_2_rw_224 ' : _cfg ( url = ' ' ) ,
' coatnet_3_rw_224 ' : _cfg ( url = ' ' ) ,
# Highly experimental configs
# timm specific CoAtNet configs, ImageNet-12k pretrain w/ 1k fine-tune, fixed rel-pos
' coatnet_bn_0_rw_224 ' : _cfg (
' coatnet_2_rw_224.sw_in12k_ft_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ) ,
#'coatnet_3_rw_224.untrained': _cfg(url=''),
# Experimental CoAtNet configs w/ ImageNet-12k pretrain -> 1k fine-tune (different norm layers, MLP rel-pos)
' coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ) ,
' coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ) ,
' coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ,
input_size = ( 3 , 384 , 384 ) , pool_size = ( 12 , 12 ) , crop_pct = 1.0 , crop_mode = ' squash ' ) ,
# Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
' coatnet_bn_0_rw_224.sw_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth ' ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ,
crop_pct = 0.95 ) ,
crop_pct = 0.95 ) ,
' coatnet_rmlp_nano_rw_224 ' : _cfg (
' coatnet_rmlp_nano_rw_224.sw_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth ' ,
crop_pct = 0.9 ) ,
crop_pct = 0.9 ) ,
' coatnet_rmlp_0_rw_224 ' : _cfg ( url = ' ' ) ,
' coatnet_rmlp_0_rw_224.untrained ' : _cfg ( url = ' ' ) ,
' coatnet_rmlp_1_rw_224 ' : _cfg (
' coatnet_rmlp_1_rw_224.sw_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth ' ) ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth ' ) ,
' coatnet_rmlp_1_rw2_224 ' : _cfg ( url = ' ' ) ,
' coatnet_rmlp_ 2_rw_224.sw_in1k' : _cfg (
' coatnet_rmlp_2_rw_224 ' : _cfg (
hf_hub_id = ' timm/ ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth ' ) ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth ' ) ,
' coatnet_rmlp_3_rw_224 ' : _cfg ( url = ' ' ) ,
' coatnet_rmlp_3_rw_224.untrained ' : _cfg ( url = ' ' ) ,
' coatnet_nano_cc_224 ' : _cfg ( url = ' ' ) ,
' coatnet_nano_cc_224.untrained ' : _cfg ( url = ' ' ) ,
' coatnext_nano_rw_224 ' : _cfg (
' coatnext_nano_rw_224.sw_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth ' ,
crop_pct = 0.9 ) ,
crop_pct = 0.9 ) ,
# Trying to be like the CoAtNet paper configs
# ImagenNet-12k pretrain CoAtNet
' coatnet_0_224 ' : _cfg ( url = ' ' ) ,
' coatnet_2_rw_224.sw_in12k ' : _cfg (
' coatnet_1_224 ' : _cfg ( url = ' ' ) ,
hf_hub_id = ' timm/ ' ,
' coatnet_2_224 ' : _cfg ( url = ' ' ) ,
num_classes = 11821 ) ,
' coatnet_3_224 ' : _cfg ( url = ' ' ) ,
' coatnet_3_rw_224.sw_in12k ' : _cfg (
' coatnet_4_224 ' : _cfg ( url = ' ' ) ,
hf_hub_id = ' timm/ ' ,
' coatnet_5_224 ' : _cfg ( url = ' ' ) ,
num_classes = 11821 ) ,
' coatnet_rmlp_1_rw2_224.sw_in12k ' : _cfg (
# Experimental configs
hf_hub_id = ' timm/ ' ,
' maxvit_pico_rw_256 ' : _cfg ( url = ' ' , input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
num_classes = 11821 ) ,
' maxvit_nano_rw_256 ' : _cfg (
' coatnet_rmlp_2_rw_224.sw_in12k ' : _cfg (
hf_hub_id = ' timm/ ' ,
num_classes = 11821 ) ,
# Trying to be like the CoAtNet paper configs (will adapt if 'tf' weights are ever released)
' coatnet_0_224.untrained ' : _cfg ( url = ' ' ) ,
' coatnet_1_224.untrained ' : _cfg ( url = ' ' ) ,
' coatnet_2_224.untrained ' : _cfg ( url = ' ' ) ,
' coatnet_3_224.untrained ' : _cfg ( url = ' ' ) ,
' coatnet_4_224.untrained ' : _cfg ( url = ' ' ) ,
' coatnet_5_224.untrained ' : _cfg ( url = ' ' ) ,
# timm specific MaxVit configs, ImageNet-1k pretrain or untrained
' maxvit_pico_rw_256.untrained ' : _cfg ( url = ' ' , input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
' maxvit_nano_rw_256.sw_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth ' ,
input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
' maxvit_tiny_rw_224 ' : _cfg (
' maxvit_tiny_rw_224.sw_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth ' ) ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth ' ) ,
' maxvit_tiny_rw_256 ' : _cfg (
' maxvit_tiny_rw_256 .untrained ' : _cfg (
url = ' ' ,
url = ' ' ,
input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
' maxvit_rmlp_pico_rw_256 ' : _cfg (
' maxvit_tiny_pm_256.untrained ' : _cfg ( url = ' ' , input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
# timm specific MaxVit w/ MLP rel-pos, ImageNet-1k pretrain
' maxvit_rmlp_pico_rw_256.sw_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth ' ,
input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
' maxvit_rmlp_nano_rw_256 ' : _cfg (
' maxvit_rmlp_nano_rw_256.sw_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth ' ,
input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
' maxvit_rmlp_tiny_rw_256 ' : _cfg (
' maxvit_rmlp_tiny_rw_256.sw_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth ' ,
input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
' maxvit_rmlp_small_rw_224 ' : _cfg (
' maxvit_rmlp_small_rw_224.sw_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth ' ,
crop_pct = 0.9 ,
crop_pct = 0.9 ,
) ,
) ,
' maxvit_rmlp_small_rw_256 ' : _cfg (
' maxvit_rmlp_small_rw_256 .untrained ' : _cfg (
url = ' ' ,
url = ' ' ,
input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
' maxvit_rmlp_base_rw_224 ' : _cfg (
url = ' ' ,
# timm specific MaxVit w/ ImageNet-12k pretrain and 1k fine-tune
' maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ,
) ,
) ,
' maxvit_rmlp_base_rw_384 ' : _cfg (
' maxvit_rmlp_base_rw_384 .sw_in12k_ft_in1k ' : _cfg (
url= ' ' ,
hf_hub_id= ' timm/ ' ,
input_size = ( 3 , 384 , 384 ) , pool_size = ( 12 , 12 ) ),
input_size = ( 3 , 384 , 384 ) , pool_size = ( 12 , 12 ) , crop_pct = 1.0 , crop_mode = ' squash ' ),
' maxvit_tiny_pm_256 ' : _cfg ( url = ' ' , input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
# timm specific MaxVit w/ ImageNet-12k pretrain
' maxvit_rmlp_base_rw_224.sw_in12k ' : _cfg (
hf_hub_id = ' timm/ ' ,
num_classes = 11821 ,
) ,
' maxxvit_rmlp_nano_rw_256 ' : _cfg (
# timm MaxxViT configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks)
' maxxvit_rmlp_nano_rw_256.sw_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth ' ,
input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
' maxxvit_rmlp_tiny_rw_256 ' : _cfg ( url = ' ' , input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
' maxxvit_rmlp_tiny_rw_256.untrained ' : _cfg ( url = ' ' , input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
' maxxvit_rmlp_small_rw_256 ' : _cfg (
' maxxvit_rmlp_small_rw_256.sw_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth ' ,
url = ' https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth ' ,
input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
' maxxvit_rmlp_base_rw_224 ' : _cfg ( url = ' ' ) ,
' maxxvit_rmlp_large_rw_224 ' : _cfg ( url = ' ' ) ,
# timm MaxxViT-V2 configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks, more width, no block attn)
' maxxvitv2_nano_rw_256.sw_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ,
input_size = ( 3 , 256 , 256 ) , pool_size = ( 8 , 8 ) ) ,
' maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ) ,
' maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k ' : _cfg (
hf_hub_id = ' timm/ ' ,
input_size = ( 3 , 384 , 384 ) , pool_size = ( 12 , 12 ) , crop_pct = 1.0 , crop_mode = ' squash ' ) ,
' maxxvitv2_rmlp_large_rw_224.untrained ' : _cfg ( url = ' ' ) ,
' maxxvitv2_rmlp_base_rw_224.sw_in12k ' : _cfg (
hf_hub_id = ' timm/ ' ,
num_classes = 11821 ) ,
# MaxViT models ported from official Tensorflow impl
# MaxViT models ported from official Tensorflow impl
' maxvit_tiny_tf_224.in1k ' : _cfg (
' maxvit_tiny_tf_224.in1k ' : _cfg (
hf_hub_id = ' timm/maxvit_tiny_tf_224.in1k ' ,
hf_hub_id = ' timm/ ' ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ) ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ) ,
' maxvit_tiny_tf_384.in1k ' : _cfg (
' maxvit_tiny_tf_384.in1k ' : _cfg (
hf_hub_id = ' timm/maxvit_tiny_tf_384.in1k ' ,
hf_hub_id = ' timm/ ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct = 1.0 , crop_mode = ' squash ' ) ,
input_size = ( 3 , 384 , 384 ) , pool_size= ( 12 , 12 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
' maxvit_tiny_tf_512.in1k ' : _cfg (
' maxvit_tiny_tf_512.in1k ' : _cfg (
hf_hub_id = ' timm/maxvit_tiny_tf_512.in1k ' ,
hf_hub_id = ' timm/ ' ,
input_size = ( 3 , 512 , 512 ) , crop_pct = 1.0 , crop_mode = ' squash ' ) ,
input_size = ( 3 , 512 , 512 ) , pool_size= ( 16 , 16 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
' maxvit_small_tf_224.in1k ' : _cfg (
' maxvit_small_tf_224.in1k ' : _cfg (
hf_hub_id = ' timm/maxvit_small_tf_224.in1k ' ,
hf_hub_id = ' timm/ ' ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ) ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ) ,
' maxvit_small_tf_384.in1k ' : _cfg (
' maxvit_small_tf_384.in1k ' : _cfg (
hf_hub_id = ' timm/ maxvit_small_tf_384.in1k ' ,
hf_hub_id = ' timm/ ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
input_size = ( 3 , 384 , 384 ) , pool_size= ( 12 , 12 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
' maxvit_small_tf_512.in1k ' : _cfg (
' maxvit_small_tf_512.in1k ' : _cfg (
hf_hub_id = ' timm/ maxvit_small_tf_512.in1k ' ,
hf_hub_id = ' timm/ ' ,
input_size = ( 3 , 512 , 512 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
input_size = ( 3 , 512 , 512 ) , pool_size= ( 16 , 16 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
' maxvit_base_tf_224.in1k ' : _cfg (
' maxvit_base_tf_224.in1k ' : _cfg (
hf_hub_id = ' timm/ maxvit_base_tf_224.in1k ' ,
hf_hub_id = ' timm/ ' ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ) ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ) ,
' maxvit_base_tf_384.in1k ' : _cfg (
' maxvit_base_tf_384.in1k ' : _cfg (
hf_hub_id = ' timm/ maxvit_base_tf_384.in1k ' ,
hf_hub_id = ' timm/ ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
input_size = ( 3 , 384 , 384 ) , pool_size= ( 12 , 12 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
' maxvit_base_tf_512.in1k ' : _cfg (
' maxvit_base_tf_512.in1k ' : _cfg (
hf_hub_id = ' timm/ maxvit_base_tf_512.in1k ' ,
hf_hub_id = ' timm/ ' ,
input_size = ( 3 , 512 , 512 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
input_size = ( 3 , 512 , 512 ) , pool_size= ( 16 , 16 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
' maxvit_large_tf_224.in1k ' : _cfg (
' maxvit_large_tf_224.in1k ' : _cfg (
hf_hub_id = ' timm/ maxvit_large_tf_224.in1k ' ,
hf_hub_id = ' timm/ ' ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ) ,
mean = IMAGENET_DEFAULT_MEAN , std = IMAGENET_DEFAULT_STD ) ,
' maxvit_large_tf_384.in1k ' : _cfg (
' maxvit_large_tf_384.in1k ' : _cfg (
hf_hub_id = ' timm/ maxvit_large_tf_384.in1k ' ,
hf_hub_id = ' timm/ ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
input_size = ( 3 , 384 , 384 ) , pool_size= ( 12 , 12 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
' maxvit_large_tf_512.in1k ' : _cfg (
' maxvit_large_tf_512.in1k ' : _cfg (
hf_hub_id = ' timm/ maxvit_large_tf_512.in1k ' ,
hf_hub_id = ' timm/ ' ,
input_size = ( 3 , 512 , 512 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
input_size = ( 3 , 512 , 512 ) , pool_size= ( 16 , 16 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
' maxvit_base_tf_224.in21k ' : _cfg (
' maxvit_base_tf_224.in21k ' : _cfg (
url = ' ' ) ,
url = ' ' ) ,
' maxvit_base_tf_384.in21k_ft_in1k ' : _cfg (
' maxvit_base_tf_384.in21k_ft_in1k ' : _cfg (
hf_hub_id = ' timm/ maxvit_base_tf_384.in21k_ft_in1k ' ,
hf_hub_id = ' timm/ ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
input_size = ( 3 , 384 , 384 ) , pool_size= ( 12 , 12 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
' maxvit_base_tf_512.in21k_ft_in1k ' : _cfg (
' maxvit_base_tf_512.in21k_ft_in1k ' : _cfg (
hf_hub_id = ' timm/ maxvit_base_tf_512.in21k_ft_in1k ' ,
hf_hub_id = ' timm/ ' ,
input_size = ( 3 , 512 , 512 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
input_size = ( 3 , 512 , 512 ) , pool_size= ( 16 , 16 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
' maxvit_large_tf_224.in21k ' : _cfg (
' maxvit_large_tf_224.in21k ' : _cfg (
url = ' ' ) ,
url = ' ' ) ,
' maxvit_large_tf_384.in21k_ft_in1k ' : _cfg (
' maxvit_large_tf_384.in21k_ft_in1k ' : _cfg (
hf_hub_id = ' timm/ maxvit_large_tf_384.in21k_ft_in1k ' ,
hf_hub_id = ' timm/ ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
input_size = ( 3 , 384 , 384 ) , pool_size= ( 12 , 12 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
' maxvit_large_tf_512.in21k_ft_in1k ' : _cfg (
' maxvit_large_tf_512.in21k_ft_in1k ' : _cfg (
hf_hub_id = ' timm/ maxvit_large_tf_512.in21k_ft_in1k ' ,
hf_hub_id = ' timm/ ' ,
input_size = ( 3 , 512 , 512 ) , crop_pct = 1.0 , crop_mode = ' squash ' ) ,
input_size = ( 3 , 512 , 512 ) , crop_pct = 1.0 , crop_mode = ' squash ' ) ,
' maxvit_xlarge_tf_224.in21k ' : _cfg (
' maxvit_xlarge_tf_224.in21k ' : _cfg (
url = ' ' ) ,
url = ' ' ) ,
' maxvit_xlarge_tf_384.in21k_ft_in1k ' : _cfg (
' maxvit_xlarge_tf_384.in21k_ft_in1k ' : _cfg (
hf_hub_id = ' timm/ maxvit_xlarge_tf_384.in21k_ft_in1k ' ,
hf_hub_id = ' timm/ ' ,
input_size = ( 3 , 384 , 384 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
input_size = ( 3 , 384 , 384 ) , pool_size= ( 12 , 12 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
' maxvit_xlarge_tf_512.in21k_ft_in1k ' : _cfg (
' maxvit_xlarge_tf_512.in21k_ft_in1k ' : _cfg (
hf_hub_id = ' timm/ maxvit_xlarge_tf_512.in21k_ft_in1k ' ,
hf_hub_id = ' timm/ ' ,
input_size = ( 3 , 512 , 512 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
input_size = ( 3 , 512 , 512 ) , pool_size= ( 16 , 16 ) , crop_pct= 1.0 , crop_mode = ' squash ' ) ,
} )
} )
@ -2027,6 +2125,11 @@ def coatnet_rmlp_2_rw_224(pretrained=False, **kwargs):
return _create_maxxvit ( ' coatnet_rmlp_2_rw_224 ' , pretrained = pretrained , * * kwargs )
return _create_maxxvit ( ' coatnet_rmlp_2_rw_224 ' , pretrained = pretrained , * * kwargs )
@register_model
def coatnet_rmlp_2_rw_384 ( pretrained = False , * * kwargs ) :
return _create_maxxvit ( ' coatnet_rmlp_2_rw_384 ' , pretrained = pretrained , * * kwargs )
@register_model
@register_model
def coatnet_rmlp_3_rw_224 ( pretrained = False , * * kwargs ) :
def coatnet_rmlp_3_rw_224 ( pretrained = False , * * kwargs ) :
return _create_maxxvit ( ' coatnet_rmlp_3_rw_224 ' , pretrained = pretrained , * * kwargs )
return _create_maxxvit ( ' coatnet_rmlp_3_rw_224 ' , pretrained = pretrained , * * kwargs )
@ -2148,13 +2251,23 @@ def maxxvit_rmlp_small_rw_256(pretrained=False, **kwargs):
@register_model
@register_model
def maxxvit_rmlp_base_rw_224 ( pretrained = False , * * kwargs ) :
def maxxvitv2_nano_rw_256 ( pretrained = False , * * kwargs ) :
return _create_maxxvit ( ' maxxvit_rmlp_base_rw_224 ' , pretrained = pretrained , * * kwargs )
return _create_maxxvit ( ' maxxvitv2_nano_rw_256 ' , pretrained = pretrained , * * kwargs )
@register_model
def maxxvitv2_rmlp_base_rw_224 ( pretrained = False , * * kwargs ) :
return _create_maxxvit ( ' maxxvitv2_rmlp_base_rw_224 ' , pretrained = pretrained , * * kwargs )
@register_model
def maxxvitv2_rmlp_base_rw_384 ( pretrained = False , * * kwargs ) :
return _create_maxxvit ( ' maxxvitv2_rmlp_base_rw_384 ' , pretrained = pretrained , * * kwargs )
@register_model
@register_model
def maxxvit_rmlp_large_rw_224 ( pretrained = False , * * kwargs ) :
def maxxvit v2 _rmlp_large_rw_224( pretrained = False , * * kwargs ) :
return _create_maxxvit ( ' maxxvit_rmlp_large_rw_224 ' , pretrained = pretrained , * * kwargs )
return _create_maxxvit ( ' maxxvit v2 _rmlp_large_rw_224' , pretrained = pretrained , * * kwargs )
@register_model
@register_model