Merge remote-tracking branch 'upstream/main'

pull/1583/head
Fredo Guan 2 years ago
commit fb717056da

@ -24,6 +24,59 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
* ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗ * ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗
* Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch. * Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch.
### Jan 20, 2023
* Add two convnext 12k -> 1k fine-tunes at 384x384
* `convnext_tiny.in12k_ft_in1k_384` - 85.1 @ 384
* `convnext_small.in12k_ft_in1k_384` - 86.2 @ 384
* Push all MaxxViT weights to HF hub, and add new ImageNet-12k -> 1k fine-tunes for `rw` base MaxViT and CoAtNet 1/2 models
|model |top1 |top5 |samples / sec |Params (M) |GMAC |Act (M)|
|------------------------------------------------------------------------------------------------------------------------|----:|----:|--------------:|--------------:|-----:|------:|
|[maxvit_xlarge_tf_512.in21k_ft_in1k](https://huggingface.co/timm/maxvit_xlarge_tf_512.in21k_ft_in1k) |88.53|98.64| 21.76| 475.77|534.14|1413.22|
|[maxvit_xlarge_tf_384.in21k_ft_in1k](https://huggingface.co/timm/maxvit_xlarge_tf_384.in21k_ft_in1k) |88.32|98.54| 42.53| 475.32|292.78| 668.76|
|[maxvit_base_tf_512.in21k_ft_in1k](https://huggingface.co/timm/maxvit_base_tf_512.in21k_ft_in1k) |88.20|98.53| 50.87| 119.88|138.02| 703.99|
|[maxvit_large_tf_512.in21k_ft_in1k](https://huggingface.co/timm/maxvit_large_tf_512.in21k_ft_in1k) |88.04|98.40| 36.42| 212.33|244.75| 942.15|
|[maxvit_large_tf_384.in21k_ft_in1k](https://huggingface.co/timm/maxvit_large_tf_384.in21k_ft_in1k) |87.98|98.56| 71.75| 212.03|132.55| 445.84|
|[maxvit_base_tf_384.in21k_ft_in1k](https://huggingface.co/timm/maxvit_base_tf_384.in21k_ft_in1k) |87.92|98.54| 104.71| 119.65| 73.80| 332.90|
|[maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k](https://huggingface.co/timm/maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k) |87.81|98.37| 106.55| 116.14| 70.97| 318.95|
|[maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k](https://huggingface.co/timm/maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k) |87.47|98.37| 149.49| 116.09| 72.98| 213.74|
|[coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k) |87.39|98.31| 160.80| 73.88| 47.69| 209.43|
|[maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k) |86.89|98.02| 375.86| 116.14| 23.15| 92.64|
|[maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k) |86.64|98.02| 501.03| 116.09| 24.20| 62.77|
|[maxvit_base_tf_512.in1k](https://huggingface.co/timm/maxvit_base_tf_512.in1k) |86.60|97.92| 50.75| 119.88|138.02| 703.99|
|[coatnet_2_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_2_rw_224.sw_in12k_ft_in1k) |86.57|97.89| 631.88| 73.87| 15.09| 49.22|
|[maxvit_large_tf_512.in1k](https://huggingface.co/timm/maxvit_large_tf_512.in1k) |86.52|97.88| 36.04| 212.33|244.75| 942.15|
|[coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k) |86.49|97.90| 620.58| 73.88| 15.18| 54.78|
|[maxvit_base_tf_384.in1k](https://huggingface.co/timm/maxvit_base_tf_384.in1k) |86.29|97.80| 101.09| 119.65| 73.80| 332.90|
|[maxvit_large_tf_384.in1k](https://huggingface.co/timm/maxvit_large_tf_384.in1k) |86.23|97.69| 70.56| 212.03|132.55| 445.84|
|[maxvit_small_tf_512.in1k](https://huggingface.co/timm/maxvit_small_tf_512.in1k) |86.10|97.76| 88.63| 69.13| 67.26| 383.77|
|[maxvit_tiny_tf_512.in1k](https://huggingface.co/timm/maxvit_tiny_tf_512.in1k) |85.67|97.58| 144.25| 31.05| 33.49| 257.59|
|[maxvit_small_tf_384.in1k](https://huggingface.co/timm/maxvit_small_tf_384.in1k) |85.54|97.46| 188.35| 69.02| 35.87| 183.65|
|[maxvit_tiny_tf_384.in1k](https://huggingface.co/timm/maxvit_tiny_tf_384.in1k) |85.11|97.38| 293.46| 30.98| 17.53| 123.42|
|[maxvit_large_tf_224.in1k](https://huggingface.co/timm/maxvit_large_tf_224.in1k) |84.93|96.97| 247.71| 211.79| 43.68| 127.35|
|[coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k) |84.90|96.96| 1025.45| 41.72| 8.11| 40.13|
|[maxvit_base_tf_224.in1k](https://huggingface.co/timm/maxvit_base_tf_224.in1k) |84.85|96.99| 358.25| 119.47| 24.04| 95.01|
|[maxxvit_rmlp_small_rw_256.sw_in1k](https://huggingface.co/timm/maxxvit_rmlp_small_rw_256.sw_in1k) |84.63|97.06| 575.53| 66.01| 14.67| 58.38|
|[coatnet_rmlp_2_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_rmlp_2_rw_224.sw_in1k) |84.61|96.74| 625.81| 73.88| 15.18| 54.78|
|[maxvit_rmlp_small_rw_224.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_small_rw_224.sw_in1k) |84.49|96.76| 693.82| 64.90| 10.75| 49.30|
|[maxvit_small_tf_224.in1k](https://huggingface.co/timm/maxvit_small_tf_224.in1k) |84.43|96.83| 647.96| 68.93| 11.66| 53.17|
|[maxvit_rmlp_tiny_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_tiny_rw_256.sw_in1k) |84.23|96.78| 807.21| 29.15| 6.77| 46.92|
|[coatnet_1_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_1_rw_224.sw_in1k) |83.62|96.38| 989.59| 41.72| 8.04| 34.60|
|[maxvit_tiny_rw_224.sw_in1k](https://huggingface.co/timm/maxvit_tiny_rw_224.sw_in1k) |83.50|96.50| 1100.53| 29.06| 5.11| 33.11|
|[maxvit_tiny_tf_224.in1k](https://huggingface.co/timm/maxvit_tiny_tf_224.in1k) |83.41|96.59| 1004.94| 30.92| 5.60| 35.78|
|[coatnet_rmlp_1_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_rmlp_1_rw_224.sw_in1k) |83.36|96.45| 1093.03| 41.69| 7.85| 35.47|
|[maxxvitv2_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxxvitv2_nano_rw_256.sw_in1k) |83.11|96.33| 1276.88| 23.70| 6.26| 23.05|
|[maxxvit_rmlp_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxxvit_rmlp_nano_rw_256.sw_in1k) |83.03|96.34| 1341.24| 16.78| 4.37| 26.05|
|[maxvit_rmlp_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_nano_rw_256.sw_in1k) |82.96|96.26| 1283.24| 15.50| 4.47| 31.92|
|[maxvit_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_nano_rw_256.sw_in1k) |82.93|96.23| 1218.17| 15.45| 4.46| 30.28|
|[coatnet_bn_0_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_bn_0_rw_224.sw_in1k) |82.39|96.19| 1600.14| 27.44| 4.67| 22.04|
|[coatnet_0_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_0_rw_224.sw_in1k) |82.39|95.84| 1831.21| 27.44| 4.43| 18.73|
|[coatnet_rmlp_nano_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_rmlp_nano_rw_224.sw_in1k) |82.05|95.87| 2109.09| 15.15| 2.62| 20.34|
|[coatnext_nano_rw_224.sw_in1k](https://huggingface.co/timm/coatnext_nano_rw_224.sw_in1k) |81.95|95.92| 2525.52| 14.70| 2.47| 12.80|
|[coatnet_nano_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_nano_rw_224.sw_in1k) |81.70|95.64| 2344.52| 15.14| 2.41| 15.41|
|[maxvit_rmlp_pico_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_pico_rw_256.sw_in1k) |80.53|95.21| 1594.71| 7.52| 1.85| 24.86|
### Jan 11, 2023 ### Jan 11, 2023
* Update ConvNeXt ImageNet-12k pretrain series w/ two new fine-tuned weights (and pre FT `.in12k` tags) * Update ConvNeXt ImageNet-12k pretrain series w/ two new fine-tuned weights (and pre FT `.in12k` tags)
* `convnext_nano.in12k_ft_in1k` - 82.3 @ 224, 82.9 @ 288 (previously released) * `convnext_nano.in12k_ft_in1k` - 82.3 @ 224, 82.9 @ 288 (previously released)

@ -27,8 +27,9 @@ NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*', 'flexivit*' 'eva_*', 'flexivit*'
] ]
#'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', '
NUM_NON_STD = len(NON_STD_FILTERS) NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures # exclude models that cause specific test failures
@ -38,7 +39,7 @@ if 'GITHUB_ACTIONS' in os.environ:
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*', '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
'swin*giant*', 'convnextv2_huge*', 'davit_giant', 'davit_huge'] 'swin*giant*', 'convnextv2_huge*', 'maxvit_xlarge*', 'davit_giant', 'davit_huge']
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*'] NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*']
else: else:
EXCLUDE_FILTERS = [] EXCLUDE_FILTERS = []
@ -53,7 +54,7 @@ MAX_JIT_SIZE = 320
TARGET_FFEAT_SIZE = 96 TARGET_FFEAT_SIZE = 96
MAX_FFEAT_SIZE = 256 MAX_FFEAT_SIZE = 256
TARGET_FWD_FX_SIZE = 128 TARGET_FWD_FX_SIZE = 128
MAX_FWD_FX_SIZE = 224 MAX_FWD_FX_SIZE = 256
TARGET_BWD_FX_SIZE = 128 TARGET_BWD_FX_SIZE = 128
MAX_BWD_FX_SIZE = 224 MAX_BWD_FX_SIZE = 224

@ -1,6 +1,6 @@
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
rand_augment_transform, auto_augment_transform rand_augment_transform, auto_augment_transform
from .config import resolve_data_config from .config import resolve_data_config, resolve_model_data_config
from .constants import * from .constants import *
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset from .dataset_factory import create_dataset

@ -6,16 +6,18 @@ _logger = logging.getLogger(__name__)
def resolve_data_config( def resolve_data_config(
args, args=None,
default_cfg=None, pretrained_cfg=None,
model=None, model=None,
use_test_size=False, use_test_size=False,
verbose=False verbose=False
): ):
new_config = {} assert model or args or pretrained_cfg, "At least one of model, args, or pretrained_cfg required for data config."
default_cfg = default_cfg or {} args = args or {}
if not default_cfg and model is not None and hasattr(model, 'default_cfg'): pretrained_cfg = pretrained_cfg or {}
default_cfg = model.default_cfg if not pretrained_cfg and model is not None and hasattr(model, 'pretrained_cfg'):
pretrained_cfg = model.pretrained_cfg
data_config = {}
# Resolve input/image size # Resolve input/image size
in_chans = 3 in_chans = 3
@ -32,65 +34,94 @@ def resolve_data_config(
assert isinstance(args['img_size'], int) assert isinstance(args['img_size'], int)
input_size = (in_chans, args['img_size'], args['img_size']) input_size = (in_chans, args['img_size'], args['img_size'])
else: else:
if use_test_size and default_cfg.get('test_input_size', None) is not None: if use_test_size and pretrained_cfg.get('test_input_size', None) is not None:
input_size = default_cfg['test_input_size'] input_size = pretrained_cfg['test_input_size']
elif default_cfg.get('input_size', None) is not None: elif pretrained_cfg.get('input_size', None) is not None:
input_size = default_cfg['input_size'] input_size = pretrained_cfg['input_size']
new_config['input_size'] = input_size data_config['input_size'] = input_size
# resolve interpolation method # resolve interpolation method
new_config['interpolation'] = 'bicubic' data_config['interpolation'] = 'bicubic'
if args.get('interpolation', None): if args.get('interpolation', None):
new_config['interpolation'] = args['interpolation'] data_config['interpolation'] = args['interpolation']
elif default_cfg.get('interpolation', None): elif pretrained_cfg.get('interpolation', None):
new_config['interpolation'] = default_cfg['interpolation'] data_config['interpolation'] = pretrained_cfg['interpolation']
# resolve dataset + model mean for normalization # resolve dataset + model mean for normalization
new_config['mean'] = IMAGENET_DEFAULT_MEAN data_config['mean'] = IMAGENET_DEFAULT_MEAN
if args.get('mean', None) is not None: if args.get('mean', None) is not None:
mean = tuple(args['mean']) mean = tuple(args['mean'])
if len(mean) == 1: if len(mean) == 1:
mean = tuple(list(mean) * in_chans) mean = tuple(list(mean) * in_chans)
else: else:
assert len(mean) == in_chans assert len(mean) == in_chans
new_config['mean'] = mean data_config['mean'] = mean
elif default_cfg.get('mean', None): elif pretrained_cfg.get('mean', None):
new_config['mean'] = default_cfg['mean'] data_config['mean'] = pretrained_cfg['mean']
# resolve dataset + model std deviation for normalization # resolve dataset + model std deviation for normalization
new_config['std'] = IMAGENET_DEFAULT_STD data_config['std'] = IMAGENET_DEFAULT_STD
if args.get('std', None) is not None: if args.get('std', None) is not None:
std = tuple(args['std']) std = tuple(args['std'])
if len(std) == 1: if len(std) == 1:
std = tuple(list(std) * in_chans) std = tuple(list(std) * in_chans)
else: else:
assert len(std) == in_chans assert len(std) == in_chans
new_config['std'] = std data_config['std'] = std
elif default_cfg.get('std', None): elif pretrained_cfg.get('std', None):
new_config['std'] = default_cfg['std'] data_config['std'] = pretrained_cfg['std']
# resolve default inference crop # resolve default inference crop
crop_pct = DEFAULT_CROP_PCT crop_pct = DEFAULT_CROP_PCT
if args.get('crop_pct', None): if args.get('crop_pct', None):
crop_pct = args['crop_pct'] crop_pct = args['crop_pct']
else: else:
if use_test_size and default_cfg.get('test_crop_pct', None): if use_test_size and pretrained_cfg.get('test_crop_pct', None):
crop_pct = default_cfg['test_crop_pct'] crop_pct = pretrained_cfg['test_crop_pct']
elif default_cfg.get('crop_pct', None): elif pretrained_cfg.get('crop_pct', None):
crop_pct = default_cfg['crop_pct'] crop_pct = pretrained_cfg['crop_pct']
new_config['crop_pct'] = crop_pct data_config['crop_pct'] = crop_pct
# resolve default crop percentage # resolve default crop percentage
crop_mode = DEFAULT_CROP_MODE crop_mode = DEFAULT_CROP_MODE
if args.get('crop_mode', None): if args.get('crop_mode', None):
crop_mode = args['crop_mode'] crop_mode = args['crop_mode']
elif default_cfg.get('crop_mode', None): elif pretrained_cfg.get('crop_mode', None):
crop_mode = default_cfg['crop_mode'] crop_mode = pretrained_cfg['crop_mode']
new_config['crop_mode'] = crop_mode data_config['crop_mode'] = crop_mode
if verbose: if verbose:
_logger.info('Data processing configuration for current model + dataset:') _logger.info('Data processing configuration for current model + dataset:')
for n, v in new_config.items(): for n, v in data_config.items():
_logger.info('\t%s: %s' % (n, str(v))) _logger.info('\t%s: %s' % (n, str(v)))
return new_config return data_config
def resolve_model_data_config(
model,
args=None,
pretrained_cfg=None,
use_test_size=False,
verbose=False,
):
""" Resolve Model Data Config
This is equivalent to resolve_data_config() but with arguments re-ordered to put model first.
Args:
model (nn.Module): the model instance
args (dict): command line arguments / configuration in dict form (overrides pretrained_cfg)
pretrained_cfg (dict): pretrained model config (overrides pretrained_cfg attached to model)
use_test_size (bool): use the test time input resolution (if one exists) instead of default train resolution
verbose (bool): enable extra logging of resolved values
Returns:
dictionary of config
"""
return resolve_data_config(
args=args,
pretrained_cfg=pretrained_cfg,
model=model,
use_test_size=use_test_size,
verbose=verbose,
)

@ -38,13 +38,24 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False
class ClassifierHead(nn.Module): class ClassifierHead(nn.Module):
"""Classifier head w/ configurable global pooling and dropout.""" """Classifier head w/ configurable global pooling and dropout."""
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False): def __init__(self, in_features, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
super(ClassifierHead, self).__init__() super(ClassifierHead, self).__init__()
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) self.in_features = in_features
self.use_conv = use_conv
self.global_pool, num_pooled_features = _create_pool(in_features, num_classes, pool_type, use_conv=use_conv)
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
def reset(self, num_classes, global_pool=None):
if global_pool is not None:
if global_pool != self.global_pool.pool_type:
self.global_pool, _ = _create_pool(self.in_features, num_classes, global_pool, use_conv=self.use_conv)
self.flatten = nn.Flatten(1) if self.use_conv and global_pool else nn.Identity()
num_pooled_features = self.in_features * self.global_pool.feat_mult()
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=self.use_conv)
def forward(self, x, pre_logits: bool = False): def forward(self, x, pre_logits: bool = False):
x = self.global_pool(x) x = self.global_pool(x)
if self.drop_rate: if self.drop_rate:

@ -179,11 +179,11 @@ def load_pretrained(
return return
if filter_fn is not None: if filter_fn is not None:
# for backwards compat with filter fn that take one arg, try one first, the two
try: try:
state_dict = filter_fn(state_dict)
except TypeError:
state_dict = filter_fn(state_dict, model) state_dict = filter_fn(state_dict, model)
except TypeError as e:
# for backwards compat with filter fn that take one arg
state_dict = filter_fn(state_dict)
input_convs = pretrained_cfg.get('first_conv', None) input_convs = pretrained_cfg.get('first_conv', None)
if input_convs is not None and in_chans != 3: if input_convs is not None and in_chans != 3:

@ -236,20 +236,7 @@ def push_to_hf_hub(
model_card = model_card or {} model_card = model_card or {}
model_name = repo_id.split('/')[-1] model_name = repo_id.split('/')[-1]
readme_path = Path(tmpdir) / "README.md" readme_path = Path(tmpdir) / "README.md"
readme_text = "---\n" readme_text = generate_readme(model_card, model_name)
readme_text += "tags:\n- image-classification\n- timm\n"
readme_text += "library_tag: timm\n"
readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n"
readme_text += "---\n"
readme_text += f"# Model card for {model_name}\n"
if 'description' in model_card:
readme_text += f"\n{model_card['description']}\n"
if 'details' in model_card:
readme_text += f"\n## Model Details\n"
for k, v in model_card['details'].items():
readme_text += f"- **{k}:** {v}\n"
if 'citation' in model_card:
readme_text += f"\n## Citation\n```\n{model_card['citation']}```\n"
readme_path.write_text(readme_text) readme_path.write_text(readme_text)
# Upload model and return # Upload model and return
@ -260,3 +247,51 @@ def push_to_hf_hub(
create_pr=create_pr, create_pr=create_pr,
commit_message=commit_message, commit_message=commit_message,
) )
def generate_readme(model_card, model_name):
readme_text = "---\n"
readme_text += "tags:\n- image-classification\n- timm\n"
readme_text += "library_tag: timm\n"
readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n"
if 'details' in model_card and 'Dataset' in model_card['details']:
readme_text += 'datasets:\n'
readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
if 'Pretrain Dataset' in model_card['details']:
readme_text += f"- {model_card['details']['Pretrain Dataset'].lower()}\n"
readme_text += "---\n"
readme_text += f"# Model card for {model_name}\n"
if 'description' in model_card:
readme_text += f"\n{model_card['description']}\n"
if 'details' in model_card:
readme_text += f"\n## Model Details\n"
for k, v in model_card['details'].items():
if isinstance(v, (list, tuple)):
readme_text += f"- **{k}:**\n"
for vi in v:
readme_text += f" - {vi}\n"
elif isinstance(v, dict):
readme_text += f"- **{k}:**\n"
for ki, vi in v.items():
readme_text += f" - {ki}: {vi}\n"
else:
readme_text += f"- **{k}:** {v}\n"
if 'usage' in model_card:
readme_text += f"\n## Model Usage\n"
readme_text += model_card['usage']
readme_text += '\n'
if 'comparison' in model_card:
readme_text += f"\n## Model Comparison\n"
readme_text += model_card['comparison']
readme_text += '\n'
if 'citation' in model_card:
readme_text += f"\n## Citation\n"
if not isinstance(model_card['citation'], (list, tuple)):
citations = [model_card['citation']]
else:
citations = model_card['citation']
for c in citations:
readme_text += f"```bibtex\n{c}\n```\n"
return readme_text

@ -500,6 +500,13 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/', hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.in12k_ft_in1k_384': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_small.in12k_ft_in1k_384': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_nano.in12k': _cfg( 'convnext_nano.in12k': _cfg(
hf_hub_id='timm/', hf_hub_id='timm/',
crop_pct=0.95, num_classes=11821), crop_pct=0.95, num_classes=11821),
@ -706,27 +713,27 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K', hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), crop_pct=1.0, num_classes=640), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laion2b_augreg': _cfg( 'convnext_base.clip_laion2b_augreg': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg', hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg',
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), crop_pct=1.0, num_classes=640), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laiona': _cfg( 'convnext_base.clip_laiona': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K', hf_hub_id='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), crop_pct=1.0, num_classes=640), input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laiona_320': _cfg( 'convnext_base.clip_laiona_320': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K', hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 320, 320), crop_pct=1.0, num_classes=640), input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laiona_augreg_320': _cfg( 'convnext_base.clip_laiona_augreg_320': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg', hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg',
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 320, 320), crop_pct=1.0, num_classes=640), input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
}) })

@ -913,7 +913,7 @@ class CspNet(nn.Module):
# Construct the head # Construct the head
self.num_features = prev_chs self.num_features = prev_chs
self.head = ClassifierHead( self.head = ClassifierHead(
in_chs=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) in_features=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)

@ -12,9 +12,6 @@ These configs work well and appear to be a bit faster / lower resource than the
The models without extra prefix / suffix' (coatnet_0_224, maxvit_tiny_224, etc), are intended to The models without extra prefix / suffix' (coatnet_0_224, maxvit_tiny_224, etc), are intended to
match paper, BUT, without any official pretrained weights it's difficult to confirm a 100% match. match paper, BUT, without any official pretrained weights it's difficult to confirm a 100% match.
# FIXME / WARNING
This impl remains a WIP, some configs and models may vanish or change...
Papers: Papers:
MaxViT: Multi-Axis Vision Transformer - https://arxiv.org/abs/2204.01697 MaxViT: Multi-Axis Vision Transformer - https://arxiv.org/abs/2204.01697
@ -76,6 +73,8 @@ class MaxxVitTransformerCfg:
partition_ratio: int = 32 partition_ratio: int = 32
window_size: Optional[Tuple[int, int]] = None window_size: Optional[Tuple[int, int]] = None
grid_size: Optional[Tuple[int, int]] = None grid_size: Optional[Tuple[int, int]] = None
no_block_attn: bool = False # disable window block attention for maxvit (ie only grid)
use_nchw_attn: bool = False # for MaxViT variants (not used for CoAt), keep tensors in NCHW order
init_values: Optional[float] = None init_values: Optional[float] = None
act_layer: str = 'gelu' act_layer: str = 'gelu'
norm_layer: str = 'layernorm2d' norm_layer: str = 'layernorm2d'
@ -889,19 +888,17 @@ class MaxxVitBlock(nn.Module):
stride: int = 1, stride: int = 1,
conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
use_nchw_attn: bool = False, # FIXME move to cfg? True is ~20-30% faster on TPU, 5-10% slower on GPU
use_block_attn: bool = True, # FIXME for testing ConvNeXt conv w/o block attention
drop_path: float = 0., drop_path: float = 0.,
): ):
super().__init__() super().__init__()
self.nchw_attn = transformer_cfg.use_nchw_attn
conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttentionCl partition_layer = PartitionAttention2d if self.nchw_attn else PartitionAttentionCl
self.nchw_attn = use_nchw_attn self.attn_block = None if transformer_cfg.no_block_attn else partition_layer(**attn_kwargs)
self.attn_block = partition_layer(**attn_kwargs) if use_block_attn else None
self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs) self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs)
def init_weights(self, scheme=''): def init_weights(self, scheme=''):
@ -1084,26 +1081,48 @@ class NormMlpHead(nn.Module):
hidden_size=None, hidden_size=None,
pool_type='avg', pool_type='avg',
drop_rate=0., drop_rate=0.,
norm_layer=nn.LayerNorm, norm_layer='layernorm2d',
act_layer=nn.Tanh, act_layer='tanh',
): ):
super().__init__() super().__init__()
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.in_features = in_features
self.hidden_size = hidden_size
self.num_features = in_features self.num_features = in_features
self.use_conv = not pool_type
norm_layer = get_norm_layer(norm_layer)
act_layer = get_act_layer(act_layer)
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
self.norm = norm_layer(in_features) self.norm = norm_layer(in_features)
self.flatten = nn.Flatten(1) if pool_type else nn.Identity() self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
if hidden_size: if hidden_size:
self.pre_logits = nn.Sequential(OrderedDict([ self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(in_features, hidden_size)), ('fc', linear_layer(in_features, hidden_size)),
('act', act_layer()), ('act', act_layer()),
])) ]))
self.num_features = hidden_size self.num_features = hidden_size
else: else:
self.pre_logits = nn.Identity() self.pre_logits = nn.Identity()
self.drop = nn.Dropout(self.drop_rate) self.drop = nn.Dropout(self.drop_rate)
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def reset(self, num_classes, global_pool=None):
if global_pool is not None:
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.use_conv = self.global_pool.is_identity()
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
if self.hidden_size:
if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or
(isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)):
with torch.no_grad():
new_fc = linear_layer(self.in_features, self.hidden_size)
new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape))
new_fc.bias.copy_(self.pre_logits.fc.bias)
self.pre_logits.fc = new_fc
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x, pre_logits: bool = False): def forward(self, x, pre_logits: bool = False):
x = self.global_pool(x) x = self.global_pool(x)
@ -1163,6 +1182,7 @@ class MaxxVit(nn.Module):
self.num_features = self.embed_dim = cfg.embed_dim[-1] self.num_features = self.embed_dim = cfg.embed_dim[-1]
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.grad_checkpointing = False self.grad_checkpointing = False
self.feature_info = []
self.stem = Stem( self.stem = Stem(
in_chs=in_chans, in_chs=in_chans,
@ -1173,8 +1193,8 @@ class MaxxVit(nn.Module):
norm_layer=cfg.conv_cfg.norm_layer, norm_layer=cfg.conv_cfg.norm_layer,
norm_eps=cfg.conv_cfg.norm_eps, norm_eps=cfg.conv_cfg.norm_eps,
) )
stride = self.stem.stride stride = self.stem.stride
self.feature_info += [dict(num_chs=self.stem.out_chs, reduction=2, module='stem')]
feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))]) feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))])
num_stages = len(cfg.embed_dim) num_stages = len(cfg.embed_dim)
@ -1198,15 +1218,17 @@ class MaxxVit(nn.Module):
)] )]
stride *= stage_stride stride *= stage_stride
in_chs = out_chs in_chs = out_chs
self.feature_info += [dict(num_chs=out_chs, reduction=stride, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages) self.stages = nn.Sequential(*stages)
final_norm_layer = partial(get_norm_layer(cfg.transformer_cfg.norm_layer), eps=cfg.transformer_cfg.norm_eps) final_norm_layer = partial(get_norm_layer(cfg.transformer_cfg.norm_layer), eps=cfg.transformer_cfg.norm_eps)
if cfg.head_hidden_size: self.head_hidden_size = cfg.head_hidden_size
if self.head_hidden_size:
self.norm = nn.Identity() self.norm = nn.Identity()
self.head = NormMlpHead( self.head = NormMlpHead(
self.num_features, self.num_features,
num_classes, num_classes,
hidden_size=cfg.head_hidden_size, hidden_size=self.head_hidden_size,
pool_type=global_pool, pool_type=global_pool,
drop_rate=drop_rate, drop_rate=drop_rate,
norm_layer=final_norm_layer, norm_layer=final_norm_layer,
@ -1253,9 +1275,7 @@ class MaxxVit(nn.Module):
def reset_classifier(self, num_classes, global_pool=None): def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes self.num_classes = num_classes
if global_pool is None: self.head.reset(num_classes, global_pool)
global_pool = self.head.global_pool.pool_type
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
@ -1376,6 +1396,7 @@ def _next_cfg(
transformer_norm_layer='layernorm2d', transformer_norm_layer='layernorm2d',
transformer_norm_layer_cl='layernorm', transformer_norm_layer_cl='layernorm',
window_size=None, window_size=None,
no_block_attn=False,
init_values=1e-6, init_values=1e-6,
rel_pos_type='mlp', # MLP by default for maxxvit rel_pos_type='mlp', # MLP by default for maxxvit
rel_pos_dim=512, rel_pos_dim=512,
@ -1396,6 +1417,7 @@ def _next_cfg(
expand_first=False, expand_first=False,
pool_type=pool_type, pool_type=pool_type,
window_size=window_size, window_size=window_size,
no_block_attn=no_block_attn, # enabled for MaxxViT-V2
init_values=init_values[1], init_values=init_values[1],
norm_layer=transformer_norm_layer, norm_layer=transformer_norm_layer,
norm_layer_cl=transformer_norm_layer_cl, norm_layer_cl=transformer_norm_layer_cl,
@ -1422,8 +1444,8 @@ def _tf_cfg():
model_cfgs = dict( model_cfgs = dict(
# Fiddling with configs / defaults / still pretraining # timm specific CoAtNet configs
coatnet_pico_rw_224=MaxxVitCfg( coatnet_pico_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(2, 3, 5, 2), depths=(2, 3, 5, 2),
stem_width=(32, 64), stem_width=(32, 64),
@ -1432,7 +1454,7 @@ model_cfgs = dict(
conv_attn_ratio=0.25, conv_attn_ratio=0.25,
), ),
), ),
coatnet_nano_rw_224=MaxxVitCfg( coatnet_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3), depths=(3, 4, 6, 3),
stem_width=(32, 64), stem_width=(32, 64),
@ -1442,7 +1464,7 @@ model_cfgs = dict(
conv_attn_ratio=0.25, conv_attn_ratio=0.25,
), ),
), ),
coatnet_0_rw_224=MaxxVitCfg( coatnet_0_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 3, 7, 2), # deeper than paper '0' model depths=(2, 3, 7, 2), # deeper than paper '0' model
stem_width=(32, 64), stem_width=(32, 64),
@ -1451,7 +1473,7 @@ model_cfgs = dict(
transformer_shortcut_bias=False, transformer_shortcut_bias=False,
), ),
), ),
coatnet_1_rw_224=MaxxVitCfg( coatnet_1_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
stem_width=(32, 64), stem_width=(32, 64),
@ -1461,7 +1483,7 @@ model_cfgs = dict(
transformer_shortcut_bias=False, transformer_shortcut_bias=False,
) )
), ),
coatnet_2_rw_224=MaxxVitCfg( coatnet_2_rw=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024), embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
stem_width=(64, 128), stem_width=(64, 128),
@ -1471,7 +1493,7 @@ model_cfgs = dict(
#init_values=1e-6, #init_values=1e-6,
), ),
), ),
coatnet_3_rw_224=MaxxVitCfg( coatnet_3_rw=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536), embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
stem_width=(96, 192), stem_width=(96, 192),
@ -1482,8 +1504,8 @@ model_cfgs = dict(
), ),
), ),
# Highly experimental configs # Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
coatnet_bn_0_rw_224=MaxxVitCfg( coatnet_bn_0_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 3, 7, 2), # deeper than paper '0' model depths=(2, 3, 7, 2), # deeper than paper '0' model
stem_width=(32, 64), stem_width=(32, 64),
@ -1494,7 +1516,7 @@ model_cfgs = dict(
transformer_norm_layer='batchnorm2d', transformer_norm_layer='batchnorm2d',
) )
), ),
coatnet_rmlp_nano_rw_224=MaxxVitCfg( coatnet_rmlp_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3), depths=(3, 4, 6, 3),
stem_width=(32, 64), stem_width=(32, 64),
@ -1505,7 +1527,7 @@ model_cfgs = dict(
rel_pos_dim=384, rel_pos_dim=384,
), ),
), ),
coatnet_rmlp_0_rw_224=MaxxVitCfg( coatnet_rmlp_0_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 3, 7, 2), # deeper than paper '0' model depths=(2, 3, 7, 2), # deeper than paper '0' model
stem_width=(32, 64), stem_width=(32, 64),
@ -1514,7 +1536,7 @@ model_cfgs = dict(
rel_pos_type='mlp', rel_pos_type='mlp',
), ),
), ),
coatnet_rmlp_1_rw_224=MaxxVitCfg( coatnet_rmlp_1_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
stem_width=(32, 64), stem_width=(32, 64),
@ -1526,7 +1548,7 @@ model_cfgs = dict(
rel_pos_dim=384, # was supposed to be 512, woops rel_pos_dim=384, # was supposed to be 512, woops
), ),
), ),
coatnet_rmlp_1_rw2_224=MaxxVitCfg( coatnet_rmlp_1_rw2=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
stem_width=(32, 64), stem_width=(32, 64),
@ -1536,7 +1558,7 @@ model_cfgs = dict(
rel_pos_dim=512, # was supposed to be 512, woops rel_pos_dim=512, # was supposed to be 512, woops
), ),
), ),
coatnet_rmlp_2_rw_224=MaxxVitCfg( coatnet_rmlp_2_rw=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024), embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
stem_width=(64, 128), stem_width=(64, 128),
@ -1547,7 +1569,7 @@ model_cfgs = dict(
rel_pos_type='mlp' rel_pos_type='mlp'
), ),
), ),
coatnet_rmlp_3_rw_224=MaxxVitCfg( coatnet_rmlp_3_rw=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536), embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
stem_width=(96, 192), stem_width=(96, 192),
@ -1559,14 +1581,14 @@ model_cfgs = dict(
), ),
), ),
coatnet_nano_cc_224=MaxxVitCfg( coatnet_nano_cc=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3), depths=(3, 4, 6, 3),
stem_width=(32, 64), stem_width=(32, 64),
block_type=('C', 'C', ('C', 'T'), ('C', 'T')), block_type=('C', 'C', ('C', 'T'), ('C', 'T')),
**_rw_coat_cfg(), **_rw_coat_cfg(),
), ),
coatnext_nano_rw_224=MaxxVitCfg( coatnext_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3), depths=(3, 4, 6, 3),
stem_width=(32, 64), stem_width=(32, 64),
@ -1578,89 +1600,95 @@ model_cfgs = dict(
), ),
# Trying to be like the CoAtNet paper configs # Trying to be like the CoAtNet paper configs
coatnet_0_224=MaxxVitCfg( coatnet_0=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 3, 5, 2), depths=(2, 3, 5, 2),
stem_width=64, stem_width=64,
head_hidden_size=768,
), ),
coatnet_1_224=MaxxVitCfg( coatnet_1=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
stem_width=64, stem_width=64,
head_hidden_size=768,
), ),
coatnet_2_224=MaxxVitCfg( coatnet_2=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024), embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
stem_width=128, stem_width=128,
head_hidden_size=1024,
), ),
coatnet_3_224=MaxxVitCfg( coatnet_3=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536), embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
stem_width=192, stem_width=192,
head_hidden_size=1536,
), ),
coatnet_4_224=MaxxVitCfg( coatnet_4=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536), embed_dim=(192, 384, 768, 1536),
depths=(2, 12, 28, 2), depths=(2, 12, 28, 2),
stem_width=192, stem_width=192,
head_hidden_size=1536,
), ),
coatnet_5_224=MaxxVitCfg( coatnet_5=MaxxVitCfg(
embed_dim=(256, 512, 1280, 2048), embed_dim=(256, 512, 1280, 2048),
depths=(2, 12, 28, 2), depths=(2, 12, 28, 2),
stem_width=192, stem_width=192,
head_hidden_size=2048,
), ),
# Experimental MaxVit configs # Experimental MaxVit configs
maxvit_pico_rw_256=MaxxVitCfg( maxvit_pico_rw=MaxxVitCfg(
embed_dim=(32, 64, 128, 256), embed_dim=(32, 64, 128, 256),
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(24, 32), stem_width=(24, 32),
**_rw_max_cfg(), **_rw_max_cfg(),
), ),
maxvit_nano_rw_256=MaxxVitCfg( maxvit_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1), depths=(1, 2, 3, 1),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(), **_rw_max_cfg(),
), ),
maxvit_tiny_rw_224=MaxxVitCfg( maxvit_tiny_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(), **_rw_max_cfg(),
), ),
maxvit_tiny_rw_256=MaxxVitCfg( maxvit_tiny_pm=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('M',) * 4, block_type=('PM',) * 4,
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(), **_rw_max_cfg(),
), ),
maxvit_rmlp_pico_rw_256=MaxxVitCfg( maxvit_rmlp_pico_rw=MaxxVitCfg(
embed_dim=(32, 64, 128, 256), embed_dim=(32, 64, 128, 256),
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(24, 32), stem_width=(24, 32),
**_rw_max_cfg(rel_pos_type='mlp'), **_rw_max_cfg(rel_pos_type='mlp'),
), ),
maxvit_rmlp_nano_rw_256=MaxxVitCfg( maxvit_rmlp_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1), depths=(1, 2, 3, 1),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(rel_pos_type='mlp'), **_rw_max_cfg(rel_pos_type='mlp'),
), ),
maxvit_rmlp_tiny_rw_256=MaxxVitCfg( maxvit_rmlp_tiny_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(rel_pos_type='mlp'), **_rw_max_cfg(rel_pos_type='mlp'),
), ),
maxvit_rmlp_small_rw_224=MaxxVitCfg( maxvit_rmlp_small_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('M',) * 4, block_type=('M',) * 4,
@ -1670,27 +1698,7 @@ model_cfgs = dict(
init_values=1e-6, init_values=1e-6,
), ),
), ),
maxvit_rmlp_small_rw_256=MaxxVitCfg( maxvit_rmlp_base_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(
rel_pos_type='mlp',
init_values=1e-6,
),
),
maxvit_rmlp_base_rw_224=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
head_hidden_size=768,
**_rw_max_cfg(
rel_pos_type='mlp',
),
),
maxvit_rmlp_base_rw_384=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
block_type=('M',) * 4, block_type=('M',) * 4,
@ -1701,15 +1709,7 @@ model_cfgs = dict(
), ),
), ),
maxvit_tiny_pm_256=MaxxVitCfg( maxxvit_rmlp_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('PM',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(),
),
maxxvit_rmlp_nano_rw_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1), depths=(1, 2, 3, 1),
block_type=('M',) * 4, block_type=('M',) * 4,
@ -1717,33 +1717,50 @@ model_cfgs = dict(
weight_init='normal', weight_init='normal',
**_next_cfg(), **_next_cfg(),
), ),
maxxvit_rmlp_tiny_rw_256=MaxxVitCfg( maxxvit_rmlp_tiny_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(32, 64), stem_width=(32, 64),
**_next_cfg(), **_next_cfg(),
), ),
maxxvit_rmlp_small_rw_256=MaxxVitCfg( maxxvit_rmlp_small_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(48, 96), stem_width=(48, 96),
**_next_cfg(), **_next_cfg(),
), ),
maxxvit_rmlp_base_rw_224=MaxxVitCfg(
maxxvitv2_nano_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2), depths=(1, 2, 3, 1),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(48, 96), stem_width=(48, 96),
**_next_cfg(), weight_init='normal',
**_next_cfg(
no_block_attn=True,
rel_pos_type='bias',
),
), ),
maxxvit_rmlp_large_rw_224=MaxxVitCfg( maxxvitv2_rmlp_base_rw=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024), embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 12, 2), depths=(2, 6, 12, 2),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(64, 128), stem_width=(64, 128),
**_next_cfg(), **_next_cfg(
no_block_attn=True,
),
),
maxxvitv2_rmlp_large_rw=MaxxVitCfg(
embed_dim=(160, 320, 640, 1280),
depths=(2, 6, 16, 2),
block_type=('M',) * 4,
stem_width=(80, 160),
head_hidden_size=1280,
**_next_cfg(
no_block_attn=True,
),
), ),
# Trying to be like the MaxViT paper configs # Trying to be like the MaxViT paper configs
@ -1795,11 +1812,29 @@ model_cfgs = dict(
) )
def checkpoint_filter_fn(state_dict, model: nn.Module):
model_state_dict = model.state_dict()
out_dict = {}
for k, v in state_dict.items():
if k in model_state_dict and v.ndim != model_state_dict[k].ndim and v.numel() == model_state_dict[k].numel():
# adapt between conv2d / linear layers
assert v.ndim in (2, 4)
v = v.reshape(model_state_dict[k].shape)
out_dict[k] = v
return out_dict
def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs): def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs):
if cfg_variant is None:
if variant in model_cfgs:
cfg_variant = variant
else:
cfg_variant = '_'.join(variant.split('_')[:-1])
return build_model_with_cfg( return build_model_with_cfg(
MaxxVit, variant, pretrained, MaxxVit, variant, pretrained,
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], model_cfg=model_cfgs[cfg_variant],
feature_cfg=dict(flatten_sequential=True), feature_cfg=dict(flatten_sequential=True),
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs) **kwargs)
@ -1815,155 +1850,218 @@ def _cfg(url='', **kwargs):
default_cfgs = generate_default_cfgs({ default_cfgs = generate_default_cfgs({
# Fiddling with configs / defaults / still pretraining # timm specific CoAtNet configs, ImageNet-1k pretrain, fixed rel-pos
'coatnet_pico_rw_224': _cfg(url=''), 'coatnet_pico_rw_224.untrained': _cfg(url=''),
'coatnet_nano_rw_224': _cfg( 'coatnet_nano_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth',
crop_pct=0.9), crop_pct=0.9),
'coatnet_0_rw_224': _cfg( 'coatnet_0_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'),
'coatnet_1_rw_224': _cfg( 'coatnet_1_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth' url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth'
), ),
'coatnet_2_rw_224': _cfg(url=''),
'coatnet_3_rw_224': _cfg(url=''),
# Highly experimental configs # timm specific CoAtNet configs, ImageNet-12k pretrain w/ 1k fine-tune, fixed rel-pos
'coatnet_bn_0_rw_224': _cfg( 'coatnet_2_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
#'coatnet_3_rw_224.untrained': _cfg(url=''),
# Experimental CoAtNet configs w/ ImageNet-12k pretrain -> 1k fine-tune (different norm layers, MLP rel-pos)
'coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
'coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
'coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
# Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
'coatnet_bn_0_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
crop_pct=0.95), crop_pct=0.95),
'coatnet_rmlp_nano_rw_224': _cfg( 'coatnet_rmlp_nano_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth',
crop_pct=0.9), crop_pct=0.9),
'coatnet_rmlp_0_rw_224': _cfg(url=''), 'coatnet_rmlp_0_rw_224.untrained': _cfg(url=''),
'coatnet_rmlp_1_rw_224': _cfg( 'coatnet_rmlp_1_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'),
'coatnet_rmlp_1_rw2_224': _cfg(url=''), 'coatnet_rmlp_2_rw_224.sw_in1k': _cfg(
'coatnet_rmlp_2_rw_224': _cfg( hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'),
'coatnet_rmlp_3_rw_224': _cfg(url=''), 'coatnet_rmlp_3_rw_224.untrained': _cfg(url=''),
'coatnet_nano_cc_224': _cfg(url=''), 'coatnet_nano_cc_224.untrained': _cfg(url=''),
'coatnext_nano_rw_224': _cfg( 'coatnext_nano_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth',
crop_pct=0.9), crop_pct=0.9),
# Trying to be like the CoAtNet paper configs # ImagenNet-12k pretrain CoAtNet
'coatnet_0_224': _cfg(url=''), 'coatnet_2_rw_224.sw_in12k': _cfg(
'coatnet_1_224': _cfg(url=''), hf_hub_id='timm/',
'coatnet_2_224': _cfg(url=''), num_classes=11821),
'coatnet_3_224': _cfg(url=''), 'coatnet_3_rw_224.sw_in12k': _cfg(
'coatnet_4_224': _cfg(url=''), hf_hub_id='timm/',
'coatnet_5_224': _cfg(url=''), num_classes=11821),
'coatnet_rmlp_1_rw2_224.sw_in12k': _cfg(
# Experimental configs hf_hub_id='timm/',
'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), num_classes=11821),
'maxvit_nano_rw_256': _cfg( 'coatnet_rmlp_2_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
# Trying to be like the CoAtNet paper configs (will adapt if 'tf' weights are ever released)
'coatnet_0_224.untrained': _cfg(url=''),
'coatnet_1_224.untrained': _cfg(url=''),
'coatnet_2_224.untrained': _cfg(url=''),
'coatnet_3_224.untrained': _cfg(url=''),
'coatnet_4_224.untrained': _cfg(url=''),
'coatnet_5_224.untrained': _cfg(url=''),
# timm specific MaxVit configs, ImageNet-1k pretrain or untrained
'maxvit_pico_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth',
input_size=(3, 256, 256), pool_size=(8, 8)), input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_rw_224': _cfg( 'maxvit_tiny_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'),
'maxvit_tiny_rw_256': _cfg( 'maxvit_tiny_rw_256.untrained': _cfg(
url='', url='',
input_size=(3, 256, 256), pool_size=(8, 8)), input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_pico_rw_256': _cfg( 'maxvit_tiny_pm_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
# timm specific MaxVit w/ MLP rel-pos, ImageNet-1k pretrain
'maxvit_rmlp_pico_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth',
input_size=(3, 256, 256), pool_size=(8, 8)), input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_nano_rw_256': _cfg( 'maxvit_rmlp_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth',
input_size=(3, 256, 256), pool_size=(8, 8)), input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_tiny_rw_256': _cfg( 'maxvit_rmlp_tiny_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth',
input_size=(3, 256, 256), pool_size=(8, 8)), input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_small_rw_224': _cfg( 'maxvit_rmlp_small_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth',
crop_pct=0.9, crop_pct=0.9,
), ),
'maxvit_rmlp_small_rw_256': _cfg( 'maxvit_rmlp_small_rw_256.untrained': _cfg(
url='', url='',
input_size=(3, 256, 256), pool_size=(8, 8)), input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_base_rw_224': _cfg(
url='',
),
'maxvit_rmlp_base_rw_384': _cfg(
url='',
input_size=(3, 384, 384), pool_size=(12, 12)),
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), # timm specific MaxVit w/ ImageNet-12k pretrain and 1k fine-tune
'maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
),
'maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
# timm specific MaxVit w/ ImageNet-12k pretrain
'maxvit_rmlp_base_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821,
),
'maxxvit_rmlp_nano_rw_256': _cfg( # timm MaxxViT configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks)
'maxxvit_rmlp_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth',
input_size=(3, 256, 256), pool_size=(8, 8)), input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_rmlp_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxxvit_rmlp_tiny_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_rmlp_small_rw_256': _cfg( 'maxxvit_rmlp_small_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth',
input_size=(3, 256, 256), pool_size=(8, 8)), input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_rmlp_base_rw_224': _cfg(url=''),
'maxxvit_rmlp_large_rw_224': _cfg(url=''),
# timm MaxxViT-V2 configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks, more width, no block attn)
'maxxvitv2_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
'maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxxvitv2_rmlp_large_rw_224.untrained': _cfg(url=''),
'maxxvitv2_rmlp_base_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
# MaxViT models ported from official Tensorflow impl # MaxViT models ported from official Tensorflow impl
'maxvit_tiny_tf_224.in1k': _cfg( 'maxvit_tiny_tf_224.in1k': _cfg(
hf_hub_id='timm/maxvit_tiny_tf_224.in1k', hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_tiny_tf_384.in1k': _cfg( 'maxvit_tiny_tf_384.in1k': _cfg(
hf_hub_id='timm/maxvit_tiny_tf_384.in1k', hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_tiny_tf_512.in1k': _cfg( 'maxvit_tiny_tf_512.in1k': _cfg(
hf_hub_id='timm/maxvit_tiny_tf_512.in1k', hf_hub_id='timm/',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_small_tf_224.in1k': _cfg( 'maxvit_small_tf_224.in1k': _cfg(
hf_hub_id='timm/maxvit_small_tf_224.in1k', hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_small_tf_384.in1k': _cfg( 'maxvit_small_tf_384.in1k': _cfg(
hf_hub_id='timm/maxvit_small_tf_384.in1k', hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_small_tf_512.in1k': _cfg( 'maxvit_small_tf_512.in1k': _cfg(
hf_hub_id='timm/maxvit_small_tf_512.in1k', hf_hub_id='timm/',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_224.in1k': _cfg( 'maxvit_base_tf_224.in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_224.in1k', hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_base_tf_384.in1k': _cfg( 'maxvit_base_tf_384.in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_384.in1k', hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_512.in1k': _cfg( 'maxvit_base_tf_512.in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_512.in1k', hf_hub_id='timm/',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_224.in1k': _cfg( 'maxvit_large_tf_224.in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_224.in1k', hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_large_tf_384.in1k': _cfg( 'maxvit_large_tf_384.in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_384.in1k', hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_512.in1k': _cfg( 'maxvit_large_tf_512.in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_512.in1k', hf_hub_id='timm/',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_224.in21k': _cfg( 'maxvit_base_tf_224.in21k': _cfg(
url=''), url=''),
'maxvit_base_tf_384.in21k_ft_in1k': _cfg( 'maxvit_base_tf_384.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_384.in21k_ft_in1k', hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_512.in21k_ft_in1k': _cfg( 'maxvit_base_tf_512.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_512.in21k_ft_in1k', hf_hub_id='timm/',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_224.in21k': _cfg( 'maxvit_large_tf_224.in21k': _cfg(
url=''), url=''),
'maxvit_large_tf_384.in21k_ft_in1k': _cfg( 'maxvit_large_tf_384.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_384.in21k_ft_in1k', hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_512.in21k_ft_in1k': _cfg( 'maxvit_large_tf_512.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_512.in21k_ft_in1k', hf_hub_id='timm/',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
'maxvit_xlarge_tf_224.in21k': _cfg( 'maxvit_xlarge_tf_224.in21k': _cfg(
url=''), url=''),
'maxvit_xlarge_tf_384.in21k_ft_in1k': _cfg( 'maxvit_xlarge_tf_384.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_xlarge_tf_384.in21k_ft_in1k', hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_xlarge_tf_512.in21k_ft_in1k': _cfg( 'maxvit_xlarge_tf_512.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_xlarge_tf_512.in21k_ft_in1k', hf_hub_id='timm/',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
}) })
@ -2027,6 +2125,11 @@ def coatnet_rmlp_2_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs) return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_rmlp_2_rw_384(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_2_rw_384', pretrained=pretrained, **kwargs)
@register_model @register_model
def coatnet_rmlp_3_rw_224(pretrained=False, **kwargs): def coatnet_rmlp_3_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs) return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs)
@ -2148,13 +2251,23 @@ def maxxvit_rmlp_small_rw_256(pretrained=False, **kwargs):
@register_model @register_model
def maxxvit_rmlp_base_rw_224(pretrained=False, **kwargs): def maxxvitv2_nano_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs) return _create_maxxvit('maxxvitv2_nano_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxxvitv2_rmlp_base_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxxvitv2_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
@register_model
def maxxvitv2_rmlp_base_rw_384(pretrained=False, **kwargs):
return _create_maxxvit('maxxvitv2_rmlp_base_rw_384', pretrained=pretrained, **kwargs)
@register_model @register_model
def maxxvit_rmlp_large_rw_224(pretrained=False, **kwargs): def maxxvitv2_rmlp_large_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxxvit_rmlp_large_rw_224', pretrained=pretrained, **kwargs) return _create_maxxvit('maxxvitv2_rmlp_large_rw_224', pretrained=pretrained, **kwargs)
@register_model @register_model

@ -496,7 +496,7 @@ class RegNet(nn.Module):
self.final_conv = get_act_layer(cfg.act_layer)() if final_act else nn.Identity() self.final_conv = get_act_layer(cfg.act_layer)() if final_act else nn.Identity()
self.num_features = prev_width self.num_features = prev_width
self.head = ClassifierHead( self.head = ClassifierHead(
in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)

@ -1029,6 +1029,10 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K', hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K',
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_gigantic_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),
'vit_base_patch32_clip_224.openai': _cfg( 'vit_base_patch32_clip_224.openai': _cfg(
hf_hub_id='timm/', hf_hub_id='timm/',
@ -1498,6 +1502,17 @@ def vit_giant_patch14_clip_224(pretrained=False, **kwargs):
return model return model
@register_model
def vit_gigantic_patch14_clip_224(pretrained=False, **kwargs):
""" ViT-bigG model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
Pretrained weights from CLIP image tower.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_gigantic_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
# Experimental models below # Experimental models below
@register_model @register_model

@ -216,7 +216,7 @@ class XceptionAligned(nn.Module):
num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))] num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))]
self.act = act_layer(inplace=True) if preact else nn.Identity() self.act = act_layer(inplace=True) if preact else nn.Identity()
self.head = ClassifierHead( self.head = ClassifierHead(
in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
@torch.jit.ignore @torch.jit.ignore
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):

@ -1 +1 @@
__version__ = '0.8.6dev0' __version__ = '0.8.8dev0'

Loading…
Cancel
Save