|
|
@ -259,8 +259,6 @@ def _rw_max_cfg(
|
|
|
|
# - mbconv expansion calculated from input instead of output chs
|
|
|
|
# - mbconv expansion calculated from input instead of output chs
|
|
|
|
# - mbconv shortcut and final 1x1 conv did not have a bias
|
|
|
|
# - mbconv shortcut and final 1x1 conv did not have a bias
|
|
|
|
# - mbconv uses silu in timm, not gelu
|
|
|
|
# - mbconv uses silu in timm, not gelu
|
|
|
|
# - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat)
|
|
|
|
|
|
|
|
# - default to avg pool for mbconv downsample instead of 1x1 or dw conv
|
|
|
|
|
|
|
|
# - expansion in attention block done via output proj, not input proj
|
|
|
|
# - expansion in attention block done via output proj, not input proj
|
|
|
|
return dict(
|
|
|
|
return dict(
|
|
|
|
conv_cfg=MaxxVitConvCfg(
|
|
|
|
conv_cfg=MaxxVitConvCfg(
|
|
|
@ -411,18 +409,19 @@ model_cfgs = dict(
|
|
|
|
rel_pos_dim=384, # was supposed to be 512, woops
|
|
|
|
rel_pos_dim=384, # was supposed to be 512, woops
|
|
|
|
),
|
|
|
|
),
|
|
|
|
),
|
|
|
|
),
|
|
|
|
coatnext_nano_rw_224=MaxxVitCfg(
|
|
|
|
coatnet_nano_cc_224=MaxxVitCfg(
|
|
|
|
embed_dim=(64, 128, 256, 512),
|
|
|
|
embed_dim=(64, 128, 256, 512),
|
|
|
|
depths=(3, 4, 6, 3),
|
|
|
|
depths=(3, 4, 6, 3),
|
|
|
|
stem_width=(32, 64),
|
|
|
|
stem_width=(32, 64),
|
|
|
|
**_next_cfg(),
|
|
|
|
block_type=('C', 'C', ('C', 'T'), ('C', 'T')),
|
|
|
|
|
|
|
|
**_rw_coat_cfg(),
|
|
|
|
),
|
|
|
|
),
|
|
|
|
coatnet_nano_cc_224=MaxxVitCfg(
|
|
|
|
coatnext_nano_rw_224=MaxxVitCfg(
|
|
|
|
embed_dim=(64, 128, 256, 512),
|
|
|
|
embed_dim=(64, 128, 256, 512),
|
|
|
|
depths=(3, 4, 6, 3),
|
|
|
|
depths=(3, 4, 6, 3),
|
|
|
|
stem_width=(32, 64),
|
|
|
|
stem_width=(32, 64),
|
|
|
|
block_type=('C', 'C', ('C', 'T'), ('C', 'T')),
|
|
|
|
weight_init='normal',
|
|
|
|
**_rw_coat_cfg(),
|
|
|
|
**_next_cfg(),
|
|
|
|
),
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
|
|
# Trying to be like the CoAtNet paper configs
|
|
|
|
# Trying to be like the CoAtNet paper configs
|
|
|
@ -498,6 +497,7 @@ model_cfgs = dict(
|
|
|
|
depths=(1, 2, 3, 1),
|
|
|
|
depths=(1, 2, 3, 1),
|
|
|
|
block_type=('M',) * 4,
|
|
|
|
block_type=('M',) * 4,
|
|
|
|
stem_width=(32, 64),
|
|
|
|
stem_width=(32, 64),
|
|
|
|
|
|
|
|
weight_init='normal',
|
|
|
|
**_next_cfg(window_size=8),
|
|
|
|
**_next_cfg(window_size=8),
|
|
|
|
),
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
|
|