Update maxvit_tiny_256 weight to better iter, add coatnet / maxvit / maxxvit model defs for future runs

pull/804/merge
Ross Wightman 2 years ago
parent de40f66536
commit fa8c84eede

@ -82,6 +82,7 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth' url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth'
), ),
'coatnet_2_rw_224': _cfg(url=''), 'coatnet_2_rw_224': _cfg(url=''),
'coatnet_3_rw_224': _cfg(url=''),
# Highly experimental configs # Highly experimental configs
'coatnet_bn_0_rw_224': _cfg( 'coatnet_bn_0_rw_224': _cfg(
@ -94,6 +95,8 @@ default_cfgs = {
'coatnet_rmlp_0_rw_224': _cfg(url=''), 'coatnet_rmlp_0_rw_224': _cfg(url=''),
'coatnet_rmlp_1_rw_224': _cfg( 'coatnet_rmlp_1_rw_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'),
'coatnet_rmlp_2_rw_224': _cfg(url=''),
'coatnet_rmlp_3_rw_224': _cfg(url=''),
'coatnet_nano_cc_224': _cfg(url=''), 'coatnet_nano_cc_224': _cfg(url=''),
'coatnext_nano_rw_224': _cfg(url=''), 'coatnext_nano_rw_224': _cfg(url=''),
@ -122,10 +125,19 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth',
input_size=(3, 256, 256), pool_size=(8, 8)), input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_tiny_rw_256': _cfg( 'maxvit_rmlp_tiny_rw_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-2da819a5.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth',
input_size=(3, 256, 256), pool_size=(8, 8)), input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_small_rw_224': _cfg(
url=''),
'maxvit_rmlp_small_rw_256': _cfg(
url='',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_small_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
# Trying to be like the MaxViT paper configs # Trying to be like the MaxViT paper configs
'maxvit_tiny_224': _cfg(url=''), 'maxvit_tiny_224': _cfg(url=''),
@ -182,7 +194,7 @@ class MaxxVitConvCfg:
attn_layer: str = 'se' attn_layer: str = 'se'
attn_act_layer: str = 'silu' attn_act_layer: str = 'silu'
attn_ratio: float = 0.25 attn_ratio: float = 0.25
init_values: Optional[float] = 1e-5 # for ConvNeXt block init_values: Optional[float] = 1e-6 # for ConvNeXt block, ignored by MBConv
act_layer: str = 'gelu' act_layer: str = 'gelu'
norm_layer: str = '' norm_layer: str = ''
norm_layer_cl: str = '' norm_layer_cl: str = ''
@ -218,10 +230,12 @@ def _rw_coat_cfg(
pool_type='avg2', pool_type='avg2',
conv_output_bias=False, conv_output_bias=False,
conv_attn_early=False, conv_attn_early=False,
conv_attn_act_layer='relu',
conv_norm_layer='', conv_norm_layer='',
transformer_shortcut_bias=True, transformer_shortcut_bias=True,
transformer_norm_layer='layernorm2d', transformer_norm_layer='layernorm2d',
transformer_norm_layer_cl='layernorm', transformer_norm_layer_cl='layernorm',
init_values=None,
rel_pos_type='bias', rel_pos_type='bias',
rel_pos_dim=512, rel_pos_dim=512,
): ):
@ -246,7 +260,7 @@ def _rw_coat_cfg(
expand_output=False, expand_output=False,
output_bias=conv_output_bias, output_bias=conv_output_bias,
attn_early=conv_attn_early, attn_early=conv_attn_early,
attn_act_layer='relu', attn_act_layer=conv_attn_act_layer,
act_layer='silu', act_layer='silu',
norm_layer=conv_norm_layer, norm_layer=conv_norm_layer,
), ),
@ -254,6 +268,7 @@ def _rw_coat_cfg(
expand_first=False, expand_first=False,
shortcut_bias=transformer_shortcut_bias, shortcut_bias=transformer_shortcut_bias,
pool_type=pool_type, pool_type=pool_type,
init_values=init_values,
norm_layer=transformer_norm_layer, norm_layer=transformer_norm_layer,
norm_layer_cl=transformer_norm_layer_cl, norm_layer_cl=transformer_norm_layer_cl,
rel_pos_type=rel_pos_type, rel_pos_type=rel_pos_type,
@ -272,6 +287,7 @@ def _rw_max_cfg(
transformer_norm_layer_cl='layernorm', transformer_norm_layer_cl='layernorm',
window_size=None, window_size=None,
dim_head=32, dim_head=32,
init_values=None,
rel_pos_type='bias', rel_pos_type='bias',
rel_pos_dim=512, rel_pos_dim=512,
): ):
@ -296,6 +312,7 @@ def _rw_max_cfg(
pool_type=pool_type, pool_type=pool_type,
dim_head=dim_head, dim_head=dim_head,
window_size=window_size, window_size=window_size,
init_values=init_values,
norm_layer=transformer_norm_layer, norm_layer=transformer_norm_layer,
norm_layer_cl=transformer_norm_layer_cl, norm_layer_cl=transformer_norm_layer_cl,
rel_pos_type=rel_pos_type, rel_pos_type=rel_pos_type,
@ -312,7 +329,8 @@ def _next_cfg(
transformer_norm_layer='layernorm2d', transformer_norm_layer='layernorm2d',
transformer_norm_layer_cl='layernorm', transformer_norm_layer_cl='layernorm',
window_size=None, window_size=None,
rel_pos_type='bias', init_values=1e-6,
rel_pos_type='mlp', # MLP by default for maxxvit
rel_pos_dim=512, rel_pos_dim=512,
): ):
# For experimental models with convnext instead of mbconv # For experimental models with convnext instead of mbconv
@ -322,6 +340,7 @@ def _next_cfg(
stride_mode=stride_mode, stride_mode=stride_mode,
pool_type=pool_type, pool_type=pool_type,
expand_output=False, expand_output=False,
init_values=init_values,
norm_layer=conv_norm_layer, norm_layer=conv_norm_layer,
norm_layer_cl=conv_norm_layer_cl, norm_layer_cl=conv_norm_layer_cl,
), ),
@ -329,6 +348,7 @@ def _next_cfg(
expand_first=False, expand_first=False,
pool_type=pool_type, pool_type=pool_type,
window_size=window_size, window_size=window_size,
init_values=init_values,
norm_layer=transformer_norm_layer, norm_layer=transformer_norm_layer,
norm_layer_cl=transformer_norm_layer_cl, norm_layer_cl=transformer_norm_layer_cl,
rel_pos_type=rel_pos_type, rel_pos_type=rel_pos_type,
@ -381,7 +401,21 @@ model_cfgs = dict(
embed_dim=(128, 256, 512, 1024), embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
stem_width=(64, 128), stem_width=(64, 128),
**_rw_coat_cfg(stride_mode='dw'), **_rw_coat_cfg(
stride_mode='dw',
conv_attn_act_layer='silu',
init_values=1e-6,
),
),
coatnet_3_rw_224=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2),
stem_width=(96, 192),
**_rw_coat_cfg(
stride_mode='dw',
conv_attn_act_layer='silu',
init_values=1e-6,
),
), ),
# Highly experimental configs # Highly experimental configs
@ -428,6 +462,29 @@ model_cfgs = dict(
rel_pos_dim=384, # was supposed to be 512, woops rel_pos_dim=384, # was supposed to be 512, woops
), ),
), ),
coatnet_rmlp_2_rw_224=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2),
stem_width=(64, 128),
**_rw_coat_cfg(
stride_mode='dw',
conv_attn_act_layer='silu',
init_values=1e-6,
rel_pos_type='mlp'
),
),
coatnet_rmlp_3_rw_224=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2),
stem_width=(96, 192),
**_rw_coat_cfg(
stride_mode='dw',
conv_attn_act_layer='silu',
init_values=1e-6,
rel_pos_type='mlp'
),
),
coatnet_nano_cc_224=MaxxVitCfg( coatnet_nano_cc_224=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3), depths=(3, 4, 6, 3),
@ -504,6 +561,7 @@ model_cfgs = dict(
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(), **_rw_max_cfg(),
), ),
maxvit_rmlp_pico_rw_256=MaxxVitCfg( maxvit_rmlp_pico_rw_256=MaxxVitCfg(
embed_dim=(32, 64, 128, 256), embed_dim=(32, 64, 128, 256),
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
@ -525,6 +583,27 @@ model_cfgs = dict(
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(rel_pos_type='mlp'), **_rw_max_cfg(rel_pos_type='mlp'),
), ),
maxvit_rmlp_small_rw_224=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(
rel_pos_type='mlp',
init_values=1e-6,
),
),
maxvit_rmlp_small_rw_256=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(
rel_pos_type='mlp',
init_values=1e-6,
),
),
maxvit_tiny_pm_256=MaxxVitCfg( maxvit_tiny_pm_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
@ -532,6 +611,7 @@ model_cfgs = dict(
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(), **_rw_max_cfg(),
), ),
maxxvit_nano_rw_256=MaxxVitCfg( maxxvit_nano_rw_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1), depths=(1, 2, 3, 1),
@ -540,6 +620,20 @@ model_cfgs = dict(
weight_init='normal', weight_init='normal',
**_next_cfg(), **_next_cfg(),
), ),
maxxvit_tiny_rw_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_next_cfg(),
),
maxxvit_small_rw_256=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(48, 96),
**_next_cfg(),
),
# Trying to be like the MaxViT paper configs # Trying to be like the MaxViT paper configs
maxvit_tiny_224=MaxxVitCfg( maxvit_tiny_224=MaxxVitCfg(
@ -1641,6 +1735,11 @@ def coatnet_2_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_2_rw_224', pretrained=pretrained, **kwargs) return _create_maxxvit('coatnet_2_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_3_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_3_rw_224', pretrained=pretrained, **kwargs)
@register_model @register_model
def coatnet_bn_0_rw_224(pretrained=False, **kwargs): def coatnet_bn_0_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_bn_0_rw_224', pretrained=pretrained, **kwargs) return _create_maxxvit('coatnet_bn_0_rw_224', pretrained=pretrained, **kwargs)
@ -1661,6 +1760,16 @@ def coatnet_rmlp_1_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_1_rw_224', pretrained=pretrained, **kwargs) return _create_maxxvit('coatnet_rmlp_1_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_rmlp_2_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_rmlp_3_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs)
@register_model @register_model
def coatnet_nano_cc_224(pretrained=False, **kwargs): def coatnet_nano_cc_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_nano_cc_224', pretrained=pretrained, **kwargs) return _create_maxxvit('coatnet_nano_cc_224', pretrained=pretrained, **kwargs)
@ -1736,6 +1845,16 @@ def maxvit_rmlp_tiny_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs) return _create_maxxvit('maxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_small_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_small_rw_224', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_small_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs)
@register_model @register_model
def maxvit_tiny_pm_256(pretrained=False, **kwargs): def maxvit_tiny_pm_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs) return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)
@ -1746,6 +1865,16 @@ def maxxvit_nano_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxxvit_nano_rw_256', pretrained=pretrained, **kwargs) return _create_maxxvit('maxxvit_nano_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxxvit_tiny_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxxvit_tiny_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxxvit_small_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxxvit_small_rw_256', pretrained=pretrained, **kwargs)
@register_model @register_model
def maxvit_tiny_224(pretrained=False, **kwargs): def maxvit_tiny_224(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_224', pretrained=pretrained, **kwargs) return _create_maxxvit('maxvit_tiny_224', pretrained=pretrained, **kwargs)

Loading…
Cancel
Save