From 9914f744dc1eae8618b95e73335526458dc89149 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 10 Oct 2022 21:49:18 -0700 Subject: [PATCH] Add more maxxvit weights includ ConvNeXt conv block based experiments. --- README.md | 15 +++++++++++++ timm/models/maxxvit.py | 49 ++++++++++++++++++++++++++---------------- 2 files changed, 46 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index d0f6cd0e..a8863d34 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,21 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before ## What's New +### Oct 10, 2022 +* More weights in `maxxvit` series, incl first ConvNeXt block based `coatnext` and `maxxvit` experiments: + * `coatnext_nano_rw_224` - 82.0 @ 224 (G) -- (uses ConvNeXt conv block, no BatchNorm) + * `maxxvit_nano_rw_256` - 83.0 @ 256, 83.7 @ 320 (G) (uses ConvNeXt conv block, no BN) + * `maxvit_rmlp_small_rw_224` - 84.5 @ 224, 85.1 @ 320 (G) + * `maxxvit_small_rw_256` - 84.6 @ 256, 84.9 @ 288 (G) -- could be trained better, hparams need tuning (uses ConvNeXt block, no BN) + * `coatnet_rmlp_2_rw_224` - 84.6 @ 224, 85 @ 320 (T) + +### Sept 23, 2022 +* LAION-2B CLIP image towers supported as pretrained backbones for fine-tune or features (no classifier) + * vit_base_patch32_224_clip_laion2b + * vit_large_patch14_224_clip_laion2b + * vit_huge_patch14_224_clip_laion2b + * vit_giant_patch14_224_clip_laion2b + ### Sept 7, 2022 * Hugging Face [`timm` docs](https://huggingface.co/docs/hub/timm) home now exists, look for more here in the future * Add BEiT-v2 weights for base and large 224x224 models from https://github.com/microsoft/unilm/tree/master/beit2 diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 1090e755..bd529245 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -95,10 +95,13 @@ default_cfgs = { 'coatnet_rmlp_0_rw_224': _cfg(url=''), 'coatnet_rmlp_1_rw_224': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'), - 'coatnet_rmlp_2_rw_224': _cfg(url=''), + 'coatnet_rmlp_2_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'), 'coatnet_rmlp_3_rw_224': _cfg(url=''), 'coatnet_nano_cc_224': _cfg(url=''), - 'coatnext_nano_rw_224': _cfg(url=''), + 'coatnext_nano_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth', + crop_pct=0.9), # Trying to be like the CoAtNet paper configs 'coatnet_0_224': _cfg(url=''), @@ -128,16 +131,22 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_rmlp_small_rw_224': _cfg( - url=''), + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth', + crop_pct=0.9, + ), 'maxvit_rmlp_small_rw_256': _cfg( url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxxvit_small_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvit_rmlp_nano_rw_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvit_rmlp_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvit_rmlp_small_rw_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), # Trying to be like the MaxViT paper configs 'maxvit_tiny_224': _cfg(url=''), @@ -334,13 +343,14 @@ def _next_cfg( rel_pos_dim=512, ): # For experimental models with convnext instead of mbconv + init_values = to_2tuple(init_values) return dict( conv_cfg=MaxxVitConvCfg( block_type='convnext', stride_mode=stride_mode, pool_type=pool_type, expand_output=False, - init_values=init_values, + init_values=init_values[0], norm_layer=conv_norm_layer, norm_layer_cl=conv_norm_layer_cl, ), @@ -348,7 +358,7 @@ def _next_cfg( expand_first=False, pool_type=pool_type, window_size=window_size, - init_values=init_values, + init_values=init_values[1], norm_layer=transformer_norm_layer, norm_layer_cl=transformer_norm_layer_cl, rel_pos_type=rel_pos_type, @@ -497,7 +507,10 @@ model_cfgs = dict( depths=(3, 4, 6, 3), stem_width=(32, 64), weight_init='normal', - **_next_cfg(), + **_next_cfg( + rel_pos_type='bias', + init_values=(1e-5, None) + ), ), # Trying to be like the CoAtNet paper configs @@ -612,7 +625,7 @@ model_cfgs = dict( **_rw_max_cfg(), ), - maxxvit_nano_rw_256=MaxxVitCfg( + maxxvit_rmlp_nano_rw_256=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(1, 2, 3, 1), block_type=('M',) * 4, @@ -620,14 +633,14 @@ model_cfgs = dict( weight_init='normal', **_next_cfg(), ), - maxxvit_tiny_rw_256=MaxxVitCfg( + maxxvit_rmlp_tiny_rw_256=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(32, 64), **_next_cfg(), ), - maxxvit_small_rw_256=MaxxVitCfg( + maxxvit_rmlp_small_rw_256=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 2, 5, 2), block_type=('M',) * 4, @@ -1861,18 +1874,18 @@ def maxvit_tiny_pm_256(pretrained=False, **kwargs): @register_model -def maxxvit_nano_rw_256(pretrained=False, **kwargs): - return _create_maxxvit('maxxvit_nano_rw_256', pretrained=pretrained, **kwargs) +def maxxvit_rmlp_nano_rw_256(pretrained=False, **kwargs): + return _create_maxxvit('maxxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs) @register_model -def maxxvit_tiny_rw_256(pretrained=False, **kwargs): - return _create_maxxvit('maxxvit_tiny_rw_256', pretrained=pretrained, **kwargs) +def maxxvit_rmlp_tiny_rw_256(pretrained=False, **kwargs): + return _create_maxxvit('maxxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs) @register_model -def maxxvit_small_rw_256(pretrained=False, **kwargs): - return _create_maxxvit('maxxvit_small_rw_256', pretrained=pretrained, **kwargs) +def maxxvit_rmlp_small_rw_256(pretrained=False, **kwargs): + return _create_maxxvit('maxxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs) @register_model