From f489f02ad1235c1345a67c61a78e54a4214ac8e6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 6 Sep 2022 15:10:52 -0700 Subject: [PATCH 1/7] Make gcvit window size ratio based to improve resolution changing support #1449. Change default init to original. --- timm/models/gcvit.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/timm/models/gcvit.py b/timm/models/gcvit.py index bad40bd6..fb375e2c 100644 --- a/timm/models/gcvit.py +++ b/timm/models/gcvit.py @@ -30,8 +30,8 @@ import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .fx_features import register_notrace_function from .helpers import build_model_with_cfg, named_apply -from .layers import trunc_normal_tf_, DropPath, to_2tuple, Mlp, get_attn, get_act_layer, get_norm_layer, \ - ClassifierHead, LayerNorm2d, _assert +from .layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d,\ + get_attn, get_act_layer, get_norm_layer, _assert from .registry import register_model from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move to common location @@ -321,7 +321,7 @@ class GlobalContextVitStage(nn.Module): depth: int, num_heads: int, feat_size: Tuple[int, int], - window_size: int, + window_size: Tuple[int, int], downsample: bool = True, global_norm: bool = False, stage_norm: bool = False, @@ -347,8 +347,9 @@ class GlobalContextVitStage(nn.Module): else: self.downsample = nn.Identity() self.feat_size = feat_size + window_size = to_2tuple(window_size) - feat_levels = int(math.log2(min(feat_size) / window_size)) + feat_levels = int(math.log2(min(feat_size) / min(window_size))) self.global_block = FeatureBlock(dim, feat_levels) self.global_norm = norm_layer_cl(dim) if global_norm else nn.Identity() @@ -400,7 +401,8 @@ class GlobalContextVit(nn.Module): num_classes: int = 1000, global_pool: str = 'avg', img_size: Tuple[int, int] = 224, - window_size: Tuple[int, ...] = (7, 7, 14, 7), + window_ratio: Tuple[int, ...] = (32, 32, 16, 32), + window_size: Tuple[int, ...] = None, embed_dim: int = 64, depths: Tuple[int, ...] = (3, 4, 19, 5), num_heads: Tuple[int, ...] = (2, 4, 8, 16), @@ -411,7 +413,7 @@ class GlobalContextVit(nn.Module): proj_drop_rate: float = 0., attn_drop_rate: float = 0., drop_path_rate: float = 0., - weight_init='vit', + weight_init='', act_layer: str = 'gelu', norm_layer: str = 'layernorm2d', norm_layer_cl: str = 'layernorm', @@ -429,6 +431,11 @@ class GlobalContextVit(nn.Module): self.drop_rate = drop_rate num_stages = len(depths) self.num_features = int(embed_dim * 2 ** (num_stages - 1)) + if window_size is not None: + window_size = to_ntuple(num_stages)(window_size) + else: + assert window_ratio is not None + window_size = tuple([(img_size[0] // r, img_size[1] // r) for r in to_ntuple(num_stages)(window_ratio)]) self.stem = Stem( in_chs=in_chans, @@ -480,7 +487,7 @@ class GlobalContextVit(nn.Module): nn.init.zeros_(module.bias) else: if isinstance(module, nn.Linear): - trunc_normal_tf_(module.weight, std=.02) + nn.init.normal_(module.weight, std=.02) if module.bias is not None: nn.init.zeros_(module.bias) @@ -490,7 +497,6 @@ class GlobalContextVit(nn.Module): k for k, _ in self.named_parameters() if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])} - @torch.jit.ignore def group_matcher(self, coarse=False): matcher = dict( @@ -567,7 +573,6 @@ def gcvit_small(pretrained=False, **kwargs): model_kwargs = dict( depths=(3, 4, 19, 5), num_heads=(3, 6, 12, 24), - window_size=(7, 7, 14, 7), embed_dim=96, mlp_ratio=2, layer_scale=1e-5, @@ -580,7 +585,6 @@ def gcvit_base(pretrained=False, **kwargs): model_kwargs = dict( depths=(3, 4, 19, 5), num_heads=(4, 8, 16, 32), - window_size=(7, 7, 14, 7), embed_dim=128, mlp_ratio=2, layer_scale=1e-5, From dc90816f2676c4e393fa26264eadc00a1f0c3e53 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 6 Sep 2022 16:14:41 -0700 Subject: [PATCH 2/7] Add `maxvit_tiny_rw_224` weights 83.5 @ 224 and `maxvit_rmlp_pico_rw_256` relpos weights, 80.5 @ 256, 81.3 @ 320 --- timm/models/maxxvit.py | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 495f682b..f1df148b 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -110,8 +110,14 @@ default_cfgs = { 'maxvit_nano_rw_256': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_tiny_rw_224': _cfg(url=''), - 'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_tiny_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'), + 'maxvit_tiny_rw_256': _cfg( + url='', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_rmlp_pico_rw_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_rmlp_nano_rw_256': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth', input_size=(3, 256, 256), pool_size=(8, 8)), @@ -139,7 +145,7 @@ class MaxxVitTransformerCfg: pool_type: str = 'avg2' rel_pos_type: str = 'bias' rel_pos_dim: int = 512 # for relative position types w/ MLP - partition_stride: int = 32 + partition_ratio: int = 32 window_size: Optional[Tuple[int, int]] = None grid_size: Optional[Tuple[int, int]] = None init_values: Optional[float] = None @@ -495,6 +501,13 @@ model_cfgs = dict( stem_width=(32, 64), **_rw_max_cfg(), ), + maxvit_rmlp_pico_rw_256=MaxxVitCfg( + embed_dim=(32, 64, 128, 256), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(24, 32), + **_rw_max_cfg(rel_pos_type='mlp'), + ), maxvit_rmlp_nano_rw_256=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(1, 2, 3, 1), @@ -1458,7 +1471,7 @@ def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]): if cfg.window_size is not None: assert cfg.grid_size return cfg - partition_size = img_size[0] // cfg.partition_stride, img_size[1] // cfg.partition_stride + partition_size = img_size[0] // cfg.partition_ratio, img_size[1] // cfg.partition_ratio cfg = replace(cfg, window_size=partition_size, grid_size=partition_size) return cfg @@ -1698,6 +1711,11 @@ def maxvit_tiny_rw_256(pretrained=False, **kwargs): return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs) +@register_model +def maxvit_rmlp_pico_rw_256(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_rmlp_pico_rw_256', pretrained=pretrained, **kwargs) + + @register_model def maxvit_rmlp_nano_rw_256(pretrained=False, **kwargs): return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs) From 914544fc81d2fe83efdbbeaa747035b2f1a4f991 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 6 Sep 2022 20:25:18 -0700 Subject: [PATCH 3/7] Add beitv2 224x224 checkpoints from https://github.com/microsoft/unilm/tree/master/beit2 --- timm/models/beit.py | 76 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/timm/models/beit.py b/timm/models/beit.py index a2083a4a..60497d9a 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -1,6 +1,25 @@ """ BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) Model from official source: https://github.com/microsoft/unilm/tree/master/beit +and +https://github.com/microsoft/unilm/tree/master/beit2 + +@inproceedings{beit, +title={{BEiT}: {BERT} Pre-Training of Image Transformers}, +author={Hangbo Bao and Li Dong and Songhao Piao and Furu Wei}, +booktitle={International Conference on Learning Representations}, +year={2022}, +url={https://openreview.net/forum?id=p-BhZSz59o4} +} + +@article{beitv2, +title={{BEiT v2}: Masked Image Modeling with Vector-Quantized Visual Tokenizers}, +author={Zhiliang Peng and Li Dong and Hangbo Bao and Qixiang Ye and Furu Wei}, +year={2022}, +eprint={2208.06366}, +archivePrefix={arXiv}, +primaryClass={cs.CV} +} At this point only the 1k fine-tuned classification weights and model configs have been added, see original source above for pre-training models and procedure. @@ -27,6 +46,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.utils.checkpoint import checkpoint +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ from .registry import register_model @@ -69,6 +89,26 @@ default_cfgs = { url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth', num_classes=21841, ), + + 'beitv2_base_patch16_224': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD + ), + 'beitv2_base_patch16_224_in22k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth', + num_classes=21841, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD + ), + 'beitv2_large_patch16_224': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth', + crop_pct=0.95, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD + ), + 'beitv2_large_patch16_224_in22k': _cfg( + url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth', + num_classes=21841, + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD + ), } @@ -417,3 +457,39 @@ def beit_large_patch16_224_in22k(pretrained=False, **kwargs): use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) model = _create_beit('beit_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs) return model + + +@register_model +def beitv2_base_patch16_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) + model = _create_beit('beitv2_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def beitv2_base_patch16_224_in22k(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, + use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) + model = _create_beit('beitv2_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def beitv2_large_patch16_224(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) + model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def beitv2_large_patch16_224_in22k(pretrained=False, **kwargs): + model_kwargs = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True, + use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) + model = _create_beit('beitv2_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs) + return model From da6f8f5a4069a2c2e26648a8a18eada1d510f1ac Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 7 Sep 2022 08:09:47 -0700 Subject: [PATCH 4/7] Fix beitv2 tests --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 175137e2..d007d65a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -26,7 +26,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): # transformer models don't support many of the spatial / feature based model functionalities NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', - 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*', + 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', ] From c1b3cea19df84beb9d5e141272b383e9ba9a0980 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 7 Sep 2022 10:27:11 -0700 Subject: [PATCH 5/7] Add maxvit_rmlp_tiny_rw_256 model def and weights w/ 84.2 top-1 @ 256, 84.8 @ 320 --- timm/models/maxxvit.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index f1df148b..f10e9f59 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -121,6 +121,9 @@ default_cfgs = { 'maxvit_rmlp_nano_rw_256': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_rmlp_tiny_rw_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-2da819a5.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), @@ -515,6 +518,13 @@ model_cfgs = dict( stem_width=(32, 64), **_rw_max_cfg(rel_pos_type='mlp'), ), + maxvit_rmlp_tiny_rw_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(rel_pos_type='mlp'), + ), maxvit_tiny_pm_256=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(2, 2, 5, 2), @@ -1721,6 +1731,11 @@ def maxvit_rmlp_nano_rw_256(pretrained=False, **kwargs): return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs) +@register_model +def maxvit_rmlp_tiny_rw_256(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs) + + @register_model def maxvit_tiny_pm_256(pretrained=False, **kwargs): return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs) From de40f66536daeeda6f013737789467733f0e902f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 7 Sep 2022 10:40:58 -0700 Subject: [PATCH 6/7] Update README.md --- README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/README.md b/README.md index b5945b19..d0f6cd0e 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,14 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before ## What's New +### Sept 7, 2022 +* Hugging Face [`timm` docs](https://huggingface.co/docs/hub/timm) home now exists, look for more here in the future +* Add BEiT-v2 weights for base and large 224x224 models from https://github.com/microsoft/unilm/tree/master/beit2 +* Add more weights in `maxxvit` series incl a `pico` (7.5M params, 1.9 GMACs), two `tiny` variants: + * `maxvit_rmlp_pico_rw_256` - 80.5 @ 256, 81.3 @ 320 (T) + * `maxvit_tiny_rw_224` - 83.5 @ 224 (G) + * `maxvit_rmlp_tiny_rw_256` - 84.2 @ 256, 84.8 @ 320 (T) + ### Aug 29, 2022 * MaxVit window size scales with img_size by default. Add new RelPosMlp MaxViT weight that leverages this: * `maxvit_rmlp_nano_rw_256` - 83.0 @ 256, 83.6 @ 320 (T) @@ -407,6 +415,8 @@ Model validation results can be found in the [documentation](https://rwightman.g My current [documentation](https://rwightman.github.io/pytorch-image-models/) for `timm` covers the basics. +Hugging Face [`timm` docs](https://huggingface.co/docs/hub/timm) will be the documentation focus going forward and will eventually replace the `github.io` docs above. + [Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail. [timmdocs](http://timm.fast.ai/) is quickly becoming a much more comprehensive set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs. From fa8c84eede55b36861460cc8ee6ac201c068df4d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 7 Sep 2022 12:37:37 -0700 Subject: [PATCH 7/7] Update maxvit_tiny_256 weight to better iter, add coatnet / maxvit / maxxvit model defs for future runs --- timm/models/maxxvit.py | 139 +++++++++++++++++++++++++++++++++++++++-- 1 file changed, 134 insertions(+), 5 deletions(-) diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index f10e9f59..1090e755 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -82,6 +82,7 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth' ), 'coatnet_2_rw_224': _cfg(url=''), + 'coatnet_3_rw_224': _cfg(url=''), # Highly experimental configs 'coatnet_bn_0_rw_224': _cfg( @@ -94,6 +95,8 @@ default_cfgs = { 'coatnet_rmlp_0_rw_224': _cfg(url=''), 'coatnet_rmlp_1_rw_224': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'), + 'coatnet_rmlp_2_rw_224': _cfg(url=''), + 'coatnet_rmlp_3_rw_224': _cfg(url=''), 'coatnet_nano_cc_224': _cfg(url=''), 'coatnext_nano_rw_224': _cfg(url=''), @@ -122,10 +125,19 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_rmlp_tiny_rw_256': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-2da819a5.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_rmlp_small_rw_224': _cfg( + url=''), + 'maxvit_rmlp_small_rw_256': _cfg( + url='', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvit_small_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), # Trying to be like the MaxViT paper configs 'maxvit_tiny_224': _cfg(url=''), @@ -182,7 +194,7 @@ class MaxxVitConvCfg: attn_layer: str = 'se' attn_act_layer: str = 'silu' attn_ratio: float = 0.25 - init_values: Optional[float] = 1e-5 # for ConvNeXt block + init_values: Optional[float] = 1e-6 # for ConvNeXt block, ignored by MBConv act_layer: str = 'gelu' norm_layer: str = '' norm_layer_cl: str = '' @@ -218,10 +230,12 @@ def _rw_coat_cfg( pool_type='avg2', conv_output_bias=False, conv_attn_early=False, + conv_attn_act_layer='relu', conv_norm_layer='', transformer_shortcut_bias=True, transformer_norm_layer='layernorm2d', transformer_norm_layer_cl='layernorm', + init_values=None, rel_pos_type='bias', rel_pos_dim=512, ): @@ -246,7 +260,7 @@ def _rw_coat_cfg( expand_output=False, output_bias=conv_output_bias, attn_early=conv_attn_early, - attn_act_layer='relu', + attn_act_layer=conv_attn_act_layer, act_layer='silu', norm_layer=conv_norm_layer, ), @@ -254,6 +268,7 @@ def _rw_coat_cfg( expand_first=False, shortcut_bias=transformer_shortcut_bias, pool_type=pool_type, + init_values=init_values, norm_layer=transformer_norm_layer, norm_layer_cl=transformer_norm_layer_cl, rel_pos_type=rel_pos_type, @@ -272,6 +287,7 @@ def _rw_max_cfg( transformer_norm_layer_cl='layernorm', window_size=None, dim_head=32, + init_values=None, rel_pos_type='bias', rel_pos_dim=512, ): @@ -296,6 +312,7 @@ def _rw_max_cfg( pool_type=pool_type, dim_head=dim_head, window_size=window_size, + init_values=init_values, norm_layer=transformer_norm_layer, norm_layer_cl=transformer_norm_layer_cl, rel_pos_type=rel_pos_type, @@ -312,7 +329,8 @@ def _next_cfg( transformer_norm_layer='layernorm2d', transformer_norm_layer_cl='layernorm', window_size=None, - rel_pos_type='bias', + init_values=1e-6, + rel_pos_type='mlp', # MLP by default for maxxvit rel_pos_dim=512, ): # For experimental models with convnext instead of mbconv @@ -322,6 +340,7 @@ def _next_cfg( stride_mode=stride_mode, pool_type=pool_type, expand_output=False, + init_values=init_values, norm_layer=conv_norm_layer, norm_layer_cl=conv_norm_layer_cl, ), @@ -329,6 +348,7 @@ def _next_cfg( expand_first=False, pool_type=pool_type, window_size=window_size, + init_values=init_values, norm_layer=transformer_norm_layer, norm_layer_cl=transformer_norm_layer_cl, rel_pos_type=rel_pos_type, @@ -381,7 +401,21 @@ model_cfgs = dict( embed_dim=(128, 256, 512, 1024), depths=(2, 6, 14, 2), stem_width=(64, 128), - **_rw_coat_cfg(stride_mode='dw'), + **_rw_coat_cfg( + stride_mode='dw', + conv_attn_act_layer='silu', + init_values=1e-6, + ), + ), + coatnet_3_rw_224=MaxxVitCfg( + embed_dim=(192, 384, 768, 1536), + depths=(2, 6, 14, 2), + stem_width=(96, 192), + **_rw_coat_cfg( + stride_mode='dw', + conv_attn_act_layer='silu', + init_values=1e-6, + ), ), # Highly experimental configs @@ -428,6 +462,29 @@ model_cfgs = dict( rel_pos_dim=384, # was supposed to be 512, woops ), ), + coatnet_rmlp_2_rw_224=MaxxVitCfg( + embed_dim=(128, 256, 512, 1024), + depths=(2, 6, 14, 2), + stem_width=(64, 128), + **_rw_coat_cfg( + stride_mode='dw', + conv_attn_act_layer='silu', + init_values=1e-6, + rel_pos_type='mlp' + ), + ), + coatnet_rmlp_3_rw_224=MaxxVitCfg( + embed_dim=(192, 384, 768, 1536), + depths=(2, 6, 14, 2), + stem_width=(96, 192), + **_rw_coat_cfg( + stride_mode='dw', + conv_attn_act_layer='silu', + init_values=1e-6, + rel_pos_type='mlp' + ), + ), + coatnet_nano_cc_224=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(3, 4, 6, 3), @@ -504,6 +561,7 @@ model_cfgs = dict( stem_width=(32, 64), **_rw_max_cfg(), ), + maxvit_rmlp_pico_rw_256=MaxxVitCfg( embed_dim=(32, 64, 128, 256), depths=(2, 2, 5, 2), @@ -525,6 +583,27 @@ model_cfgs = dict( stem_width=(32, 64), **_rw_max_cfg(rel_pos_type='mlp'), ), + maxvit_rmlp_small_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg( + rel_pos_type='mlp', + init_values=1e-6, + ), + ), + maxvit_rmlp_small_rw_256=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg( + rel_pos_type='mlp', + init_values=1e-6, + ), + ), + maxvit_tiny_pm_256=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(2, 2, 5, 2), @@ -532,6 +611,7 @@ model_cfgs = dict( stem_width=(32, 64), **_rw_max_cfg(), ), + maxxvit_nano_rw_256=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(1, 2, 3, 1), @@ -540,6 +620,20 @@ model_cfgs = dict( weight_init='normal', **_next_cfg(), ), + maxxvit_tiny_rw_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(32, 64), + **_next_cfg(), + ), + maxxvit_small_rw_256=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(48, 96), + **_next_cfg(), + ), # Trying to be like the MaxViT paper configs maxvit_tiny_224=MaxxVitCfg( @@ -1641,6 +1735,11 @@ def coatnet_2_rw_224(pretrained=False, **kwargs): return _create_maxxvit('coatnet_2_rw_224', pretrained=pretrained, **kwargs) +@register_model +def coatnet_3_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_3_rw_224', pretrained=pretrained, **kwargs) + + @register_model def coatnet_bn_0_rw_224(pretrained=False, **kwargs): return _create_maxxvit('coatnet_bn_0_rw_224', pretrained=pretrained, **kwargs) @@ -1661,6 +1760,16 @@ def coatnet_rmlp_1_rw_224(pretrained=False, **kwargs): return _create_maxxvit('coatnet_rmlp_1_rw_224', pretrained=pretrained, **kwargs) +@register_model +def coatnet_rmlp_2_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def coatnet_rmlp_3_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs) + + @register_model def coatnet_nano_cc_224(pretrained=False, **kwargs): return _create_maxxvit('coatnet_nano_cc_224', pretrained=pretrained, **kwargs) @@ -1736,6 +1845,16 @@ def maxvit_rmlp_tiny_rw_256(pretrained=False, **kwargs): return _create_maxxvit('maxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs) +@register_model +def maxvit_rmlp_small_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_rmlp_small_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_rmlp_small_rw_256(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs) + + @register_model def maxvit_tiny_pm_256(pretrained=False, **kwargs): return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs) @@ -1746,6 +1865,16 @@ def maxxvit_nano_rw_256(pretrained=False, **kwargs): return _create_maxxvit('maxxvit_nano_rw_256', pretrained=pretrained, **kwargs) +@register_model +def maxxvit_tiny_rw_256(pretrained=False, **kwargs): + return _create_maxxvit('maxxvit_tiny_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxxvit_small_rw_256(pretrained=False, **kwargs): + return _create_maxxvit('maxxvit_small_rw_256', pretrained=pretrained, **kwargs) + + @register_model def maxvit_tiny_224(pretrained=False, **kwargs): return _create_maxxvit('maxvit_tiny_224', pretrained=pretrained, **kwargs)