Merge remote-tracking branch 'origin/master' into bits_and_tpu

pull/804/merge^2
Ross Wightman 2 years ago
commit 38594ef7fd

@ -21,6 +21,14 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
## What's New ## What's New
### Sept 7, 2022
* Hugging Face [`timm` docs](https://huggingface.co/docs/hub/timm) home now exists, look for more here in the future
* Add BEiT-v2 weights for base and large 224x224 models from https://github.com/microsoft/unilm/tree/master/beit2
* Add more weights in `maxxvit` series incl a `pico` (7.5M params, 1.9 GMACs), two `tiny` variants:
* `maxvit_rmlp_pico_rw_256` - 80.5 @ 256, 81.3 @ 320 (T)
* `maxvit_tiny_rw_224` - 83.5 @ 224 (G)
* `maxvit_rmlp_tiny_rw_256` - 84.2 @ 256, 84.8 @ 320 (T)
### Aug 29, 2022 ### Aug 29, 2022
* MaxVit window size scales with img_size by default. Add new RelPosMlp MaxViT weight that leverages this: * MaxVit window size scales with img_size by default. Add new RelPosMlp MaxViT weight that leverages this:
* `maxvit_rmlp_nano_rw_256` - 83.0 @ 256, 83.6 @ 320 (T) * `maxvit_rmlp_nano_rw_256` - 83.0 @ 256, 83.6 @ 320 (T)
@ -407,6 +415,8 @@ Model validation results can be found in the [documentation](https://rwightman.g
My current [documentation](https://rwightman.github.io/pytorch-image-models/) for `timm` covers the basics. My current [documentation](https://rwightman.github.io/pytorch-image-models/) for `timm` covers the basics.
Hugging Face [`timm` docs](https://huggingface.co/docs/hub/timm) will be the documentation focus going forward and will eventually replace the `github.io` docs above.
[Getting Started with PyTorch Image Models (timm): A Practitioners Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail. [Getting Started with PyTorch Image Models (timm): A Practitioners Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail.
[timmdocs](http://timm.fast.ai/) is quickly becoming a much more comprehensive set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs. [timmdocs](http://timm.fast.ai/) is quickly becoming a much more comprehensive set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs.

@ -26,7 +26,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
# transformer models don't support many of the spatial / feature based model functionalities # transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = [ NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*',
] ]

@ -1,6 +1,25 @@
""" BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254) """ BEIT: BERT Pre-Training of Image Transformers (https://arxiv.org/abs/2106.08254)
Model from official source: https://github.com/microsoft/unilm/tree/master/beit Model from official source: https://github.com/microsoft/unilm/tree/master/beit
and
https://github.com/microsoft/unilm/tree/master/beit2
@inproceedings{beit,
title={{BEiT}: {BERT} Pre-Training of Image Transformers},
author={Hangbo Bao and Li Dong and Songhao Piao and Furu Wei},
booktitle={International Conference on Learning Representations},
year={2022},
url={https://openreview.net/forum?id=p-BhZSz59o4}
}
@article{beitv2,
title={{BEiT v2}: Masked Image Modeling with Vector-Quantized Visual Tokenizers},
author={Zhiliang Peng and Li Dong and Hangbo Bao and Qixiang Ye and Furu Wei},
year={2022},
eprint={2208.06366},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
At this point only the 1k fine-tuned classification weights and model configs have been added, At this point only the 1k fine-tuned classification weights and model configs have been added,
see original source above for pre-training models and procedure. see original source above for pre-training models and procedure.
@ -27,6 +46,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_ from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_
from .registry import register_model from .registry import register_model
@ -69,6 +89,26 @@ default_cfgs = {
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth', url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth',
num_classes=21841, num_classes=21841,
), ),
'beitv2_base_patch16_224': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'beitv2_base_patch16_224_in22k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth',
num_classes=21841,
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'beitv2_large_patch16_224': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth',
crop_pct=0.95,
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'beitv2_large_patch16_224_in22k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth',
num_classes=21841,
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
} }
@ -417,3 +457,39 @@ def beit_large_patch16_224_in22k(pretrained=False, **kwargs):
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs) use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
model = _create_beit('beit_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs) model = _create_beit('beit_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
return model return model
@register_model
def beitv2_base_patch16_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
model = _create_beit('beitv2_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def beitv2_base_patch16_224_in22k(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
model = _create_beit('beitv2_base_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
return model
@register_model
def beitv2_large_patch16_224(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
model = _create_beit('beitv2_large_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def beitv2_large_patch16_224_in22k(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
use_abs_pos_emb=False, use_rel_pos_bias=True, init_values=1e-5, **kwargs)
model = _create_beit('beitv2_large_patch16_224_in22k', pretrained=pretrained, **model_kwargs)
return model

@ -30,8 +30,8 @@ import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, named_apply from .helpers import build_model_with_cfg, named_apply
from .layers import trunc_normal_tf_, DropPath, to_2tuple, Mlp, get_attn, get_act_layer, get_norm_layer, \ from .layers import DropPath, to_2tuple, to_ntuple, Mlp, ClassifierHead, LayerNorm2d,\
ClassifierHead, LayerNorm2d, _assert get_attn, get_act_layer, get_norm_layer, _assert
from .registry import register_model from .registry import register_model
from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move to common location from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move to common location
@ -321,7 +321,7 @@ class GlobalContextVitStage(nn.Module):
depth: int, depth: int,
num_heads: int, num_heads: int,
feat_size: Tuple[int, int], feat_size: Tuple[int, int],
window_size: int, window_size: Tuple[int, int],
downsample: bool = True, downsample: bool = True,
global_norm: bool = False, global_norm: bool = False,
stage_norm: bool = False, stage_norm: bool = False,
@ -347,8 +347,9 @@ class GlobalContextVitStage(nn.Module):
else: else:
self.downsample = nn.Identity() self.downsample = nn.Identity()
self.feat_size = feat_size self.feat_size = feat_size
window_size = to_2tuple(window_size)
feat_levels = int(math.log2(min(feat_size) / window_size)) feat_levels = int(math.log2(min(feat_size) / min(window_size)))
self.global_block = FeatureBlock(dim, feat_levels) self.global_block = FeatureBlock(dim, feat_levels)
self.global_norm = norm_layer_cl(dim) if global_norm else nn.Identity() self.global_norm = norm_layer_cl(dim) if global_norm else nn.Identity()
@ -400,7 +401,8 @@ class GlobalContextVit(nn.Module):
num_classes: int = 1000, num_classes: int = 1000,
global_pool: str = 'avg', global_pool: str = 'avg',
img_size: Tuple[int, int] = 224, img_size: Tuple[int, int] = 224,
window_size: Tuple[int, ...] = (7, 7, 14, 7), window_ratio: Tuple[int, ...] = (32, 32, 16, 32),
window_size: Tuple[int, ...] = None,
embed_dim: int = 64, embed_dim: int = 64,
depths: Tuple[int, ...] = (3, 4, 19, 5), depths: Tuple[int, ...] = (3, 4, 19, 5),
num_heads: Tuple[int, ...] = (2, 4, 8, 16), num_heads: Tuple[int, ...] = (2, 4, 8, 16),
@ -411,7 +413,7 @@ class GlobalContextVit(nn.Module):
proj_drop_rate: float = 0., proj_drop_rate: float = 0.,
attn_drop_rate: float = 0., attn_drop_rate: float = 0.,
drop_path_rate: float = 0., drop_path_rate: float = 0.,
weight_init='vit', weight_init='',
act_layer: str = 'gelu', act_layer: str = 'gelu',
norm_layer: str = 'layernorm2d', norm_layer: str = 'layernorm2d',
norm_layer_cl: str = 'layernorm', norm_layer_cl: str = 'layernorm',
@ -429,6 +431,11 @@ class GlobalContextVit(nn.Module):
self.drop_rate = drop_rate self.drop_rate = drop_rate
num_stages = len(depths) num_stages = len(depths)
self.num_features = int(embed_dim * 2 ** (num_stages - 1)) self.num_features = int(embed_dim * 2 ** (num_stages - 1))
if window_size is not None:
window_size = to_ntuple(num_stages)(window_size)
else:
assert window_ratio is not None
window_size = tuple([(img_size[0] // r, img_size[1] // r) for r in to_ntuple(num_stages)(window_ratio)])
self.stem = Stem( self.stem = Stem(
in_chs=in_chans, in_chs=in_chans,
@ -480,7 +487,7 @@ class GlobalContextVit(nn.Module):
nn.init.zeros_(module.bias) nn.init.zeros_(module.bias)
else: else:
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
trunc_normal_tf_(module.weight, std=.02) nn.init.normal_(module.weight, std=.02)
if module.bias is not None: if module.bias is not None:
nn.init.zeros_(module.bias) nn.init.zeros_(module.bias)
@ -490,7 +497,6 @@ class GlobalContextVit(nn.Module):
k for k, _ in self.named_parameters() k for k, _ in self.named_parameters()
if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])} if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])}
@torch.jit.ignore @torch.jit.ignore
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
matcher = dict( matcher = dict(
@ -567,7 +573,6 @@ def gcvit_small(pretrained=False, **kwargs):
model_kwargs = dict( model_kwargs = dict(
depths=(3, 4, 19, 5), depths=(3, 4, 19, 5),
num_heads=(3, 6, 12, 24), num_heads=(3, 6, 12, 24),
window_size=(7, 7, 14, 7),
embed_dim=96, embed_dim=96,
mlp_ratio=2, mlp_ratio=2,
layer_scale=1e-5, layer_scale=1e-5,
@ -580,7 +585,6 @@ def gcvit_base(pretrained=False, **kwargs):
model_kwargs = dict( model_kwargs = dict(
depths=(3, 4, 19, 5), depths=(3, 4, 19, 5),
num_heads=(4, 8, 16, 32), num_heads=(4, 8, 16, 32),
window_size=(7, 7, 14, 7),
embed_dim=128, embed_dim=128,
mlp_ratio=2, mlp_ratio=2,
layer_scale=1e-5, layer_scale=1e-5,

@ -82,6 +82,7 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth' url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth'
), ),
'coatnet_2_rw_224': _cfg(url=''), 'coatnet_2_rw_224': _cfg(url=''),
'coatnet_3_rw_224': _cfg(url=''),
# Highly experimental configs # Highly experimental configs
'coatnet_bn_0_rw_224': _cfg( 'coatnet_bn_0_rw_224': _cfg(
@ -94,6 +95,8 @@ default_cfgs = {
'coatnet_rmlp_0_rw_224': _cfg(url=''), 'coatnet_rmlp_0_rw_224': _cfg(url=''),
'coatnet_rmlp_1_rw_224': _cfg( 'coatnet_rmlp_1_rw_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'),
'coatnet_rmlp_2_rw_224': _cfg(url=''),
'coatnet_rmlp_3_rw_224': _cfg(url=''),
'coatnet_nano_cc_224': _cfg(url=''), 'coatnet_nano_cc_224': _cfg(url=''),
'coatnext_nano_rw_224': _cfg(url=''), 'coatnext_nano_rw_224': _cfg(url=''),
@ -110,13 +113,31 @@ default_cfgs = {
'maxvit_nano_rw_256': _cfg( 'maxvit_nano_rw_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth',
input_size=(3, 256, 256), pool_size=(8, 8)), input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_rw_224': _cfg(url=''), 'maxvit_tiny_rw_224': _cfg(
'maxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'),
'maxvit_tiny_rw_256': _cfg(
url='',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_pico_rw_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_nano_rw_256': _cfg( 'maxvit_rmlp_nano_rw_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth',
input_size=(3, 256, 256), pool_size=(8, 8)), input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_tiny_rw_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_small_rw_224': _cfg(
url=''),
'maxvit_rmlp_small_rw_256': _cfg(
url='',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxxvit_nano_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_small_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
# Trying to be like the MaxViT paper configs # Trying to be like the MaxViT paper configs
'maxvit_tiny_224': _cfg(url=''), 'maxvit_tiny_224': _cfg(url=''),
@ -139,7 +160,7 @@ class MaxxVitTransformerCfg:
pool_type: str = 'avg2' pool_type: str = 'avg2'
rel_pos_type: str = 'bias' rel_pos_type: str = 'bias'
rel_pos_dim: int = 512 # for relative position types w/ MLP rel_pos_dim: int = 512 # for relative position types w/ MLP
partition_stride: int = 32 partition_ratio: int = 32
window_size: Optional[Tuple[int, int]] = None window_size: Optional[Tuple[int, int]] = None
grid_size: Optional[Tuple[int, int]] = None grid_size: Optional[Tuple[int, int]] = None
init_values: Optional[float] = None init_values: Optional[float] = None
@ -173,7 +194,7 @@ class MaxxVitConvCfg:
attn_layer: str = 'se' attn_layer: str = 'se'
attn_act_layer: str = 'silu' attn_act_layer: str = 'silu'
attn_ratio: float = 0.25 attn_ratio: float = 0.25
init_values: Optional[float] = 1e-5 # for ConvNeXt block init_values: Optional[float] = 1e-6 # for ConvNeXt block, ignored by MBConv
act_layer: str = 'gelu' act_layer: str = 'gelu'
norm_layer: str = '' norm_layer: str = ''
norm_layer_cl: str = '' norm_layer_cl: str = ''
@ -209,10 +230,12 @@ def _rw_coat_cfg(
pool_type='avg2', pool_type='avg2',
conv_output_bias=False, conv_output_bias=False,
conv_attn_early=False, conv_attn_early=False,
conv_attn_act_layer='relu',
conv_norm_layer='', conv_norm_layer='',
transformer_shortcut_bias=True, transformer_shortcut_bias=True,
transformer_norm_layer='layernorm2d', transformer_norm_layer='layernorm2d',
transformer_norm_layer_cl='layernorm', transformer_norm_layer_cl='layernorm',
init_values=None,
rel_pos_type='bias', rel_pos_type='bias',
rel_pos_dim=512, rel_pos_dim=512,
): ):
@ -237,7 +260,7 @@ def _rw_coat_cfg(
expand_output=False, expand_output=False,
output_bias=conv_output_bias, output_bias=conv_output_bias,
attn_early=conv_attn_early, attn_early=conv_attn_early,
attn_act_layer='relu', attn_act_layer=conv_attn_act_layer,
act_layer='silu', act_layer='silu',
norm_layer=conv_norm_layer, norm_layer=conv_norm_layer,
), ),
@ -245,6 +268,7 @@ def _rw_coat_cfg(
expand_first=False, expand_first=False,
shortcut_bias=transformer_shortcut_bias, shortcut_bias=transformer_shortcut_bias,
pool_type=pool_type, pool_type=pool_type,
init_values=init_values,
norm_layer=transformer_norm_layer, norm_layer=transformer_norm_layer,
norm_layer_cl=transformer_norm_layer_cl, norm_layer_cl=transformer_norm_layer_cl,
rel_pos_type=rel_pos_type, rel_pos_type=rel_pos_type,
@ -263,6 +287,7 @@ def _rw_max_cfg(
transformer_norm_layer_cl='layernorm', transformer_norm_layer_cl='layernorm',
window_size=None, window_size=None,
dim_head=32, dim_head=32,
init_values=None,
rel_pos_type='bias', rel_pos_type='bias',
rel_pos_dim=512, rel_pos_dim=512,
): ):
@ -287,6 +312,7 @@ def _rw_max_cfg(
pool_type=pool_type, pool_type=pool_type,
dim_head=dim_head, dim_head=dim_head,
window_size=window_size, window_size=window_size,
init_values=init_values,
norm_layer=transformer_norm_layer, norm_layer=transformer_norm_layer,
norm_layer_cl=transformer_norm_layer_cl, norm_layer_cl=transformer_norm_layer_cl,
rel_pos_type=rel_pos_type, rel_pos_type=rel_pos_type,
@ -303,7 +329,8 @@ def _next_cfg(
transformer_norm_layer='layernorm2d', transformer_norm_layer='layernorm2d',
transformer_norm_layer_cl='layernorm', transformer_norm_layer_cl='layernorm',
window_size=None, window_size=None,
rel_pos_type='bias', init_values=1e-6,
rel_pos_type='mlp', # MLP by default for maxxvit
rel_pos_dim=512, rel_pos_dim=512,
): ):
# For experimental models with convnext instead of mbconv # For experimental models with convnext instead of mbconv
@ -313,6 +340,7 @@ def _next_cfg(
stride_mode=stride_mode, stride_mode=stride_mode,
pool_type=pool_type, pool_type=pool_type,
expand_output=False, expand_output=False,
init_values=init_values,
norm_layer=conv_norm_layer, norm_layer=conv_norm_layer,
norm_layer_cl=conv_norm_layer_cl, norm_layer_cl=conv_norm_layer_cl,
), ),
@ -320,6 +348,7 @@ def _next_cfg(
expand_first=False, expand_first=False,
pool_type=pool_type, pool_type=pool_type,
window_size=window_size, window_size=window_size,
init_values=init_values,
norm_layer=transformer_norm_layer, norm_layer=transformer_norm_layer,
norm_layer_cl=transformer_norm_layer_cl, norm_layer_cl=transformer_norm_layer_cl,
rel_pos_type=rel_pos_type, rel_pos_type=rel_pos_type,
@ -372,7 +401,21 @@ model_cfgs = dict(
embed_dim=(128, 256, 512, 1024), embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
stem_width=(64, 128), stem_width=(64, 128),
**_rw_coat_cfg(stride_mode='dw'), **_rw_coat_cfg(
stride_mode='dw',
conv_attn_act_layer='silu',
init_values=1e-6,
),
),
coatnet_3_rw_224=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2),
stem_width=(96, 192),
**_rw_coat_cfg(
stride_mode='dw',
conv_attn_act_layer='silu',
init_values=1e-6,
),
), ),
# Highly experimental configs # Highly experimental configs
@ -419,6 +462,29 @@ model_cfgs = dict(
rel_pos_dim=384, # was supposed to be 512, woops rel_pos_dim=384, # was supposed to be 512, woops
), ),
), ),
coatnet_rmlp_2_rw_224=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2),
stem_width=(64, 128),
**_rw_coat_cfg(
stride_mode='dw',
conv_attn_act_layer='silu',
init_values=1e-6,
rel_pos_type='mlp'
),
),
coatnet_rmlp_3_rw_224=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2),
stem_width=(96, 192),
**_rw_coat_cfg(
stride_mode='dw',
conv_attn_act_layer='silu',
init_values=1e-6,
rel_pos_type='mlp'
),
),
coatnet_nano_cc_224=MaxxVitCfg( coatnet_nano_cc_224=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3), depths=(3, 4, 6, 3),
@ -495,6 +561,14 @@ model_cfgs = dict(
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(), **_rw_max_cfg(),
), ),
maxvit_rmlp_pico_rw_256=MaxxVitCfg(
embed_dim=(32, 64, 128, 256),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(24, 32),
**_rw_max_cfg(rel_pos_type='mlp'),
),
maxvit_rmlp_nano_rw_256=MaxxVitCfg( maxvit_rmlp_nano_rw_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1), depths=(1, 2, 3, 1),
@ -502,6 +576,34 @@ model_cfgs = dict(
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(rel_pos_type='mlp'), **_rw_max_cfg(rel_pos_type='mlp'),
), ),
maxvit_rmlp_tiny_rw_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(rel_pos_type='mlp'),
),
maxvit_rmlp_small_rw_224=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(
rel_pos_type='mlp',
init_values=1e-6,
),
),
maxvit_rmlp_small_rw_256=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(
rel_pos_type='mlp',
init_values=1e-6,
),
),
maxvit_tiny_pm_256=MaxxVitCfg( maxvit_tiny_pm_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
@ -509,6 +611,7 @@ model_cfgs = dict(
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(), **_rw_max_cfg(),
), ),
maxxvit_nano_rw_256=MaxxVitCfg( maxxvit_nano_rw_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1), depths=(1, 2, 3, 1),
@ -517,6 +620,20 @@ model_cfgs = dict(
weight_init='normal', weight_init='normal',
**_next_cfg(), **_next_cfg(),
), ),
maxxvit_tiny_rw_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_next_cfg(),
),
maxxvit_small_rw_256=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(48, 96),
**_next_cfg(),
),
# Trying to be like the MaxViT paper configs # Trying to be like the MaxViT paper configs
maxvit_tiny_224=MaxxVitCfg( maxvit_tiny_224=MaxxVitCfg(
@ -1458,7 +1575,7 @@ def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]):
if cfg.window_size is not None: if cfg.window_size is not None:
assert cfg.grid_size assert cfg.grid_size
return cfg return cfg
partition_size = img_size[0] // cfg.partition_stride, img_size[1] // cfg.partition_stride partition_size = img_size[0] // cfg.partition_ratio, img_size[1] // cfg.partition_ratio
cfg = replace(cfg, window_size=partition_size, grid_size=partition_size) cfg = replace(cfg, window_size=partition_size, grid_size=partition_size)
return cfg return cfg
@ -1618,6 +1735,11 @@ def coatnet_2_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_2_rw_224', pretrained=pretrained, **kwargs) return _create_maxxvit('coatnet_2_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_3_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_3_rw_224', pretrained=pretrained, **kwargs)
@register_model @register_model
def coatnet_bn_0_rw_224(pretrained=False, **kwargs): def coatnet_bn_0_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_bn_0_rw_224', pretrained=pretrained, **kwargs) return _create_maxxvit('coatnet_bn_0_rw_224', pretrained=pretrained, **kwargs)
@ -1638,6 +1760,16 @@ def coatnet_rmlp_1_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_1_rw_224', pretrained=pretrained, **kwargs) return _create_maxxvit('coatnet_rmlp_1_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_rmlp_2_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_rmlp_3_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs)
@register_model @register_model
def coatnet_nano_cc_224(pretrained=False, **kwargs): def coatnet_nano_cc_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_nano_cc_224', pretrained=pretrained, **kwargs) return _create_maxxvit('coatnet_nano_cc_224', pretrained=pretrained, **kwargs)
@ -1698,11 +1830,31 @@ def maxvit_tiny_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs) return _create_maxxvit('maxvit_tiny_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_pico_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_pico_rw_256', pretrained=pretrained, **kwargs)
@register_model @register_model
def maxvit_rmlp_nano_rw_256(pretrained=False, **kwargs): def maxvit_rmlp_nano_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs) return _create_maxxvit('maxvit_rmlp_nano_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_tiny_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_tiny_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_small_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_small_rw_224', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_small_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs)
@register_model @register_model
def maxvit_tiny_pm_256(pretrained=False, **kwargs): def maxvit_tiny_pm_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs) return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)
@ -1713,6 +1865,16 @@ def maxxvit_nano_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxxvit_nano_rw_256', pretrained=pretrained, **kwargs) return _create_maxxvit('maxxvit_nano_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxxvit_tiny_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxxvit_tiny_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxxvit_small_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxxvit_small_rw_256', pretrained=pretrained, **kwargs)
@register_model @register_model
def maxvit_tiny_224(pretrained=False, **kwargs): def maxvit_tiny_224(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_224', pretrained=pretrained, **kwargs) return _create_maxxvit('maxvit_tiny_224', pretrained=pretrained, **kwargs)

Loading…
Cancel
Save