diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 15000b40..e64bd0ef 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -21,111 +21,13 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import named_apply, build_model_with_cfg, checkpoint_seq from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \ create_conv2d, get_act_layer, make_divisible, to_ntuple +from ._pretrained import generate_defaults from .registry import register_model __all__ = ['ConvNeXt'] # model_registry will add each entrypoint fn to this -def _cfg(url='', **kwargs): - return { - 'url': url, - 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.875, 'interpolation': 'bicubic', - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stem.0', 'classifier': 'head.fc', - **kwargs - } - - -default_cfgs = dict( - # timm specific variants - convnext_atto=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth', - test_input_size=(3, 288, 288), test_crop_pct=0.95), - convnext_atto_ols=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth', - test_input_size=(3, 288, 288), test_crop_pct=0.95), - convnext_femto=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth', - test_input_size=(3, 288, 288), test_crop_pct=0.95), - convnext_femto_ols=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth', - test_input_size=(3, 288, 288), test_crop_pct=0.95), - convnext_pico=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth', - test_input_size=(3, 288, 288), test_crop_pct=0.95), - convnext_pico_ols=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth', - crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), - convnext_nano=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth', - crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), - convnext_nano_ols=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth', - crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), - convnext_tiny_hnf=_cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth', - crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), - - convnext_tiny=_cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", - test_input_size=(3, 288, 288), test_crop_pct=1.0), - convnext_small=_cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", - test_input_size=(3, 288, 288), test_crop_pct=1.0), - convnext_base=_cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", - test_input_size=(3, 288, 288), test_crop_pct=1.0), - convnext_large=_cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", - test_input_size=(3, 288, 288), test_crop_pct=1.0), - - convnext_tiny_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth', - test_input_size=(3, 288, 288), test_crop_pct=1.0), - convnext_small_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth', - test_input_size=(3, 288, 288), test_crop_pct=1.0), - convnext_base_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth', - test_input_size=(3, 288, 288), test_crop_pct=1.0), - convnext_large_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth', - test_input_size=(3, 288, 288), test_crop_pct=1.0), - convnext_xlarge_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth', - test_input_size=(3, 288, 288), test_crop_pct=1.0), - - convnext_tiny_384_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth', - input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), - convnext_small_384_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth', - input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), - convnext_base_384_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth', - input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), - convnext_large_384_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth', - input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), - convnext_xlarge_384_in22ft1k=_cfg( - url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth', - input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0), - - convnext_tiny_in22k=_cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", num_classes=21841), - convnext_small_in22k=_cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", num_classes=21841), - convnext_base_in22k=_cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841), - convnext_large_in22k=_cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841), - convnext_xlarge_in22k=_cfg( - url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841), -) - - class ConvNeXtBlock(nn.Module): """ ConvNeXt Block There are two equivalent implementations: @@ -459,6 +361,107 @@ def _create_convnext(variant, pretrained=False, **kwargs): return model + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = generate_defaults({ + # timm specific variants + 'convnext_atto.timm_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'convnext_atto_ols.timm_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'convnext_femto.timm_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'convnext_femto_ols.timm_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'convnext_pico.timm_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth', + test_input_size=(3, 288, 288), test_crop_pct=0.95), + 'convnext_pico_ols.timm_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_nano.timm_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_nano_ols.timm_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_tiny_hnf.timm_in1k': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + + 'convnext_tiny.fb_in1k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_small.fb_in1k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_base.fb_in1k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_large.fb_in1k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_xlarge.untrained': _cfg(), + + 'convnext_tiny.fb_in22k_ft_in1k': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_small.fb_in22k_ft_in1k': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_base.fb_in22k_ft_in1k': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_large.fb_in22k_ft_in1k': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_xlarge.fb_in22k_ft_in1k': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth', + test_input_size=(3, 288, 288), test_crop_pct=1.0), + + 'convnext_tiny.fb_in22k_ft_in1k_384': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnext_small..fb_in22k_ft_in1k_384': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnext_base.fb_in22k_ft_in1k_384': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnext_large.fb_in22k_ft_in1k_384': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnext_xlarge.fb_in22k_ft_in1k_384': _cfg( + url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + + 'convnext_tiny_in22k.fb_in22k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", num_classes=21841), + 'convnext_small_in22k.fb_in22k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", num_classes=21841), + 'convnext_base_in22k.fb_in22k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841), + 'convnext_large_in22k.fb_in22k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841), + 'convnext_xlarge_in22k.fb_in22k': _cfg( + url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841), +}) + + @register_model def convnext_atto(pretrained=False, **kwargs): # timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M @@ -569,105 +572,7 @@ def convnext_large(pretrained=False, **kwargs): @register_model -def convnext_tiny_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) - model = _create_convnext('convnext_tiny_in22ft1k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_small_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) - model = _create_convnext('convnext_small_in22ft1k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_base_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) - model = _create_convnext('convnext_base_in22ft1k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_large_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) - model = _create_convnext('convnext_large_in22ft1k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_xlarge_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) - model = _create_convnext('convnext_xlarge_in22ft1k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_tiny_384_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) - model = _create_convnext('convnext_tiny_384_in22ft1k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_small_384_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) - model = _create_convnext('convnext_small_384_in22ft1k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_base_384_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) - model = _create_convnext('convnext_base_384_in22ft1k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_large_384_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) - model = _create_convnext('convnext_large_384_in22ft1k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_xlarge_384_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) - model = _create_convnext('convnext_xlarge_384_in22ft1k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_tiny_in22k(pretrained=False, **kwargs): - model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) - model = _create_convnext('convnext_tiny_in22k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_small_in22k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) - model = _create_convnext('convnext_small_in22k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_base_in22k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) - model = _create_convnext('convnext_base_in22k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_large_in22k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) - model = _create_convnext('convnext_large_in22k', pretrained=pretrained, **model_args) - return model - - -@register_model -def convnext_xlarge_in22k(pretrained=False, **kwargs): +def convnext_xlarge(pretrained=False, **kwargs): model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) - model = _create_convnext('convnext_xlarge_in22k', pretrained=pretrained, **model_args) + model = _create_convnext('convnext_xlarge', pretrained=pretrained, **model_args) return model diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 51c683c0..3c0efc96 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -366,11 +366,11 @@ default_cfgs = { 'tf_efficientnetv2_m': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m-cc09e0cd.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'tf_efficientnetv2_l': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l-d664b728.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'tf_efficientnetv2_s_in21ft1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21ft1k-d7dafa41.pth', @@ -379,15 +379,15 @@ default_cfgs = { 'tf_efficientnetv2_m_in21ft1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21ft1k-bf41664a.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'tf_efficientnetv2_l_in21ft1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21ft1k-60127a9d.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'tf_efficientnetv2_xl_in21ft1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_xl_in21ft1k-06c35c48.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), - input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0), + input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'tf_efficientnetv2_s_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_s_21k-6337ad01.pth', @@ -396,15 +396,15 @@ default_cfgs = { 'tf_efficientnetv2_m_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_m_21k-361418a2.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, - input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'tf_efficientnetv2_l_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_l_21k-91a19ec9.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, - input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0), + input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'tf_efficientnetv2_xl_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_xl_in21k-fd7e8abf.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), num_classes=21843, - input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0), + input_size=(3, 384, 384), test_input_size=(3, 512, 512), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'tf_efficientnetv2_b0': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-effv2-weights/tf_efficientnetv2_b0-c7cc451f.pth', diff --git a/timm/models/layers/activations.py b/timm/models/layers/activations.py index e16b3bd3..2f5476c0 100644 --- a/timm/models/layers/activations.py +++ b/timm/models/layers/activations.py @@ -143,3 +143,17 @@ class GELU(nn.Module): def forward(self, input: torch.Tensor) -> torch.Tensor: return F.gelu(input) + + +def gelu_tanh(x: torch.Tensor, inplace: bool = False) -> torch.Tensor: + return F.gelu(x, approximate='tanh') + + +class GELUTanh(nn.Module): + """Applies the Gaussian Error Linear Units function (w/ dummy inplace arg) + """ + def __init__(self, inplace: bool = False): + super(GELUTanh, self).__init__() + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.gelu(input, approximate='tanh') diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py index a3044a3d..0b02398d 100644 --- a/timm/models/layers/create_act.py +++ b/timm/models/layers/create_act.py @@ -28,6 +28,7 @@ _ACT_FN_DEFAULT = dict( celu=F.celu, selu=F.selu, gelu=gelu, + gelu_tanh=gelu_tanh, sigmoid=sigmoid, tanh=tanh, hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid, @@ -71,6 +72,7 @@ _ACT_LAYER_DEFAULT = dict( celu=nn.CELU, selu=nn.SELU, gelu=GELU, + gelu_tanh=GELUTanh, sigmoid=Sigmoid, tanh=Tanh, hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid, diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index bd529245..13fd7abf 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -52,114 +52,19 @@ from .helpers import build_model_with_cfg, checkpoint_seq, named_apply from .fx_features import register_notrace_function from .layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, LayerNorm2d, LayerNorm from .layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d +from .layers import SelectAdaptivePool2d, create_pool2d from .layers import to_2tuple, extend_tuple, make_divisible, _assert +from ._pretrained import generate_defaults from .registry import register_model from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move these to common location __all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit'] -def _cfg(url='', **kwargs): - return { - 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), - 'crop_pct': 0.95, 'interpolation': 'bicubic', - 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), - 'first_conv': 'stem.conv1', 'classifier': 'head.fc', - 'fixed_input_size': True, - **kwargs - } - - -default_cfgs = { - # Fiddling with configs / defaults / still pretraining - 'coatnet_pico_rw_224': _cfg(url=''), - 'coatnet_nano_rw_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth', - crop_pct=0.9), - 'coatnet_0_rw_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'), - 'coatnet_1_rw_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth' - ), - 'coatnet_2_rw_224': _cfg(url=''), - 'coatnet_3_rw_224': _cfg(url=''), - - # Highly experimental configs - 'coatnet_bn_0_rw_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth', - mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, - crop_pct=0.95), - 'coatnet_rmlp_nano_rw_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth', - crop_pct=0.9), - 'coatnet_rmlp_0_rw_224': _cfg(url=''), - 'coatnet_rmlp_1_rw_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'), - 'coatnet_rmlp_2_rw_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'), - 'coatnet_rmlp_3_rw_224': _cfg(url=''), - 'coatnet_nano_cc_224': _cfg(url=''), - 'coatnext_nano_rw_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth', - crop_pct=0.9), - - # Trying to be like the CoAtNet paper configs - 'coatnet_0_224': _cfg(url=''), - 'coatnet_1_224': _cfg(url=''), - 'coatnet_2_224': _cfg(url=''), - 'coatnet_3_224': _cfg(url=''), - 'coatnet_4_224': _cfg(url=''), - 'coatnet_5_224': _cfg(url=''), - - # Experimental configs - 'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_nano_rw_256': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth', - input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_tiny_rw_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'), - 'maxvit_tiny_rw_256': _cfg( - url='', - input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_rmlp_pico_rw_256': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth', - input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_rmlp_nano_rw_256': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth', - input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_rmlp_tiny_rw_256': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth', - input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_rmlp_small_rw_224': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth', - crop_pct=0.9, - ), - 'maxvit_rmlp_small_rw_256': _cfg( - url='', - input_size=(3, 256, 256), pool_size=(8, 8)), - - 'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), - - 'maxxvit_rmlp_nano_rw_256': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth', - input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxxvit_rmlp_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxxvit_rmlp_small_rw_256': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth', - input_size=(3, 256, 256), pool_size=(8, 8)), - - # Trying to be like the MaxViT paper configs - 'maxvit_tiny_224': _cfg(url=''), - 'maxvit_small_224': _cfg(url=''), - 'maxvit_base_224': _cfg(url=''), - 'maxvit_large_224': _cfg(url=''), - 'maxvit_xlarge_224': _cfg(url=''), -} - - @dataclass class MaxxVitTransformerCfg: dim_head: int = 32 + head_first: bool = True # head ordering in qkv channel dim expand_ratio: float = 4.0 expand_first: bool = True shortcut_bias: bool = True @@ -199,6 +104,7 @@ class MaxxVitConvCfg: stride_mode: str = 'dw' # stride done via one of 'pool', '1x1', 'dw' pool_type: str = 'avg2' downsample_pool_type: str = 'avg2' + padding: str = '' attn_early: bool = False # apply attn between conv2 and norm2, instead of after norm2 attn_layer: str = 'se' attn_act_layer: str = 'silu' @@ -228,499 +134,56 @@ class MaxxVitCfg: depths: Tuple[int, ...] = (2, 3, 5, 2) block_type: Tuple[Union[str, Tuple[str, ...]], ...] = ('C', 'C', 'T', 'T') stem_width: Union[int, Tuple[int, int]] = 64 - stem_bias: bool = True + stem_bias: bool = False conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg() transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg() + head_hidden_size: int = None weight_init: str = 'vit_eff' -def _rw_coat_cfg( - stride_mode='pool', - pool_type='avg2', - conv_output_bias=False, - conv_attn_early=False, - conv_attn_act_layer='relu', - conv_norm_layer='', - transformer_shortcut_bias=True, - transformer_norm_layer='layernorm2d', - transformer_norm_layer_cl='layernorm', - init_values=None, - rel_pos_type='bias', - rel_pos_dim=512, -): - # 'RW' timm variant models were created and trained before seeing https://github.com/google-research/maxvit - # Common differences for initial timm models: - # - pre-norm layer in MZBConv included an activation after norm - # - mbconv expansion calculated from input instead of output chs - # - mbconv shortcut and final 1x1 conv did not have a bias - # - SE act layer was relu, not silu - # - mbconv uses silu in timm, not gelu - # - expansion in attention block done via output proj, not input proj - # Variable differences (evolved over training initial models): - # - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat) - # - SE attention was between conv2 and norm/act - # - default to avg pool for mbconv downsample instead of 1x1 or dw conv - # - transformer block shortcut has no bias - return dict( - conv_cfg=MaxxVitConvCfg( - stride_mode=stride_mode, - pool_type=pool_type, - pre_norm_act=True, - expand_output=False, - output_bias=conv_output_bias, - attn_early=conv_attn_early, - attn_act_layer=conv_attn_act_layer, - act_layer='silu', - norm_layer=conv_norm_layer, - ), - transformer_cfg=MaxxVitTransformerCfg( - expand_first=False, - shortcut_bias=transformer_shortcut_bias, - pool_type=pool_type, - init_values=init_values, - norm_layer=transformer_norm_layer, - norm_layer_cl=transformer_norm_layer_cl, - rel_pos_type=rel_pos_type, - rel_pos_dim=rel_pos_dim, - ), - ) - +class Attention2d(nn.Module): + """ multi-head attention for 2D NCHW tensors""" + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + dim_head: int = 32, + bias: bool = True, + expand_first: bool = True, + head_first: bool = True, + rel_pos_cls: Callable = None, + attn_drop: float = 0., + proj_drop: float = 0. + ): + super().__init__() + dim_out = dim_out or dim + dim_attn = dim_out if expand_first else dim + self.num_heads = dim_attn // dim_head + self.dim_head = dim_head + self.head_first = head_first + self.scale = dim_head ** -0.5 -def _rw_max_cfg( - stride_mode='dw', - pool_type='avg2', - conv_output_bias=False, - conv_attn_ratio=1 / 16, - conv_norm_layer='', - transformer_norm_layer='layernorm2d', - transformer_norm_layer_cl='layernorm', - window_size=None, - dim_head=32, - init_values=None, - rel_pos_type='bias', - rel_pos_dim=512, -): - # 'RW' timm variant models were created and trained before seeing https://github.com/google-research/maxvit - # Differences of initial timm models: - # - mbconv expansion calculated from input instead of output chs - # - mbconv shortcut and final 1x1 conv did not have a bias - # - mbconv uses silu in timm, not gelu - # - expansion in attention block done via output proj, not input proj - return dict( - conv_cfg=MaxxVitConvCfg( - stride_mode=stride_mode, - pool_type=pool_type, - expand_output=False, - output_bias=conv_output_bias, - attn_ratio=conv_attn_ratio, - act_layer='silu', - norm_layer=conv_norm_layer, - ), - transformer_cfg=MaxxVitTransformerCfg( - expand_first=False, - pool_type=pool_type, - dim_head=dim_head, - window_size=window_size, - init_values=init_values, - norm_layer=transformer_norm_layer, - norm_layer_cl=transformer_norm_layer_cl, - rel_pos_type=rel_pos_type, - rel_pos_dim=rel_pos_dim, - ), - ) + self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias) + self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias) + self.proj_drop = nn.Dropout(proj_drop) + def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + B, C, H, W = x.shape -def _next_cfg( - stride_mode='dw', - pool_type='avg2', - conv_norm_layer='layernorm2d', - conv_norm_layer_cl='layernorm', - transformer_norm_layer='layernorm2d', - transformer_norm_layer_cl='layernorm', - window_size=None, - init_values=1e-6, - rel_pos_type='mlp', # MLP by default for maxxvit - rel_pos_dim=512, -): - # For experimental models with convnext instead of mbconv - init_values = to_2tuple(init_values) - return dict( - conv_cfg=MaxxVitConvCfg( - block_type='convnext', - stride_mode=stride_mode, - pool_type=pool_type, - expand_output=False, - init_values=init_values[0], - norm_layer=conv_norm_layer, - norm_layer_cl=conv_norm_layer_cl, - ), - transformer_cfg=MaxxVitTransformerCfg( - expand_first=False, - pool_type=pool_type, - window_size=window_size, - init_values=init_values[1], - norm_layer=transformer_norm_layer, - norm_layer_cl=transformer_norm_layer_cl, - rel_pos_type=rel_pos_type, - rel_pos_dim=rel_pos_dim, - ), - ) + if self.head_first: + q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2) + else: + q, k, v = self.qkv(x).reshape(B, 3, self.num_heads, self.dim_head, -1).unbind(1) - -model_cfgs = dict( - # Fiddling with configs / defaults / still pretraining - coatnet_pico_rw_224=MaxxVitCfg( - embed_dim=(64, 128, 256, 512), - depths=(2, 3, 5, 2), - stem_width=(32, 64), - **_rw_max_cfg( # using newer max defaults here - conv_output_bias=True, - conv_attn_ratio=0.25, - ), - ), - coatnet_nano_rw_224=MaxxVitCfg( - embed_dim=(64, 128, 256, 512), - depths=(3, 4, 6, 3), - stem_width=(32, 64), - **_rw_max_cfg( # using newer max defaults here - stride_mode='pool', - conv_output_bias=True, - conv_attn_ratio=0.25, - ), - ), - coatnet_0_rw_224=MaxxVitCfg( - embed_dim=(96, 192, 384, 768), - depths=(2, 3, 7, 2), # deeper than paper '0' model - stem_width=(32, 64), - **_rw_coat_cfg( - conv_attn_early=True, - transformer_shortcut_bias=False, - ), - ), - coatnet_1_rw_224=MaxxVitCfg( - embed_dim=(96, 192, 384, 768), - depths=(2, 6, 14, 2), - stem_width=(32, 64), - **_rw_coat_cfg( - stride_mode='dw', - conv_attn_early=True, - transformer_shortcut_bias=False, - ) - ), - coatnet_2_rw_224=MaxxVitCfg( - embed_dim=(128, 256, 512, 1024), - depths=(2, 6, 14, 2), - stem_width=(64, 128), - **_rw_coat_cfg( - stride_mode='dw', - conv_attn_act_layer='silu', - init_values=1e-6, - ), - ), - coatnet_3_rw_224=MaxxVitCfg( - embed_dim=(192, 384, 768, 1536), - depths=(2, 6, 14, 2), - stem_width=(96, 192), - **_rw_coat_cfg( - stride_mode='dw', - conv_attn_act_layer='silu', - init_values=1e-6, - ), - ), - - # Highly experimental configs - coatnet_bn_0_rw_224=MaxxVitCfg( - embed_dim=(96, 192, 384, 768), - depths=(2, 3, 7, 2), # deeper than paper '0' model - stem_width=(32, 64), - **_rw_coat_cfg( - stride_mode='dw', - conv_attn_early=True, - transformer_shortcut_bias=False, - transformer_norm_layer='batchnorm2d', - ) - ), - coatnet_rmlp_nano_rw_224=MaxxVitCfg( - embed_dim=(64, 128, 256, 512), - depths=(3, 4, 6, 3), - stem_width=(32, 64), - **_rw_max_cfg( - conv_output_bias=True, - conv_attn_ratio=0.25, - rel_pos_type='mlp', - rel_pos_dim=384, - ), - ), - coatnet_rmlp_0_rw_224=MaxxVitCfg( - embed_dim=(96, 192, 384, 768), - depths=(2, 3, 7, 2), # deeper than paper '0' model - stem_width=(32, 64), - **_rw_coat_cfg( - stride_mode='dw', - rel_pos_type='mlp', - ), - ), - coatnet_rmlp_1_rw_224=MaxxVitCfg( - embed_dim=(96, 192, 384, 768), - depths=(2, 6, 14, 2), - stem_width=(32, 64), - **_rw_coat_cfg( - pool_type='max', - conv_attn_early=True, - transformer_shortcut_bias=False, - rel_pos_type='mlp', - rel_pos_dim=384, # was supposed to be 512, woops - ), - ), - coatnet_rmlp_2_rw_224=MaxxVitCfg( - embed_dim=(128, 256, 512, 1024), - depths=(2, 6, 14, 2), - stem_width=(64, 128), - **_rw_coat_cfg( - stride_mode='dw', - conv_attn_act_layer='silu', - init_values=1e-6, - rel_pos_type='mlp' - ), - ), - coatnet_rmlp_3_rw_224=MaxxVitCfg( - embed_dim=(192, 384, 768, 1536), - depths=(2, 6, 14, 2), - stem_width=(96, 192), - **_rw_coat_cfg( - stride_mode='dw', - conv_attn_act_layer='silu', - init_values=1e-6, - rel_pos_type='mlp' - ), - ), - - coatnet_nano_cc_224=MaxxVitCfg( - embed_dim=(64, 128, 256, 512), - depths=(3, 4, 6, 3), - stem_width=(32, 64), - block_type=('C', 'C', ('C', 'T'), ('C', 'T')), - **_rw_coat_cfg(), - ), - coatnext_nano_rw_224=MaxxVitCfg( - embed_dim=(64, 128, 256, 512), - depths=(3, 4, 6, 3), - stem_width=(32, 64), - weight_init='normal', - **_next_cfg( - rel_pos_type='bias', - init_values=(1e-5, None) - ), - ), - - # Trying to be like the CoAtNet paper configs - coatnet_0_224=MaxxVitCfg( - embed_dim=(96, 192, 384, 768), - depths=(2, 3, 5, 2), - stem_width=64, - ), - coatnet_1_224=MaxxVitCfg( - embed_dim=(96, 192, 384, 768), - depths=(2, 6, 14, 2), - stem_width=64, - ), - coatnet_2_224=MaxxVitCfg( - embed_dim=(128, 256, 512, 1024), - depths=(2, 6, 14, 2), - stem_width=128, - ), - coatnet_3_224=MaxxVitCfg( - embed_dim=(192, 384, 768, 1536), - depths=(2, 6, 14, 2), - stem_width=192, - ), - coatnet_4_224=MaxxVitCfg( - embed_dim=(192, 384, 768, 1536), - depths=(2, 12, 28, 2), - stem_width=192, - ), - coatnet_5_224=MaxxVitCfg( - embed_dim=(256, 512, 1280, 2048), - depths=(2, 12, 28, 2), - stem_width=192, - ), - - # Experimental MaxVit configs - maxvit_pico_rw_256=MaxxVitCfg( - embed_dim=(32, 64, 128, 256), - depths=(2, 2, 5, 2), - block_type=('M',) * 4, - stem_width=(24, 32), - **_rw_max_cfg(), - ), - maxvit_nano_rw_256=MaxxVitCfg( - embed_dim=(64, 128, 256, 512), - depths=(1, 2, 3, 1), - block_type=('M',) * 4, - stem_width=(32, 64), - **_rw_max_cfg(), - ), - maxvit_tiny_rw_224=MaxxVitCfg( - embed_dim=(64, 128, 256, 512), - depths=(2, 2, 5, 2), - block_type=('M',) * 4, - stem_width=(32, 64), - **_rw_max_cfg(), - ), - maxvit_tiny_rw_256=MaxxVitCfg( - embed_dim=(64, 128, 256, 512), - depths=(2, 2, 5, 2), - block_type=('M',) * 4, - stem_width=(32, 64), - **_rw_max_cfg(), - ), - - maxvit_rmlp_pico_rw_256=MaxxVitCfg( - embed_dim=(32, 64, 128, 256), - depths=(2, 2, 5, 2), - block_type=('M',) * 4, - stem_width=(24, 32), - **_rw_max_cfg(rel_pos_type='mlp'), - ), - maxvit_rmlp_nano_rw_256=MaxxVitCfg( - embed_dim=(64, 128, 256, 512), - depths=(1, 2, 3, 1), - block_type=('M',) * 4, - stem_width=(32, 64), - **_rw_max_cfg(rel_pos_type='mlp'), - ), - maxvit_rmlp_tiny_rw_256=MaxxVitCfg( - embed_dim=(64, 128, 256, 512), - depths=(2, 2, 5, 2), - block_type=('M',) * 4, - stem_width=(32, 64), - **_rw_max_cfg(rel_pos_type='mlp'), - ), - maxvit_rmlp_small_rw_224=MaxxVitCfg( - embed_dim=(96, 192, 384, 768), - depths=(2, 2, 5, 2), - block_type=('M',) * 4, - stem_width=(32, 64), - **_rw_max_cfg( - rel_pos_type='mlp', - init_values=1e-6, - ), - ), - maxvit_rmlp_small_rw_256=MaxxVitCfg( - embed_dim=(96, 192, 384, 768), - depths=(2, 2, 5, 2), - block_type=('M',) * 4, - stem_width=(32, 64), - **_rw_max_cfg( - rel_pos_type='mlp', - init_values=1e-6, - ), - ), - - maxvit_tiny_pm_256=MaxxVitCfg( - embed_dim=(64, 128, 256, 512), - depths=(2, 2, 5, 2), - block_type=('PM',) * 4, - stem_width=(32, 64), - **_rw_max_cfg(), - ), - - maxxvit_rmlp_nano_rw_256=MaxxVitCfg( - embed_dim=(64, 128, 256, 512), - depths=(1, 2, 3, 1), - block_type=('M',) * 4, - stem_width=(32, 64), - weight_init='normal', - **_next_cfg(), - ), - maxxvit_rmlp_tiny_rw_256=MaxxVitCfg( - embed_dim=(64, 128, 256, 512), - depths=(2, 2, 5, 2), - block_type=('M',) * 4, - stem_width=(32, 64), - **_next_cfg(), - ), - maxxvit_rmlp_small_rw_256=MaxxVitCfg( - embed_dim=(96, 192, 384, 768), - depths=(2, 2, 5, 2), - block_type=('M',) * 4, - stem_width=(48, 96), - **_next_cfg(), - ), - - # Trying to be like the MaxViT paper configs - maxvit_tiny_224=MaxxVitCfg( - embed_dim=(64, 128, 256, 512), - depths=(2, 2, 5, 2), - block_type=('M',) * 4, - stem_width=64, - ), - maxvit_small_224=MaxxVitCfg( - embed_dim=(96, 192, 384, 768), - depths=(2, 2, 5, 2), - block_type=('M',) * 4, - stem_width=64, - ), - maxvit_base_224=MaxxVitCfg( - embed_dim=(96, 192, 384, 768), - depths=(2, 6, 14, 2), - block_type=('M',) * 4, - stem_width=64, - ), - maxvit_large_224=MaxxVitCfg( - embed_dim=(128, 256, 512, 1024), - depths=(2, 6, 14, 2), - block_type=('M',) * 4, - stem_width=128, - ), - maxvit_xlarge_224=MaxxVitCfg( - embed_dim=(192, 384, 768, 1536), - depths=(2, 6, 14, 2), - block_type=('M',) * 4, - stem_width=192, - ), - -) - - -class Attention2d(nn.Module): - """ multi-head attention for 2D NCHW tensors""" - def __init__( - self, - dim: int, - dim_out: Optional[int] = None, - dim_head: int = 32, - bias: bool = True, - expand_first: bool = True, - rel_pos_cls: Callable = None, - attn_drop: float = 0., - proj_drop: float = 0. - ): - super().__init__() - dim_out = dim_out or dim - dim_attn = dim_out if expand_first else dim - self.num_heads = dim_attn // dim_head - self.dim_head = dim_head - self.scale = dim_head ** -0.5 - - self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias) - self.rel_pos = rel_pos_cls(num_heads=self.num_heads) if rel_pos_cls else None - self.attn_drop = nn.Dropout(attn_drop) - self.proj = nn.Conv2d(dim_attn, dim_out, 1, bias=bias) - self.proj_drop = nn.Dropout(proj_drop) - - def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): - B, C, H, W = x.shape - - q, k, v = self.qkv(x).view(B, self.num_heads, self.dim_head * 3, -1).chunk(3, dim=2) - - attn = (q.transpose(-2, -1) @ k) * self.scale - if self.rel_pos is not None: - attn = self.rel_pos(attn) - elif shared_rel_pos is not None: - attn = attn + shared_rel_pos - attn = attn.softmax(dim=-1) - attn = self.attn_drop(attn) + attn = (q.transpose(-2, -1) @ k) * self.scale + if self.rel_pos is not None: + attn = self.rel_pos(attn) + elif shared_rel_pos is not None: + attn = attn + shared_rel_pos + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) x = self.proj(x) @@ -737,6 +200,7 @@ class AttentionCl(nn.Module): dim_head: int = 32, bias: bool = True, expand_first: bool = True, + head_first: bool = True, rel_pos_cls: Callable = None, attn_drop: float = 0., proj_drop: float = 0. @@ -747,6 +211,7 @@ class AttentionCl(nn.Module): assert dim_attn % dim_head == 0, 'attn dim should be divisible by head_dim' self.num_heads = dim_attn // dim_head self.dim_head = dim_head + self.head_first = head_first self.scale = dim_head ** -0.5 self.qkv = nn.Linear(dim, dim_attn * 3, bias=bias) @@ -759,7 +224,10 @@ class AttentionCl(nn.Module): B = x.shape[0] restore_shape = x.shape[:-1] - q, k, v = self.qkv(x).view(B, -1, self.num_heads, self.dim_head * 3).transpose(1, 2).chunk(3, dim=3) + if self.head_first: + q, k, v = self.qkv(x).view(B, -1, self.num_heads, self.dim_head * 3).transpose(1, 2).chunk(3, dim=3) + else: + q, k, v = self.qkv(x).reshape(B, -1, 3, self.num_heads, self.dim_head).transpose(1, 3).unbind(2) attn = (q @ k.transpose(-2, -1)) * self.scale if self.rel_pos is not None: @@ -810,18 +278,20 @@ class Downsample2d(nn.Module): dim: int, dim_out: int, pool_type: str = 'avg2', + padding: str = '', bias: bool = True, ): super().__init__() assert pool_type in ('max', 'max2', 'avg', 'avg2') if pool_type == 'max': - self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.pool = create_pool2d('max', kernel_size=3, stride=2, padding=padding or 1) elif pool_type == 'max2': - self.pool = nn.MaxPool2d(2) # kernel_size == stride == 2 + self.pool = create_pool2d('max', 2, padding=padding or 0) # kernel_size == stride == 2 elif pool_type == 'avg': - self.pool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1, count_include_pad=False) + self.pool = create_pool2d( + 'avg', kernel_size=3, stride=2, count_include_pad=False, padding=padding or 1) else: - self.pool = nn.AvgPool2d(2) # kernel_size == stride == 2 + self.pool = create_pool2d('avg', 2, padding=padding or 0) if dim != dim_out: self.expand = nn.Conv2d(dim, dim_out, 1, bias=bias) @@ -973,7 +443,8 @@ class MbConvBlock(nn.Module): groups = num_groups(cfg.group_size, mid_chs) if stride == 2: - self.shortcut = Downsample2d(in_chs, out_chs, pool_type=cfg.pool_type, bias=cfg.output_bias) + self.shortcut = Downsample2d( + in_chs, out_chs, pool_type=cfg.pool_type, bias=cfg.output_bias, padding=cfg.padding) else: self.shortcut = nn.Identity() @@ -991,14 +462,15 @@ class MbConvBlock(nn.Module): self.pre_norm = norm_act_layer(in_chs, apply_act=cfg.pre_norm_act) if stride_pool > 1: - self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type) + self.down = Downsample2d(in_chs, in_chs, pool_type=cfg.downsample_pool_type, padding=cfg.padding) else: self.down = nn.Identity() self.conv1_1x1 = create_conv2d(in_chs, mid_chs, 1, stride=stride_1) self.norm1 = norm_act_layer(mid_chs) self.conv2_kxk = create_conv2d( - mid_chs, mid_chs, cfg.kernel_size, stride=stride_2, dilation=dilation_2, groups=groups) + mid_chs, mid_chs, cfg.kernel_size, + stride=stride_2, dilation=dilation_2, groups=groups, padding=cfg.padding) attn_kwargs = {} if isinstance(cfg.attn_layer, str): @@ -1164,6 +636,8 @@ def get_rel_pos_cls(cfg: MaxxVitTransformerCfg, window_size): rel_pos_cls = partial(RelPosMlp, window_size=window_size, hidden_dim=cfg.rel_pos_dim) elif cfg.rel_pos_type == 'bias': rel_pos_cls = partial(RelPosBias, window_size=window_size) + elif cfg.rel_pos_type == 'bias_tf': + rel_pos_cls = partial(RelPosBiasTf, window_size=window_size) return rel_pos_cls @@ -1193,6 +667,7 @@ class PartitionAttentionCl(nn.Module): dim, dim_head=cfg.dim_head, bias=cfg.attn_bias, + head_first=cfg.head_first, rel_pos_cls=rel_pos_cls, attn_drop=cfg.attn_drop, proj_drop=cfg.proj_drop, @@ -1256,6 +731,7 @@ class ParallelPartitionAttention(nn.Module): dim // 2, dim_head=cfg.dim_head, bias=cfg.attn_bias, + head_first=cfg.head_first, rel_pos_cls=rel_pos_cls, attn_drop=cfg.attn_drop, proj_drop=cfg.proj_drop, @@ -1265,6 +741,7 @@ class ParallelPartitionAttention(nn.Module): dim // 2, dim_head=cfg.dim_head, bias=cfg.attn_bias, + head_first=cfg.head_first, rel_pos_cls=rel_pos_cls, attn_drop=cfg.attn_drop, proj_drop=cfg.proj_drop, @@ -1364,6 +841,7 @@ class PartitionAttention2d(nn.Module): dim, dim_head=cfg.dim_head, bias=cfg.attn_bias, + head_first=cfg.head_first, rel_pos_cls=rel_pos_cls, attn_drop=cfg.attn_drop, proj_drop=cfg.proj_drop, @@ -1413,6 +891,7 @@ class MaxxVitBlock(nn.Module): conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), use_nchw_attn: bool = False, # FIXME move to cfg? True is ~20-30% faster on TPU, 5-10% slower on GPU + use_block_attn: bool = True, # FIXME for testing ConvNeXt conv w/o block attention drop_path: float = 0., ): super().__init__() @@ -1423,11 +902,12 @@ class MaxxVitBlock(nn.Module): attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttentionCl self.nchw_attn = use_nchw_attn - self.attn_block = partition_layer(**attn_kwargs) + self.attn_block = partition_layer(**attn_kwargs) if use_block_attn else None self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs) def init_weights(self, scheme=''): - named_apply(partial(_init_transformer, scheme=scheme), self.attn_block) + if self.attn_block is not None: + named_apply(partial(_init_transformer, scheme=scheme), self.attn_block) named_apply(partial(_init_transformer, scheme=scheme), self.attn_grid) named_apply(partial(_init_conv, scheme=scheme), self.conv) @@ -1437,7 +917,8 @@ class MaxxVitBlock(nn.Module): if not self.nchw_attn: x = x.permute(0, 2, 3, 1) # to NHWC (channels-last) - x = self.attn_block(x) + if self.attn_block is not None: + x = self.attn_block(x) x = self.attn_grid(x) if not self.nchw_attn: x = x.permute(0, 3, 1, 2) # back to NCHW @@ -1544,183 +1025,1004 @@ class MaxxVitStage(nn.Module): self.blocks = nn.Sequential(*blocks) def forward(self, x): - if self.grad_checkpointing and not torch.jit.is_scripting(): - x = checkpoint_seq(self.blocks, x) - else: - x = self.blocks(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class Stem(nn.Module): + + def __init__( + self, + in_chs: int, + out_chs: int, + kernel_size: int = 3, + padding: str = '', + bias: bool = False, + act_layer: str = 'gelu', + norm_layer: str = 'batchnorm2d', + norm_eps: float = 1e-5, + ): + super().__init__() + if not isinstance(out_chs, (list, tuple)): + out_chs = to_2tuple(out_chs) + + norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) + self.out_chs = out_chs[-1] + self.stride = 2 + + self.conv1 = create_conv2d(in_chs, out_chs[0], kernel_size, stride=2, padding=padding, bias=bias) + self.norm1 = norm_act_layer(out_chs[0]) + self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1, padding=padding, bias=bias) + + def init_weights(self, scheme=''): + named_apply(partial(_init_conv, scheme=scheme), self) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.conv2(x) + return x + + +def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]): + if cfg.window_size is not None: + assert cfg.grid_size + return cfg + partition_size = img_size[0] // cfg.partition_ratio, img_size[1] // cfg.partition_ratio + cfg = replace(cfg, window_size=partition_size, grid_size=partition_size) + return cfg + + +def generate_lookup_tensor( + length: int, + max_relative_position: Optional[int] = None, +): + """Generate a one_hot lookup tensor to reindex embeddings along one dimension. + Args: + length: the length to reindex to. + max_relative_position: the maximum relative position to consider. + Relative position embeddings for distances above this threshold + are zeroed out. + Returns: + a lookup Tensor of size [length, length, vocab_size] that satisfies + ret[n,m,v] = 1{m - n + max_relative_position = v}. + """ + if max_relative_position is None: + max_relative_position = length - 1 + # Return the cached lookup tensor, otherwise compute it and cache it. + vocab_size = 2 * max_relative_position + 1 + ret = torch.zeros(length, length, vocab_size) + for i in range(length): + for x in range(length): + v = x - i + max_relative_position + if abs(x - i) > max_relative_position: + continue + ret[i, x, v] = 1 + return ret + + +def reindex_2d_einsum_lookup( + relative_position_tensor, + height: int, + width: int, + height_lookup: torch.Tensor, + width_lookup: torch.Tensor, +) -> torch.Tensor: + """Reindex 2d relative position bias with 2 independent einsum lookups. + Args: + relative_position_tensor: tensor of shape + [..., vocab_height, vocab_width, ...]. + height: height to reindex to. + width: width to reindex to. + height_lookup: one-hot height lookup + width_lookup: one-hot width lookup + Returns: + reindexed_tensor: a Tensor of shape + [..., height * width, height * width, ...] + """ + reindexed_tensor = torch.einsum('nhw,ixh->nixw', relative_position_tensor, height_lookup) + reindexed_tensor = torch.einsum('nixw,jyw->nijxy', reindexed_tensor, width_lookup) + area = height * width + return reindexed_tensor.reshape(relative_position_tensor.shape[0], area, area) + + +class RelPosBiasTf(nn.Module): + + def __init__(self, window_size, num_heads, prefix_tokens=0): + super().__init__() + assert prefix_tokens <= 1 + self.window_size = window_size + self.window_area = window_size[0] * window_size[1] + self.num_heads = num_heads + + vocab_height = 2 * window_size[0] - 1 + vocab_width = 2 * window_size[1] - 1 + self.bias_shape = (self.num_heads, vocab_height, vocab_width) + self.relative_position_bias_table = nn.Parameter(torch.zeros(self.bias_shape)) + self.register_buffer('height_lookup', generate_lookup_tensor(window_size[0]), persistent=False) + self.register_buffer('width_lookup', generate_lookup_tensor(window_size[1]), persistent=False) + self.init_weights() + + def init_weights(self): + nn.init.normal_(self.relative_position_bias_table, std=.02) + + def get_bias(self) -> torch.Tensor: + # FIXME change to not use one-hot/einsum? + return reindex_2d_einsum_lookup( + self.relative_position_bias_table, + self.window_size[0], + self.window_size[1], + self.height_lookup, + self.width_lookup + ) + + def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): + return attn + self.get_bias() + + +class NormMlpHead(nn.Module): + + def __init__( + self, + in_features, + num_classes, + hidden_size=None, + pool_type='avg', + drop_rate=0., + norm_layer=nn.LayerNorm, + act_layer=nn.Tanh, + ): + super().__init__() + self.drop_rate = drop_rate + self.num_features = in_features + + self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) + self.norm = norm_layer(in_features) + self.flatten = nn.Flatten(1) if pool_type else nn.Identity() + if hidden_size: + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(in_features, hidden_size)), + ('act', act_layer()), + ])) + self.num_features = hidden_size + else: + self.pre_logits = nn.Identity() + self.drop = nn.Dropout(self.drop_rate) + self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward(self, x, pre_logits: bool = False): + x = self.global_pool(x) + x = self.norm(x) + x = self.flatten(x) + x = self.pre_logits(x) + if pre_logits: + return x + x = self.fc(x) + return x + + +class MaxxVit(nn.Module): + """ CoaTNet + MaxVit base model. + + Highly configurable for different block compositions, tensor layouts, pooling types. + """ + + def __init__( + self, + cfg: MaxxVitCfg, + img_size: Union[int, Tuple[int, int]] = 224, + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + drop_rate: float = 0., + drop_path_rate: float = 0. + ): + super().__init__() + img_size = to_2tuple(img_size) + transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size) + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = cfg.embed_dim[-1] + self.drop_rate = drop_rate + self.grad_checkpointing = False + + self.stem = Stem( + in_chs=in_chans, + out_chs=cfg.stem_width, + padding=cfg.conv_cfg.padding, + bias=cfg.stem_bias, + act_layer=cfg.conv_cfg.act_layer, + norm_layer=cfg.conv_cfg.norm_layer, + norm_eps=cfg.conv_cfg.norm_eps, + ) + + stride = self.stem.stride + feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))]) + + num_stages = len(cfg.embed_dim) + assert len(cfg.depths) == num_stages + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] + in_chs = self.stem.out_chs + stages = [] + for i in range(num_stages): + stage_stride = 2 + out_chs = cfg.embed_dim[i] + feat_size = tuple([(r - 1) // stage_stride + 1 for r in feat_size]) + stages += [MaxxVitStage( + in_chs, + out_chs, + depth=cfg.depths[i], + block_types=cfg.block_type[i], + conv_cfg=cfg.conv_cfg, + transformer_cfg=transformer_cfg, + feat_size=feat_size, + drop_path=dpr[i], + )] + stride *= stage_stride + in_chs = out_chs + self.stages = nn.Sequential(*stages) + + final_norm_layer = partial(get_norm_layer(cfg.transformer_cfg.norm_layer), eps=cfg.transformer_cfg.norm_eps) + if cfg.head_hidden_size: + self.norm = nn.Identity() + self.head = NormMlpHead( + self.num_features, + num_classes, + hidden_size=cfg.head_hidden_size, + pool_type=global_pool, + drop_rate=drop_rate, + norm_layer=final_norm_layer, + ) + else: + # standard classifier head w/ norm, pooling, fc classifier + self.norm = final_norm_layer(self.num_features) + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) + + # Weight init (default PyTorch init works well for AdamW if scheme not set) + assert cfg.weight_init in ('', 'normal', 'trunc_normal', 'xavier_normal', 'vit_eff') + if cfg.weight_init: + named_apply(partial(self._init_weights, scheme=cfg.weight_init), self) + + def _init_weights(self, module, name, scheme=''): + if hasattr(module, 'init_weights'): + try: + module.init_weights(scheme=scheme) + except TypeError: + module.init_weights() + + @torch.jit.ignore + def no_weight_decay(self): + return { + k for k, _ in self.named_parameters() + if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + matcher = dict( + stem=r'^stem', # stem and embed + blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] + ) + return matcher + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is None: + global_pool = self.head.global_pool.pool_type + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=pre_logits) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) return x -class Stem(nn.Module): +def _rw_coat_cfg( + stride_mode='pool', + pool_type='avg2', + conv_output_bias=False, + conv_attn_early=False, + conv_attn_act_layer='relu', + conv_norm_layer='', + transformer_shortcut_bias=True, + transformer_norm_layer='layernorm2d', + transformer_norm_layer_cl='layernorm', + init_values=None, + rel_pos_type='bias', + rel_pos_dim=512, +): + # 'RW' timm variant models were created and trained before seeing https://github.com/google-research/maxvit + # Common differences for initial timm models: + # - pre-norm layer in MZBConv included an activation after norm + # - mbconv expansion calculated from input instead of output chs + # - mbconv shortcut and final 1x1 conv did not have a bias + # - SE act layer was relu, not silu + # - mbconv uses silu in timm, not gelu + # - expansion in attention block done via output proj, not input proj + # Variable differences (evolved over training initial models): + # - avg pool with kernel_size=2 favoured downsampling (instead of maxpool for coat) + # - SE attention was between conv2 and norm/act + # - default to avg pool for mbconv downsample instead of 1x1 or dw conv + # - transformer block shortcut has no bias + return dict( + conv_cfg=MaxxVitConvCfg( + stride_mode=stride_mode, + pool_type=pool_type, + pre_norm_act=True, + expand_output=False, + output_bias=conv_output_bias, + attn_early=conv_attn_early, + attn_act_layer=conv_attn_act_layer, + act_layer='silu', + norm_layer=conv_norm_layer, + ), + transformer_cfg=MaxxVitTransformerCfg( + expand_first=False, + shortcut_bias=transformer_shortcut_bias, + pool_type=pool_type, + init_values=init_values, + norm_layer=transformer_norm_layer, + norm_layer_cl=transformer_norm_layer_cl, + rel_pos_type=rel_pos_type, + rel_pos_dim=rel_pos_dim, + ), + ) - def __init__( - self, - in_chs: int, - out_chs: int, - kernel_size: int = 3, - act_layer: str = 'gelu', - norm_layer: str = 'batchnorm2d', - norm_eps: float = 1e-5, - ): - super().__init__() - if not isinstance(out_chs, (list, tuple)): - out_chs = to_2tuple(out_chs) - norm_act_layer = partial(get_norm_act_layer(norm_layer, act_layer), eps=norm_eps) - self.out_chs = out_chs[-1] - self.stride = 2 +def _rw_max_cfg( + stride_mode='dw', + pool_type='avg2', + conv_output_bias=False, + conv_attn_ratio=1 / 16, + conv_norm_layer='', + transformer_norm_layer='layernorm2d', + transformer_norm_layer_cl='layernorm', + window_size=None, + dim_head=32, + init_values=None, + rel_pos_type='bias', + rel_pos_dim=512, +): + # 'RW' timm variant models were created and trained before seeing https://github.com/google-research/maxvit + # Differences of initial timm models: + # - mbconv expansion calculated from input instead of output chs + # - mbconv shortcut and final 1x1 conv did not have a bias + # - mbconv uses silu in timm, not gelu + # - expansion in attention block done via output proj, not input proj + return dict( + conv_cfg=MaxxVitConvCfg( + stride_mode=stride_mode, + pool_type=pool_type, + expand_output=False, + output_bias=conv_output_bias, + attn_ratio=conv_attn_ratio, + act_layer='silu', + norm_layer=conv_norm_layer, + ), + transformer_cfg=MaxxVitTransformerCfg( + expand_first=False, + pool_type=pool_type, + dim_head=dim_head, + window_size=window_size, + init_values=init_values, + norm_layer=transformer_norm_layer, + norm_layer_cl=transformer_norm_layer_cl, + rel_pos_type=rel_pos_type, + rel_pos_dim=rel_pos_dim, + ), + ) - self.conv1 = create_conv2d(in_chs, out_chs[0], kernel_size, stride=2) - self.norm1 = norm_act_layer(out_chs[0]) - self.conv2 = create_conv2d(out_chs[0], out_chs[1], kernel_size, stride=1) - def init_weights(self, scheme=''): - named_apply(partial(_init_conv, scheme=scheme), self) +def _next_cfg( + stride_mode='dw', + pool_type='avg2', + conv_norm_layer='layernorm2d', + conv_norm_layer_cl='layernorm', + transformer_norm_layer='layernorm2d', + transformer_norm_layer_cl='layernorm', + window_size=None, + init_values=1e-6, + rel_pos_type='mlp', # MLP by default for maxxvit + rel_pos_dim=512, +): + # For experimental models with convnext instead of mbconv + init_values = to_2tuple(init_values) + return dict( + conv_cfg=MaxxVitConvCfg( + block_type='convnext', + stride_mode=stride_mode, + pool_type=pool_type, + expand_output=False, + init_values=init_values[0], + norm_layer=conv_norm_layer, + norm_layer_cl=conv_norm_layer_cl, + ), + transformer_cfg=MaxxVitTransformerCfg( + expand_first=False, + pool_type=pool_type, + window_size=window_size, + init_values=init_values[1], + norm_layer=transformer_norm_layer, + norm_layer_cl=transformer_norm_layer_cl, + rel_pos_type=rel_pos_type, + rel_pos_dim=rel_pos_dim, + ), + ) - def forward(self, x): - x = self.conv1(x) - x = self.norm1(x) - x = self.conv2(x) - return x + +def _tf_cfg(): + return dict( + conv_cfg=MaxxVitConvCfg( + norm_eps=1e-3, + act_layer='gelu_tanh', + padding='same', + ), + transformer_cfg=MaxxVitTransformerCfg( + norm_eps=1e-5, + act_layer='gelu_tanh', + head_first=False, # heads are interleaved (q_nh, q_hdim, k_nh, q_hdim, ....) + rel_pos_type='bias_tf', + ), + ) -def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]): - if cfg.window_size is not None: - assert cfg.grid_size - return cfg - partition_size = img_size[0] // cfg.partition_ratio, img_size[1] // cfg.partition_ratio - cfg = replace(cfg, window_size=partition_size, grid_size=partition_size) - return cfg +model_cfgs = dict( + # Fiddling with configs / defaults / still pretraining + coatnet_pico_rw_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 3, 5, 2), + stem_width=(32, 64), + **_rw_max_cfg( # using newer max defaults here + conv_output_bias=True, + conv_attn_ratio=0.25, + ), + ), + coatnet_nano_rw_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(3, 4, 6, 3), + stem_width=(32, 64), + **_rw_max_cfg( # using newer max defaults here + stride_mode='pool', + conv_output_bias=True, + conv_attn_ratio=0.25, + ), + ), + coatnet_0_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 3, 7, 2), # deeper than paper '0' model + stem_width=(32, 64), + **_rw_coat_cfg( + conv_attn_early=True, + transformer_shortcut_bias=False, + ), + ), + coatnet_1_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + stem_width=(32, 64), + **_rw_coat_cfg( + stride_mode='dw', + conv_attn_early=True, + transformer_shortcut_bias=False, + ) + ), + coatnet_2_rw_224=MaxxVitCfg( + embed_dim=(128, 256, 512, 1024), + depths=(2, 6, 14, 2), + stem_width=(64, 128), + **_rw_coat_cfg( + stride_mode='dw', + conv_attn_act_layer='silu', + #init_values=1e-6, + ), + ), + coatnet_3_rw_224=MaxxVitCfg( + embed_dim=(192, 384, 768, 1536), + depths=(2, 6, 14, 2), + stem_width=(96, 192), + **_rw_coat_cfg( + stride_mode='dw', + conv_attn_act_layer='silu', + init_values=1e-6, + ), + ), + # Highly experimental configs + coatnet_bn_0_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 3, 7, 2), # deeper than paper '0' model + stem_width=(32, 64), + **_rw_coat_cfg( + stride_mode='dw', + conv_attn_early=True, + transformer_shortcut_bias=False, + transformer_norm_layer='batchnorm2d', + ) + ), + coatnet_rmlp_nano_rw_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(3, 4, 6, 3), + stem_width=(32, 64), + **_rw_max_cfg( + conv_output_bias=True, + conv_attn_ratio=0.25, + rel_pos_type='mlp', + rel_pos_dim=384, + ), + ), + coatnet_rmlp_0_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 3, 7, 2), # deeper than paper '0' model + stem_width=(32, 64), + **_rw_coat_cfg( + stride_mode='dw', + rel_pos_type='mlp', + ), + ), + coatnet_rmlp_1_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + stem_width=(32, 64), + **_rw_coat_cfg( + pool_type='max', + conv_attn_early=True, + transformer_shortcut_bias=False, + rel_pos_type='mlp', + rel_pos_dim=384, # was supposed to be 512, woops + ), + ), + coatnet_rmlp_1_rw2_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + stem_width=(32, 64), + **_rw_coat_cfg( + stride_mode='dw', + rel_pos_type='mlp', + rel_pos_dim=512, # was supposed to be 512, woops + ), + ), + coatnet_rmlp_2_rw_224=MaxxVitCfg( + embed_dim=(128, 256, 512, 1024), + depths=(2, 6, 14, 2), + stem_width=(64, 128), + **_rw_coat_cfg( + stride_mode='dw', + conv_attn_act_layer='silu', + init_values=1e-6, + rel_pos_type='mlp' + ), + ), + coatnet_rmlp_3_rw_224=MaxxVitCfg( + embed_dim=(192, 384, 768, 1536), + depths=(2, 6, 14, 2), + stem_width=(96, 192), + **_rw_coat_cfg( + stride_mode='dw', + conv_attn_act_layer='silu', + init_values=1e-6, + rel_pos_type='mlp' + ), + ), -class MaxxVit(nn.Module): - """ CoaTNet + MaxVit base model. + coatnet_nano_cc_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(3, 4, 6, 3), + stem_width=(32, 64), + block_type=('C', 'C', ('C', 'T'), ('C', 'T')), + **_rw_coat_cfg(), + ), + coatnext_nano_rw_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(3, 4, 6, 3), + stem_width=(32, 64), + weight_init='normal', + **_next_cfg( + rel_pos_type='bias', + init_values=(1e-5, None) + ), + ), - Highly configurable for different block compositions, tensor layouts, pooling types. - """ + # Trying to be like the CoAtNet paper configs + coatnet_0_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 3, 5, 2), + stem_width=64, + ), + coatnet_1_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + stem_width=64, + ), + coatnet_2_224=MaxxVitCfg( + embed_dim=(128, 256, 512, 1024), + depths=(2, 6, 14, 2), + stem_width=128, + ), + coatnet_3_224=MaxxVitCfg( + embed_dim=(192, 384, 768, 1536), + depths=(2, 6, 14, 2), + stem_width=192, + ), + coatnet_4_224=MaxxVitCfg( + embed_dim=(192, 384, 768, 1536), + depths=(2, 12, 28, 2), + stem_width=192, + ), + coatnet_5_224=MaxxVitCfg( + embed_dim=(256, 512, 1280, 2048), + depths=(2, 12, 28, 2), + stem_width=192, + ), - def __init__( - self, - cfg: MaxxVitCfg, - img_size: Union[int, Tuple[int, int]] = 224, - in_chans: int = 3, - num_classes: int = 1000, - global_pool: str = 'avg', - drop_rate: float = 0., - drop_path_rate: float = 0. - ): - super().__init__() - img_size = to_2tuple(img_size) - transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size) - self.num_classes = num_classes - self.global_pool = global_pool - self.num_features = cfg.embed_dim[-1] - self.embed_dim = cfg.embed_dim - self.drop_rate = drop_rate - self.grad_checkpointing = False + # Experimental MaxVit configs + maxvit_pico_rw_256=MaxxVitCfg( + embed_dim=(32, 64, 128, 256), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(24, 32), + **_rw_max_cfg(), + ), + maxvit_nano_rw_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(1, 2, 3, 1), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(), + ), + maxvit_tiny_rw_224=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(), + ), + maxvit_tiny_rw_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(), + ), - self.stem = Stem( - in_chs=in_chans, - out_chs=cfg.stem_width, - act_layer=cfg.conv_cfg.act_layer, - norm_layer=cfg.conv_cfg.norm_layer, - norm_eps=cfg.conv_cfg.norm_eps, - ) + maxvit_rmlp_pico_rw_256=MaxxVitCfg( + embed_dim=(32, 64, 128, 256), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(24, 32), + **_rw_max_cfg(rel_pos_type='mlp'), + ), + maxvit_rmlp_nano_rw_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(1, 2, 3, 1), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(rel_pos_type='mlp'), + ), + maxvit_rmlp_tiny_rw_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(rel_pos_type='mlp'), + ), + maxvit_rmlp_small_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg( + rel_pos_type='mlp', + init_values=1e-6, + ), + ), + maxvit_rmlp_small_rw_256=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(32, 64), + **_rw_max_cfg( + rel_pos_type='mlp', + init_values=1e-6, + ), + ), - stride = self.stem.stride - feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))]) + maxvit_tiny_pm_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('PM',) * 4, + stem_width=(32, 64), + **_rw_max_cfg(), + ), + + maxxvit_rmlp_nano_rw_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(1, 2, 3, 1), + block_type=('M',) * 4, + stem_width=(32, 64), + weight_init='normal', + **_next_cfg(), + ), + maxxvit_rmlp_tiny_rw_256=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(32, 64), + **_next_cfg(), + ), + maxxvit_rmlp_small_rw_256=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=(48, 96), + **_next_cfg(), + ), + maxxvit_rmlp_base_rw_224=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + block_type=('M',) * 4, + stem_width=(48, 96), + **_next_cfg(), + ), + maxxvit_rmlp_large_rw_224=MaxxVitCfg( + embed_dim=(128, 256, 512, 1024), + depths=(2, 6, 12, 2), + block_type=('M',) * 4, + stem_width=(64, 128), + **_next_cfg(), + ), - num_stages = len(cfg.embed_dim) - assert len(cfg.depths) == num_stages - dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] - in_chs = self.stem.out_chs - stages = [] - for i in range(num_stages): - stage_stride = 2 - out_chs = cfg.embed_dim[i] - feat_size = tuple([(r - 1) // stage_stride + 1 for r in feat_size]) - stages += [MaxxVitStage( - in_chs, - out_chs, - depth=cfg.depths[i], - block_types=cfg.block_type[i], - conv_cfg=cfg.conv_cfg, - transformer_cfg=transformer_cfg, - feat_size=feat_size, - drop_path=dpr[i], - )] - stride *= stage_stride - in_chs = out_chs - self.stages = nn.Sequential(*stages) + # Trying to be like the MaxViT paper configs + maxvit_tiny_tf=MaxxVitCfg( + embed_dim=(64, 128, 256, 512), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=64, + stem_bias=True, + head_hidden_size=512, + **_tf_cfg(), + ), + maxvit_small_tf=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 2, 5, 2), + block_type=('M',) * 4, + stem_width=64, + stem_bias=True, + head_hidden_size=768, + **_tf_cfg(), + ), + maxvit_base_tf=MaxxVitCfg( + embed_dim=(96, 192, 384, 768), + depths=(2, 6, 14, 2), + block_type=('M',) * 4, + stem_width=64, + stem_bias=True, + head_hidden_size=768, + **_tf_cfg(), + ), + maxvit_large_tf=MaxxVitCfg( + embed_dim=(128, 256, 512, 1024), + depths=(2, 6, 14, 2), + block_type=('M',) * 4, + stem_width=128, + stem_bias=True, + head_hidden_size=1024, + **_tf_cfg(), + ), + maxvit_xlarge_tf=MaxxVitCfg( + embed_dim=(192, 384, 768, 1536), + depths=(2, 6, 14, 2), + block_type=('M',) * 4, + stem_width=192, + stem_bias=True, + head_hidden_size=1536, + **_tf_cfg(), + ), +) - final_norm_layer = get_norm_layer(cfg.transformer_cfg.norm_layer) - self.norm = final_norm_layer(self.num_features, eps=cfg.transformer_cfg.norm_eps) - # Classifier head - self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) +def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs): + return build_model_with_cfg( + MaxxVit, variant, pretrained, + model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], + feature_cfg=dict(flatten_sequential=True), + **kwargs) - # Weight init (default PyTorch init works well for AdamW if scheme not set) - assert cfg.weight_init in ('', 'normal', 'trunc_normal', 'xavier_normal', 'vit_eff') - if cfg.weight_init: - named_apply(partial(self._init_weights, scheme=cfg.weight_init), self) - def _init_weights(self, module, name, scheme=''): - if hasattr(module, 'init_weights'): - try: - module.init_weights(scheme=scheme) - except TypeError: - module.init_weights() +def _cfg(url='', **kwargs): + return { + 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.95, 'interpolation': 'bicubic', + 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), + 'first_conv': 'stem.conv1', 'classifier': 'head.fc', + 'fixed_input_size': True, + **kwargs + } - @torch.jit.ignore - def no_weight_decay(self): - return { - k for k, _ in self.named_parameters() - if any(n in k for n in ["relative_position_bias_table", "rel_pos.mlp"])} - @torch.jit.ignore - def group_matcher(self, coarse=False): - matcher = dict( - stem=r'^stem', # stem and embed - blocks=[(r'^stages\.(\d+)', None), (r'^norm', (99999,))] - ) - return matcher +default_cfgs = generate_defaults({ + # Fiddling with configs / defaults / still pretraining + 'coatnet_pico_rw_224': _cfg(url=''), + 'coatnet_nano_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth', + crop_pct=0.9), + 'coatnet_0_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'), + 'coatnet_1_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth' + ), + 'coatnet_2_rw_224': _cfg(url=''), + 'coatnet_3_rw_224': _cfg(url=''), - @torch.jit.ignore - def set_grad_checkpointing(self, enable=True): - for s in self.stages: - s.grad_checkpointing = enable + # Highly experimental configs + 'coatnet_bn_0_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, + crop_pct=0.95), + 'coatnet_rmlp_nano_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth', + crop_pct=0.9), + 'coatnet_rmlp_0_rw_224': _cfg(url=''), + 'coatnet_rmlp_1_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'), + 'coatnet_rmlp_1_rw2_224': _cfg(url=''), + 'coatnet_rmlp_2_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'), + 'coatnet_rmlp_3_rw_224': _cfg(url=''), + 'coatnet_nano_cc_224': _cfg(url=''), + 'coatnext_nano_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth', + crop_pct=0.9), - @torch.jit.ignore - def get_classifier(self): - return self.head.fc + # Trying to be like the CoAtNet paper configs + 'coatnet_0_224': _cfg(url=''), + 'coatnet_1_224': _cfg(url=''), + 'coatnet_2_224': _cfg(url=''), + 'coatnet_3_224': _cfg(url=''), + 'coatnet_4_224': _cfg(url=''), + 'coatnet_5_224': _cfg(url=''), - def reset_classifier(self, num_classes, global_pool=None): - self.num_classes = num_classes - if global_pool is None: - global_pool = self.head.global_pool.pool_type - self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + # Experimental configs + 'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_nano_rw_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_tiny_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'), + 'maxvit_tiny_rw_256': _cfg( + url='', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_rmlp_pico_rw_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_rmlp_nano_rw_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_rmlp_tiny_rw_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_rmlp_small_rw_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth', + crop_pct=0.9, + ), + 'maxvit_rmlp_small_rw_256': _cfg( + url='', + input_size=(3, 256, 256), pool_size=(8, 8)), - def forward_features(self, x): - x = self.stem(x) - x = self.stages(x) - x = self.norm(x) - return x + 'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), - def forward_head(self, x, pre_logits: bool = False): - return self.head(x, pre_logits=pre_logits) + 'maxxvit_rmlp_nano_rw_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvit_rmlp_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvit_rmlp_small_rw_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvit_rmlp_base_rw_224': _cfg(url=''), + 'maxxvit_rmlp_large_rw_224': _cfg(url=''), - def forward(self, x): - x = self.forward_features(x) - x = self.forward_head(x) - return x + # Trying to be like the MaxViT paper configs + 'maxvit_tiny_tf_224.in1k': _cfg( + url='', + #file='maxvit_tiny_tf_224_in1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'maxvit_tiny_tf_384.in1k': _cfg( + url='', + #file='maxvit_tiny_tf_384_in1k.pth', + input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + 'maxvit_tiny_tf_512.in1k': _cfg( + url='', + #file='maxvit_tiny_tf_512_in1k.pth', + input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + 'maxvit_small_tf_224.in1k': _cfg( + url='', + #file='maxvit_small_tf_224_in1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'maxvit_small_tf_384.in1k': _cfg( + url='', + #file='maxvit_small_tf_384_in1k.pth', + input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + 'maxvit_small_tf_512.in1k': _cfg( + url='', + #file='maxvit_small_tf_512_in1k.pth', + input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + 'maxvit_base_tf_224.in1k': _cfg( + url='', + #file='maxvit_base_tf_224_in1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'maxvit_base_tf_384.in1k': _cfg( + url='', + #file='maxvit_base_tf_384_in1k.pth', + input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + 'maxvit_base_tf_512.in1k': _cfg( + url='', + #file='maxvit_base_tf_512_in1k.pth', + input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + 'maxvit_large_tf_224.in1k': _cfg( + url='', + #file='maxvit_large_tf_224_in1k.pth', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'maxvit_large_tf_384.in1k': _cfg( + url='', + #file='maxvit_large_tf_384_in1k.pth', + input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + 'maxvit_large_tf_512.in1k': _cfg( + url='', + #file='maxvit_large_tf_512_in1k.pth', + input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), -def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs): - return build_model_with_cfg( - MaxxVit, variant, pretrained, - model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], - feature_cfg=dict(flatten_sequential=True), - **kwargs) + 'maxvit_base_tf_224.in21k': _cfg( + url='', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), + 'maxvit_base_tf_384.in21k_ft1k': _cfg( + url='', + #file='maxvit_base_tf_384_in21k_ft_in1k.pth', + input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + 'maxvit_base_tf_512.in21k_ft1k': _cfg( + url='', + #file='maxvit_base_tf_512_in21k_ft_in1k.pth', + input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + 'maxvit_large_tf_224.in21k': _cfg( + url=''), + 'maxvit_large_tf_384.in21k_ft1k': _cfg( + url='', + #file='maxvit_large_tf_384_in21k_ft_in1k.pth', + input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + 'maxvit_large_tf_512.in21k_ft1k': _cfg( + url='', + #file='maxvit_large_tf_512_in21k_ft_in1k.pth', + input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + 'maxvit_xlarge_tf_224.in21k': _cfg( + url=''), + 'maxvit_xlarge_tf_384.in21k_ft1k': _cfg( + url='', + #file='maxvit_xlarge_tf_384_in21k_ft_in1k.pth', + input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + 'maxvit_xlarge_tf_512.in21k_ft1k': _cfg( + url='', + #file='maxvit_xlarge_tf_512_in21k_ft_in1k.pth', + input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), +}) @register_model @@ -1773,6 +2075,11 @@ def coatnet_rmlp_1_rw_224(pretrained=False, **kwargs): return _create_maxxvit('coatnet_rmlp_1_rw_224', pretrained=pretrained, **kwargs) +@register_model +def coatnet_rmlp_1_rw2_224(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_rmlp_1_rw2_224', pretrained=pretrained, **kwargs) + + @register_model def coatnet_rmlp_2_rw_224(pretrained=False, **kwargs): return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs) @@ -1889,25 +2196,85 @@ def maxxvit_rmlp_small_rw_256(pretrained=False, **kwargs): @register_model -def maxvit_tiny_224(pretrained=False, **kwargs): - return _create_maxxvit('maxvit_tiny_224', pretrained=pretrained, **kwargs) +def maxxvit_rmlp_base_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('maxxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxxvit_rmlp_large_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('maxxvit_rmlp_large_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_tiny_tf_224(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_tiny_tf_224', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_tiny_tf_384(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_tiny_tf_384', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_tiny_tf_512(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_tiny_tf_512', 'maxvit_tiny_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_small_tf_224(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_small_tf_224', 'maxvit_small_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_small_tf_384(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_small_tf_384', 'maxvit_small_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_small_tf_512(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_small_tf_512', 'maxvit_small_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_base_tf_224(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_base_tf_224', 'maxvit_base_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_base_tf_384(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_base_tf_384', 'maxvit_base_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_base_tf_512(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_base_tf_512', 'maxvit_base_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_large_tf_224(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_large_tf_224', 'maxvit_large_tf', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_large_tf_384(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_large_tf_384', 'maxvit_large_tf', pretrained=pretrained, **kwargs) @register_model -def maxvit_small_224(pretrained=False, **kwargs): - return _create_maxxvit('maxvit_small_224', pretrained=pretrained, **kwargs) +def maxvit_large_tf_512(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_large_tf_512', 'maxvit_large_tf', pretrained=pretrained, **kwargs) @register_model -def maxvit_base_224(pretrained=False, **kwargs): - return _create_maxxvit('maxvit_base_224', pretrained=pretrained, **kwargs) +def maxvit_xlarge_tf_224(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_xlarge_tf_224', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs) @register_model -def maxvit_large_224(pretrained=False, **kwargs): - return _create_maxxvit('maxvit_large_224', pretrained=pretrained, **kwargs) +def maxvit_xlarge_tf_384(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_xlarge_tf_384', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs) @register_model -def maxvit_xlarge_224(pretrained=False, **kwargs): - return _create_maxxvit('maxvit_xlarge_224', pretrained=pretrained, **kwargs) \ No newline at end of file +def maxvit_xlarge_tf_512(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_xlarge_tf_512', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs) \ No newline at end of file diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index f29216c9..cde0018b 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -32,7 +32,7 @@ import torch.utils.checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD,\ OPENAI_CLIP_MEAN, OPENAI_CLIP_STD -from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply, adapt_input_conv, checkpoint_seq +from .helpers import build_model_with_cfg, named_apply, adapt_input_conv, checkpoint_seq from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_ from ._pretrained import generate_defaults from .registry import register_model @@ -795,13 +795,15 @@ default_cfgs = generate_defaults({ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0), 'vit_large_patch14_clip_336.laion2b_ft_in1k': _cfg( hf_hub_id='timm/vit_large_patch14_clip_336.laion2b_ft_in1k', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, input_size=(3, 336, 336)), + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_huge_patch14_clip_224.laion2b_ft_in1k': _cfg( hf_hub_id='timm/vit_huge_patch14_clip_224.laion2b_ft_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), 'vit_huge_patch14_clip_336.laion2b_ft_in1k': _cfg( hf_hub_id='', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 336, 336)), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_base_patch32_clip_224.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k_in1k', @@ -823,13 +825,15 @@ default_cfgs = generate_defaults({ mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0), 'vit_large_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_large_patch14_clip_336.laion2b_ft_in12k_in1k', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, input_size=(3, 336, 336)), + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), 'vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 336, 336)), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_base_patch32_clip_224.laion2b_ft_in12k': _cfg( hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k', @@ -879,12 +883,16 @@ default_cfgs = generate_defaults({ 'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg( hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in12k_in1k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), + 'vit_large_patch14_clip_336.openai_ft_in12k_in1k': _cfg( + hf_hub_id='timm/vit_large_patch14_clip_336.openai_ft_in12k_in1k', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_base_patch32_clip_224.openai_ft_in12k': _cfg( - #hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k', + hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), 'vit_base_patch16_clip_224.openai_ft_in12k': _cfg( - #hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in12k', + hf_hub_id='timm/vit_base_patch16_clip_224.openai_ft_in12k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), 'vit_large_patch14_clip_224.openai_ft_in12k': _cfg( hf_hub_id='timm/vit_large_patch14_clip_224.openai_ft_in12k',