diff --git a/README.md b/README.md index ee07c368..287b6f66 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,59 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before * ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗ * Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch. +### Jan 20, 2023 +* Add two convnext 12k -> 1k fine-tunes at 384x384 + * `convnext_tiny.in12k_ft_in1k_384` - 85.1 @ 384 + * `convnext_small.in12k_ft_in1k_384` - 86.2 @ 384 + +* Push all MaxxViT weights to HF hub, and add new ImageNet-12k -> 1k fine-tunes for `rw` base MaxViT and CoAtNet 1/2 models + +|model |top1 |top5 |samples / sec |Params (M) |GMAC |Act (M)| +|------------------------------------------------------------------------------------------------------------------------|----:|----:|--------------:|--------------:|-----:|------:| +|[maxvit_xlarge_tf_512.in21k_ft_in1k](https://huggingface.co/timm/maxvit_xlarge_tf_512.in21k_ft_in1k) |88.53|98.64| 21.76| 475.77|534.14|1413.22| +|[maxvit_xlarge_tf_384.in21k_ft_in1k](https://huggingface.co/timm/maxvit_xlarge_tf_384.in21k_ft_in1k) |88.32|98.54| 42.53| 475.32|292.78| 668.76| +|[maxvit_base_tf_512.in21k_ft_in1k](https://huggingface.co/timm/maxvit_base_tf_512.in21k_ft_in1k) |88.20|98.53| 50.87| 119.88|138.02| 703.99| +|[maxvit_large_tf_512.in21k_ft_in1k](https://huggingface.co/timm/maxvit_large_tf_512.in21k_ft_in1k) |88.04|98.40| 36.42| 212.33|244.75| 942.15| +|[maxvit_large_tf_384.in21k_ft_in1k](https://huggingface.co/timm/maxvit_large_tf_384.in21k_ft_in1k) |87.98|98.56| 71.75| 212.03|132.55| 445.84| +|[maxvit_base_tf_384.in21k_ft_in1k](https://huggingface.co/timm/maxvit_base_tf_384.in21k_ft_in1k) |87.92|98.54| 104.71| 119.65| 73.80| 332.90| +|[maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k](https://huggingface.co/timm/maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k) |87.81|98.37| 106.55| 116.14| 70.97| 318.95| +|[maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k](https://huggingface.co/timm/maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k) |87.47|98.37| 149.49| 116.09| 72.98| 213.74| +|[coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k) |87.39|98.31| 160.80| 73.88| 47.69| 209.43| +|[maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k) |86.89|98.02| 375.86| 116.14| 23.15| 92.64| +|[maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k) |86.64|98.02| 501.03| 116.09| 24.20| 62.77| +|[maxvit_base_tf_512.in1k](https://huggingface.co/timm/maxvit_base_tf_512.in1k) |86.60|97.92| 50.75| 119.88|138.02| 703.99| +|[coatnet_2_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_2_rw_224.sw_in12k_ft_in1k) |86.57|97.89| 631.88| 73.87| 15.09| 49.22| +|[maxvit_large_tf_512.in1k](https://huggingface.co/timm/maxvit_large_tf_512.in1k) |86.52|97.88| 36.04| 212.33|244.75| 942.15| +|[coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k) |86.49|97.90| 620.58| 73.88| 15.18| 54.78| +|[maxvit_base_tf_384.in1k](https://huggingface.co/timm/maxvit_base_tf_384.in1k) |86.29|97.80| 101.09| 119.65| 73.80| 332.90| +|[maxvit_large_tf_384.in1k](https://huggingface.co/timm/maxvit_large_tf_384.in1k) |86.23|97.69| 70.56| 212.03|132.55| 445.84| +|[maxvit_small_tf_512.in1k](https://huggingface.co/timm/maxvit_small_tf_512.in1k) |86.10|97.76| 88.63| 69.13| 67.26| 383.77| +|[maxvit_tiny_tf_512.in1k](https://huggingface.co/timm/maxvit_tiny_tf_512.in1k) |85.67|97.58| 144.25| 31.05| 33.49| 257.59| +|[maxvit_small_tf_384.in1k](https://huggingface.co/timm/maxvit_small_tf_384.in1k) |85.54|97.46| 188.35| 69.02| 35.87| 183.65| +|[maxvit_tiny_tf_384.in1k](https://huggingface.co/timm/maxvit_tiny_tf_384.in1k) |85.11|97.38| 293.46| 30.98| 17.53| 123.42| +|[maxvit_large_tf_224.in1k](https://huggingface.co/timm/maxvit_large_tf_224.in1k) |84.93|96.97| 247.71| 211.79| 43.68| 127.35| +|[coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k) |84.90|96.96| 1025.45| 41.72| 8.11| 40.13| +|[maxvit_base_tf_224.in1k](https://huggingface.co/timm/maxvit_base_tf_224.in1k) |84.85|96.99| 358.25| 119.47| 24.04| 95.01| +|[maxxvit_rmlp_small_rw_256.sw_in1k](https://huggingface.co/timm/maxxvit_rmlp_small_rw_256.sw_in1k) |84.63|97.06| 575.53| 66.01| 14.67| 58.38| +|[coatnet_rmlp_2_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_rmlp_2_rw_224.sw_in1k) |84.61|96.74| 625.81| 73.88| 15.18| 54.78| +|[maxvit_rmlp_small_rw_224.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_small_rw_224.sw_in1k) |84.49|96.76| 693.82| 64.90| 10.75| 49.30| +|[maxvit_small_tf_224.in1k](https://huggingface.co/timm/maxvit_small_tf_224.in1k) |84.43|96.83| 647.96| 68.93| 11.66| 53.17| +|[maxvit_rmlp_tiny_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_tiny_rw_256.sw_in1k) |84.23|96.78| 807.21| 29.15| 6.77| 46.92| +|[coatnet_1_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_1_rw_224.sw_in1k) |83.62|96.38| 989.59| 41.72| 8.04| 34.60| +|[maxvit_tiny_rw_224.sw_in1k](https://huggingface.co/timm/maxvit_tiny_rw_224.sw_in1k) |83.50|96.50| 1100.53| 29.06| 5.11| 33.11| +|[maxvit_tiny_tf_224.in1k](https://huggingface.co/timm/maxvit_tiny_tf_224.in1k) |83.41|96.59| 1004.94| 30.92| 5.60| 35.78| +|[coatnet_rmlp_1_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_rmlp_1_rw_224.sw_in1k) |83.36|96.45| 1093.03| 41.69| 7.85| 35.47| +|[maxxvitv2_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxxvitv2_nano_rw_256.sw_in1k) |83.11|96.33| 1276.88| 23.70| 6.26| 23.05| +|[maxxvit_rmlp_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxxvit_rmlp_nano_rw_256.sw_in1k) |83.03|96.34| 1341.24| 16.78| 4.37| 26.05| +|[maxvit_rmlp_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_nano_rw_256.sw_in1k) |82.96|96.26| 1283.24| 15.50| 4.47| 31.92| +|[maxvit_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_nano_rw_256.sw_in1k) |82.93|96.23| 1218.17| 15.45| 4.46| 30.28| +|[coatnet_bn_0_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_bn_0_rw_224.sw_in1k) |82.39|96.19| 1600.14| 27.44| 4.67| 22.04| +|[coatnet_0_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_0_rw_224.sw_in1k) |82.39|95.84| 1831.21| 27.44| 4.43| 18.73| +|[coatnet_rmlp_nano_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_rmlp_nano_rw_224.sw_in1k) |82.05|95.87| 2109.09| 15.15| 2.62| 20.34| +|[coatnext_nano_rw_224.sw_in1k](https://huggingface.co/timm/coatnext_nano_rw_224.sw_in1k) |81.95|95.92| 2525.52| 14.70| 2.47| 12.80| +|[coatnet_nano_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_nano_rw_224.sw_in1k) |81.70|95.64| 2344.52| 15.14| 2.41| 15.41| +|[maxvit_rmlp_pico_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_pico_rw_256.sw_in1k) |80.53|95.21| 1594.71| 7.52| 1.85| 24.86| + ### Jan 11, 2023 * Update ConvNeXt ImageNet-12k pretrain series w/ two new fine-tuned weights (and pre FT `.in12k` tags) * `convnext_nano.in12k_ft_in1k` - 82.3 @ 224, 82.9 @ 288 (previously released) diff --git a/tests/test_models.py b/tests/test_models.py index 15e222b7..eb470d5f 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -27,8 +27,9 @@ NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', - 'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*', 'flexivit*' + 'eva_*', 'flexivit*' ] +#'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', ' NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures @@ -38,7 +39,7 @@ if 'GITHUB_ACTIONS' in os.environ: '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*', - 'swin*giant*', 'convnextv2_huge*', 'davit_giant', 'davit_huge'] + 'swin*giant*', 'convnextv2_huge*', 'maxvit_xlarge*', 'davit_giant', 'davit_huge'] NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*'] else: EXCLUDE_FILTERS = [] @@ -53,7 +54,7 @@ MAX_JIT_SIZE = 320 TARGET_FFEAT_SIZE = 96 MAX_FFEAT_SIZE = 256 TARGET_FWD_FX_SIZE = 128 -MAX_FWD_FX_SIZE = 224 +MAX_FWD_FX_SIZE = 256 TARGET_BWD_FX_SIZE = 128 MAX_BWD_FX_SIZE = 224 diff --git a/timm/data/__init__.py b/timm/data/__init__.py index 7cc7b0b0..9f62a7d5 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -1,6 +1,6 @@ from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ rand_augment_transform, auto_augment_transform -from .config import resolve_data_config +from .config import resolve_data_config, resolve_model_data_config from .constants import * from .dataset import ImageDataset, IterableImageDataset, AugMixDataset from .dataset_factory import create_dataset diff --git a/timm/data/config.py b/timm/data/config.py index a65695d0..a6c2298c 100644 --- a/timm/data/config.py +++ b/timm/data/config.py @@ -6,16 +6,18 @@ _logger = logging.getLogger(__name__) def resolve_data_config( - args, - default_cfg=None, + args=None, + pretrained_cfg=None, model=None, use_test_size=False, verbose=False ): - new_config = {} - default_cfg = default_cfg or {} - if not default_cfg and model is not None and hasattr(model, 'default_cfg'): - default_cfg = model.default_cfg + assert model or args or pretrained_cfg, "At least one of model, args, or pretrained_cfg required for data config." + args = args or {} + pretrained_cfg = pretrained_cfg or {} + if not pretrained_cfg and model is not None and hasattr(model, 'pretrained_cfg'): + pretrained_cfg = model.pretrained_cfg + data_config = {} # Resolve input/image size in_chans = 3 @@ -32,65 +34,94 @@ def resolve_data_config( assert isinstance(args['img_size'], int) input_size = (in_chans, args['img_size'], args['img_size']) else: - if use_test_size and default_cfg.get('test_input_size', None) is not None: - input_size = default_cfg['test_input_size'] - elif default_cfg.get('input_size', None) is not None: - input_size = default_cfg['input_size'] - new_config['input_size'] = input_size + if use_test_size and pretrained_cfg.get('test_input_size', None) is not None: + input_size = pretrained_cfg['test_input_size'] + elif pretrained_cfg.get('input_size', None) is not None: + input_size = pretrained_cfg['input_size'] + data_config['input_size'] = input_size # resolve interpolation method - new_config['interpolation'] = 'bicubic' + data_config['interpolation'] = 'bicubic' if args.get('interpolation', None): - new_config['interpolation'] = args['interpolation'] - elif default_cfg.get('interpolation', None): - new_config['interpolation'] = default_cfg['interpolation'] + data_config['interpolation'] = args['interpolation'] + elif pretrained_cfg.get('interpolation', None): + data_config['interpolation'] = pretrained_cfg['interpolation'] # resolve dataset + model mean for normalization - new_config['mean'] = IMAGENET_DEFAULT_MEAN + data_config['mean'] = IMAGENET_DEFAULT_MEAN if args.get('mean', None) is not None: mean = tuple(args['mean']) if len(mean) == 1: mean = tuple(list(mean) * in_chans) else: assert len(mean) == in_chans - new_config['mean'] = mean - elif default_cfg.get('mean', None): - new_config['mean'] = default_cfg['mean'] + data_config['mean'] = mean + elif pretrained_cfg.get('mean', None): + data_config['mean'] = pretrained_cfg['mean'] # resolve dataset + model std deviation for normalization - new_config['std'] = IMAGENET_DEFAULT_STD + data_config['std'] = IMAGENET_DEFAULT_STD if args.get('std', None) is not None: std = tuple(args['std']) if len(std) == 1: std = tuple(list(std) * in_chans) else: assert len(std) == in_chans - new_config['std'] = std - elif default_cfg.get('std', None): - new_config['std'] = default_cfg['std'] + data_config['std'] = std + elif pretrained_cfg.get('std', None): + data_config['std'] = pretrained_cfg['std'] # resolve default inference crop crop_pct = DEFAULT_CROP_PCT if args.get('crop_pct', None): crop_pct = args['crop_pct'] else: - if use_test_size and default_cfg.get('test_crop_pct', None): - crop_pct = default_cfg['test_crop_pct'] - elif default_cfg.get('crop_pct', None): - crop_pct = default_cfg['crop_pct'] - new_config['crop_pct'] = crop_pct + if use_test_size and pretrained_cfg.get('test_crop_pct', None): + crop_pct = pretrained_cfg['test_crop_pct'] + elif pretrained_cfg.get('crop_pct', None): + crop_pct = pretrained_cfg['crop_pct'] + data_config['crop_pct'] = crop_pct # resolve default crop percentage crop_mode = DEFAULT_CROP_MODE if args.get('crop_mode', None): crop_mode = args['crop_mode'] - elif default_cfg.get('crop_mode', None): - crop_mode = default_cfg['crop_mode'] - new_config['crop_mode'] = crop_mode + elif pretrained_cfg.get('crop_mode', None): + crop_mode = pretrained_cfg['crop_mode'] + data_config['crop_mode'] = crop_mode if verbose: _logger.info('Data processing configuration for current model + dataset:') - for n, v in new_config.items(): + for n, v in data_config.items(): _logger.info('\t%s: %s' % (n, str(v))) - return new_config + return data_config + + +def resolve_model_data_config( + model, + args=None, + pretrained_cfg=None, + use_test_size=False, + verbose=False, +): + """ Resolve Model Data Config + This is equivalent to resolve_data_config() but with arguments re-ordered to put model first. + + Args: + model (nn.Module): the model instance + args (dict): command line arguments / configuration in dict form (overrides pretrained_cfg) + pretrained_cfg (dict): pretrained model config (overrides pretrained_cfg attached to model) + use_test_size (bool): use the test time input resolution (if one exists) instead of default train resolution + verbose (bool): enable extra logging of resolved values + + Returns: + dictionary of config + """ + return resolve_data_config( + args=args, + pretrained_cfg=pretrained_cfg, + model=model, + use_test_size=use_test_size, + verbose=verbose, + ) diff --git a/timm/layers/classifier.py b/timm/layers/classifier.py index 3ac33387..e885084c 100644 --- a/timm/layers/classifier.py +++ b/timm/layers/classifier.py @@ -38,13 +38,24 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False class ClassifierHead(nn.Module): """Classifier head w/ configurable global pooling and dropout.""" - def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False): + def __init__(self, in_features, num_classes, pool_type='avg', drop_rate=0., use_conv=False): super(ClassifierHead, self).__init__() self.drop_rate = drop_rate - self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) + self.in_features = in_features + self.use_conv = use_conv + + self.global_pool, num_pooled_features = _create_pool(in_features, num_classes, pool_type, use_conv=use_conv) self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() + def reset(self, num_classes, global_pool=None): + if global_pool is not None: + if global_pool != self.global_pool.pool_type: + self.global_pool, _ = _create_pool(self.in_features, num_classes, global_pool, use_conv=self.use_conv) + self.flatten = nn.Flatten(1) if self.use_conv and global_pool else nn.Identity() + num_pooled_features = self.in_features * self.global_pool.feat_mult() + self.fc = _create_fc(num_pooled_features, num_classes, use_conv=self.use_conv) + def forward(self, x, pre_logits: bool = False): x = self.global_pool(x) if self.drop_rate: diff --git a/timm/models/_builder.py b/timm/models/_builder.py index 901d7d44..32a35304 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -179,11 +179,11 @@ def load_pretrained( return if filter_fn is not None: - # for backwards compat with filter fn that take one arg, try one first, the two try: - state_dict = filter_fn(state_dict) - except TypeError: state_dict = filter_fn(state_dict, model) + except TypeError as e: + # for backwards compat with filter fn that take one arg + state_dict = filter_fn(state_dict) input_convs = pretrained_cfg.get('first_conv', None) if input_convs is not None and in_chans != 3: diff --git a/timm/models/_hub.py b/timm/models/_hub.py index df1a1ef7..378d646c 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -236,20 +236,7 @@ def push_to_hf_hub( model_card = model_card or {} model_name = repo_id.split('/')[-1] readme_path = Path(tmpdir) / "README.md" - readme_text = "---\n" - readme_text += "tags:\n- image-classification\n- timm\n" - readme_text += "library_tag: timm\n" - readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n" - readme_text += "---\n" - readme_text += f"# Model card for {model_name}\n" - if 'description' in model_card: - readme_text += f"\n{model_card['description']}\n" - if 'details' in model_card: - readme_text += f"\n## Model Details\n" - for k, v in model_card['details'].items(): - readme_text += f"- **{k}:** {v}\n" - if 'citation' in model_card: - readme_text += f"\n## Citation\n```\n{model_card['citation']}```\n" + readme_text = generate_readme(model_card, model_name) readme_path.write_text(readme_text) # Upload model and return @@ -260,3 +247,51 @@ def push_to_hf_hub( create_pr=create_pr, commit_message=commit_message, ) + + +def generate_readme(model_card, model_name): + readme_text = "---\n" + readme_text += "tags:\n- image-classification\n- timm\n" + readme_text += "library_tag: timm\n" + readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n" + if 'details' in model_card and 'Dataset' in model_card['details']: + readme_text += 'datasets:\n' + readme_text += f"- {model_card['details']['Dataset'].lower()}\n" + if 'Pretrain Dataset' in model_card['details']: + readme_text += f"- {model_card['details']['Pretrain Dataset'].lower()}\n" + readme_text += "---\n" + readme_text += f"# Model card for {model_name}\n" + if 'description' in model_card: + readme_text += f"\n{model_card['description']}\n" + if 'details' in model_card: + readme_text += f"\n## Model Details\n" + for k, v in model_card['details'].items(): + if isinstance(v, (list, tuple)): + readme_text += f"- **{k}:**\n" + for vi in v: + readme_text += f" - {vi}\n" + elif isinstance(v, dict): + readme_text += f"- **{k}:**\n" + for ki, vi in v.items(): + readme_text += f" - {ki}: {vi}\n" + else: + readme_text += f"- **{k}:** {v}\n" + if 'usage' in model_card: + readme_text += f"\n## Model Usage\n" + readme_text += model_card['usage'] + readme_text += '\n' + + if 'comparison' in model_card: + readme_text += f"\n## Model Comparison\n" + readme_text += model_card['comparison'] + readme_text += '\n' + + if 'citation' in model_card: + readme_text += f"\n## Citation\n" + if not isinstance(model_card['citation'], (list, tuple)): + citations = [model_card['citation']] + else: + citations = model_card['citation'] + for c in citations: + readme_text += f"```bibtex\n{c}\n```\n" + return readme_text diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 05e29a73..2bbe0b11 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -500,6 +500,13 @@ default_cfgs = generate_default_cfgs({ hf_hub_id='timm/', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_tiny.in12k_ft_in1k_384': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnext_small.in12k_ft_in1k_384': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnext_nano.in12k': _cfg( hf_hub_id='timm/', crop_pct=0.95, num_classes=11821), @@ -706,27 +713,27 @@ default_cfgs = generate_default_cfgs({ hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K', hf_hub_filename='open_clip_pytorch_model.bin', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - input_size=(3, 256, 256), crop_pct=1.0, num_classes=640), + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640), 'convnext_base.clip_laion2b_augreg': _cfg( hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg', hf_hub_filename='open_clip_pytorch_model.bin', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - input_size=(3, 256, 256), crop_pct=1.0, num_classes=640), + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640), 'convnext_base.clip_laiona': _cfg( hf_hub_id='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K', hf_hub_filename='open_clip_pytorch_model.bin', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - input_size=(3, 256, 256), crop_pct=1.0, num_classes=640), + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640), 'convnext_base.clip_laiona_320': _cfg( hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K', hf_hub_filename='open_clip_pytorch_model.bin', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - input_size=(3, 320, 320), crop_pct=1.0, num_classes=640), + input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640), 'convnext_base.clip_laiona_augreg_320': _cfg( hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg', hf_hub_filename='open_clip_pytorch_model.bin', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - input_size=(3, 320, 320), crop_pct=1.0, num_classes=640), + input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640), }) diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 26ec54d9..da9d1ae0 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -913,7 +913,7 @@ class CspNet(nn.Module): # Construct the head self.num_features = prev_chs self.head = ClassifierHead( - in_chs=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + in_features=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index dd424078..e730fa30 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -12,9 +12,6 @@ These configs work well and appear to be a bit faster / lower resource than the The models without extra prefix / suffix' (coatnet_0_224, maxvit_tiny_224, etc), are intended to match paper, BUT, without any official pretrained weights it's difficult to confirm a 100% match. -# FIXME / WARNING -This impl remains a WIP, some configs and models may vanish or change... - Papers: MaxViT: Multi-Axis Vision Transformer - https://arxiv.org/abs/2204.01697 @@ -76,6 +73,8 @@ class MaxxVitTransformerCfg: partition_ratio: int = 32 window_size: Optional[Tuple[int, int]] = None grid_size: Optional[Tuple[int, int]] = None + no_block_attn: bool = False # disable window block attention for maxvit (ie only grid) + use_nchw_attn: bool = False # for MaxViT variants (not used for CoAt), keep tensors in NCHW order init_values: Optional[float] = None act_layer: str = 'gelu' norm_layer: str = 'layernorm2d' @@ -889,19 +888,17 @@ class MaxxVitBlock(nn.Module): stride: int = 1, conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), - use_nchw_attn: bool = False, # FIXME move to cfg? True is ~20-30% faster on TPU, 5-10% slower on GPU - use_block_attn: bool = True, # FIXME for testing ConvNeXt conv w/o block attention drop_path: float = 0., ): super().__init__() + self.nchw_attn = transformer_cfg.use_nchw_attn conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) - partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttentionCl - self.nchw_attn = use_nchw_attn - self.attn_block = partition_layer(**attn_kwargs) if use_block_attn else None + partition_layer = PartitionAttention2d if self.nchw_attn else PartitionAttentionCl + self.attn_block = None if transformer_cfg.no_block_attn else partition_layer(**attn_kwargs) self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs) def init_weights(self, scheme=''): @@ -1084,26 +1081,48 @@ class NormMlpHead(nn.Module): hidden_size=None, pool_type='avg', drop_rate=0., - norm_layer=nn.LayerNorm, - act_layer=nn.Tanh, + norm_layer='layernorm2d', + act_layer='tanh', ): super().__init__() self.drop_rate = drop_rate + self.in_features = in_features + self.hidden_size = hidden_size self.num_features = in_features + self.use_conv = not pool_type + norm_layer = get_norm_layer(norm_layer) + act_layer = get_act_layer(act_layer) + linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) self.norm = norm_layer(in_features) self.flatten = nn.Flatten(1) if pool_type else nn.Identity() if hidden_size: self.pre_logits = nn.Sequential(OrderedDict([ - ('fc', nn.Linear(in_features, hidden_size)), + ('fc', linear_layer(in_features, hidden_size)), ('act', act_layer()), ])) self.num_features = hidden_size else: self.pre_logits = nn.Identity() self.drop = nn.Dropout(self.drop_rate) - self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def reset(self, num_classes, global_pool=None): + if global_pool is not None: + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() + self.use_conv = self.global_pool.is_identity() + linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear + if self.hidden_size: + if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or + (isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)): + with torch.no_grad(): + new_fc = linear_layer(self.in_features, self.hidden_size) + new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape)) + new_fc.bias.copy_(self.pre_logits.fc.bias) + self.pre_logits.fc = new_fc + self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward(self, x, pre_logits: bool = False): x = self.global_pool(x) @@ -1163,6 +1182,7 @@ class MaxxVit(nn.Module): self.num_features = self.embed_dim = cfg.embed_dim[-1] self.drop_rate = drop_rate self.grad_checkpointing = False + self.feature_info = [] self.stem = Stem( in_chs=in_chans, @@ -1173,8 +1193,8 @@ class MaxxVit(nn.Module): norm_layer=cfg.conv_cfg.norm_layer, norm_eps=cfg.conv_cfg.norm_eps, ) - stride = self.stem.stride + self.feature_info += [dict(num_chs=self.stem.out_chs, reduction=2, module='stem')] feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))]) num_stages = len(cfg.embed_dim) @@ -1198,15 +1218,17 @@ class MaxxVit(nn.Module): )] stride *= stage_stride in_chs = out_chs + self.feature_info += [dict(num_chs=out_chs, reduction=stride, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) final_norm_layer = partial(get_norm_layer(cfg.transformer_cfg.norm_layer), eps=cfg.transformer_cfg.norm_eps) - if cfg.head_hidden_size: + self.head_hidden_size = cfg.head_hidden_size + if self.head_hidden_size: self.norm = nn.Identity() self.head = NormMlpHead( self.num_features, num_classes, - hidden_size=cfg.head_hidden_size, + hidden_size=self.head_hidden_size, pool_type=global_pool, drop_rate=drop_rate, norm_layer=final_norm_layer, @@ -1253,9 +1275,7 @@ class MaxxVit(nn.Module): def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes - if global_pool is None: - global_pool = self.head.global_pool.pool_type - self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + self.head.reset(num_classes, global_pool) def forward_features(self, x): x = self.stem(x) @@ -1376,6 +1396,7 @@ def _next_cfg( transformer_norm_layer='layernorm2d', transformer_norm_layer_cl='layernorm', window_size=None, + no_block_attn=False, init_values=1e-6, rel_pos_type='mlp', # MLP by default for maxxvit rel_pos_dim=512, @@ -1396,6 +1417,7 @@ def _next_cfg( expand_first=False, pool_type=pool_type, window_size=window_size, + no_block_attn=no_block_attn, # enabled for MaxxViT-V2 init_values=init_values[1], norm_layer=transformer_norm_layer, norm_layer_cl=transformer_norm_layer_cl, @@ -1422,8 +1444,8 @@ def _tf_cfg(): model_cfgs = dict( - # Fiddling with configs / defaults / still pretraining - coatnet_pico_rw_224=MaxxVitCfg( + # timm specific CoAtNet configs + coatnet_pico_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(2, 3, 5, 2), stem_width=(32, 64), @@ -1432,7 +1454,7 @@ model_cfgs = dict( conv_attn_ratio=0.25, ), ), - coatnet_nano_rw_224=MaxxVitCfg( + coatnet_nano_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(3, 4, 6, 3), stem_width=(32, 64), @@ -1442,7 +1464,7 @@ model_cfgs = dict( conv_attn_ratio=0.25, ), ), - coatnet_0_rw_224=MaxxVitCfg( + coatnet_0_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 3, 7, 2), # deeper than paper '0' model stem_width=(32, 64), @@ -1451,7 +1473,7 @@ model_cfgs = dict( transformer_shortcut_bias=False, ), ), - coatnet_1_rw_224=MaxxVitCfg( + coatnet_1_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 6, 14, 2), stem_width=(32, 64), @@ -1461,7 +1483,7 @@ model_cfgs = dict( transformer_shortcut_bias=False, ) ), - coatnet_2_rw_224=MaxxVitCfg( + coatnet_2_rw=MaxxVitCfg( embed_dim=(128, 256, 512, 1024), depths=(2, 6, 14, 2), stem_width=(64, 128), @@ -1471,7 +1493,7 @@ model_cfgs = dict( #init_values=1e-6, ), ), - coatnet_3_rw_224=MaxxVitCfg( + coatnet_3_rw=MaxxVitCfg( embed_dim=(192, 384, 768, 1536), depths=(2, 6, 14, 2), stem_width=(96, 192), @@ -1482,8 +1504,8 @@ model_cfgs = dict( ), ), - # Highly experimental configs - coatnet_bn_0_rw_224=MaxxVitCfg( + # Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos) + coatnet_bn_0_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 3, 7, 2), # deeper than paper '0' model stem_width=(32, 64), @@ -1494,7 +1516,7 @@ model_cfgs = dict( transformer_norm_layer='batchnorm2d', ) ), - coatnet_rmlp_nano_rw_224=MaxxVitCfg( + coatnet_rmlp_nano_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(3, 4, 6, 3), stem_width=(32, 64), @@ -1505,7 +1527,7 @@ model_cfgs = dict( rel_pos_dim=384, ), ), - coatnet_rmlp_0_rw_224=MaxxVitCfg( + coatnet_rmlp_0_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 3, 7, 2), # deeper than paper '0' model stem_width=(32, 64), @@ -1514,7 +1536,7 @@ model_cfgs = dict( rel_pos_type='mlp', ), ), - coatnet_rmlp_1_rw_224=MaxxVitCfg( + coatnet_rmlp_1_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 6, 14, 2), stem_width=(32, 64), @@ -1526,7 +1548,7 @@ model_cfgs = dict( rel_pos_dim=384, # was supposed to be 512, woops ), ), - coatnet_rmlp_1_rw2_224=MaxxVitCfg( + coatnet_rmlp_1_rw2=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 6, 14, 2), stem_width=(32, 64), @@ -1536,7 +1558,7 @@ model_cfgs = dict( rel_pos_dim=512, # was supposed to be 512, woops ), ), - coatnet_rmlp_2_rw_224=MaxxVitCfg( + coatnet_rmlp_2_rw=MaxxVitCfg( embed_dim=(128, 256, 512, 1024), depths=(2, 6, 14, 2), stem_width=(64, 128), @@ -1547,7 +1569,7 @@ model_cfgs = dict( rel_pos_type='mlp' ), ), - coatnet_rmlp_3_rw_224=MaxxVitCfg( + coatnet_rmlp_3_rw=MaxxVitCfg( embed_dim=(192, 384, 768, 1536), depths=(2, 6, 14, 2), stem_width=(96, 192), @@ -1559,14 +1581,14 @@ model_cfgs = dict( ), ), - coatnet_nano_cc_224=MaxxVitCfg( + coatnet_nano_cc=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(3, 4, 6, 3), stem_width=(32, 64), block_type=('C', 'C', ('C', 'T'), ('C', 'T')), **_rw_coat_cfg(), ), - coatnext_nano_rw_224=MaxxVitCfg( + coatnext_nano_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(3, 4, 6, 3), stem_width=(32, 64), @@ -1578,89 +1600,95 @@ model_cfgs = dict( ), # Trying to be like the CoAtNet paper configs - coatnet_0_224=MaxxVitCfg( + coatnet_0=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 3, 5, 2), stem_width=64, + head_hidden_size=768, ), - coatnet_1_224=MaxxVitCfg( + coatnet_1=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 6, 14, 2), stem_width=64, + head_hidden_size=768, ), - coatnet_2_224=MaxxVitCfg( + coatnet_2=MaxxVitCfg( embed_dim=(128, 256, 512, 1024), depths=(2, 6, 14, 2), stem_width=128, + head_hidden_size=1024, ), - coatnet_3_224=MaxxVitCfg( + coatnet_3=MaxxVitCfg( embed_dim=(192, 384, 768, 1536), depths=(2, 6, 14, 2), stem_width=192, + head_hidden_size=1536, ), - coatnet_4_224=MaxxVitCfg( + coatnet_4=MaxxVitCfg( embed_dim=(192, 384, 768, 1536), depths=(2, 12, 28, 2), stem_width=192, + head_hidden_size=1536, ), - coatnet_5_224=MaxxVitCfg( + coatnet_5=MaxxVitCfg( embed_dim=(256, 512, 1280, 2048), depths=(2, 12, 28, 2), stem_width=192, + head_hidden_size=2048, ), # Experimental MaxVit configs - maxvit_pico_rw_256=MaxxVitCfg( + maxvit_pico_rw=MaxxVitCfg( embed_dim=(32, 64, 128, 256), depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(24, 32), **_rw_max_cfg(), ), - maxvit_nano_rw_256=MaxxVitCfg( + maxvit_nano_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(1, 2, 3, 1), block_type=('M',) * 4, stem_width=(32, 64), **_rw_max_cfg(), ), - maxvit_tiny_rw_224=MaxxVitCfg( + maxvit_tiny_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(32, 64), **_rw_max_cfg(), ), - maxvit_tiny_rw_256=MaxxVitCfg( + maxvit_tiny_pm=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(2, 2, 5, 2), - block_type=('M',) * 4, + block_type=('PM',) * 4, stem_width=(32, 64), **_rw_max_cfg(), ), - maxvit_rmlp_pico_rw_256=MaxxVitCfg( + maxvit_rmlp_pico_rw=MaxxVitCfg( embed_dim=(32, 64, 128, 256), depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(24, 32), **_rw_max_cfg(rel_pos_type='mlp'), ), - maxvit_rmlp_nano_rw_256=MaxxVitCfg( + maxvit_rmlp_nano_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(1, 2, 3, 1), block_type=('M',) * 4, stem_width=(32, 64), **_rw_max_cfg(rel_pos_type='mlp'), ), - maxvit_rmlp_tiny_rw_256=MaxxVitCfg( + maxvit_rmlp_tiny_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(32, 64), **_rw_max_cfg(rel_pos_type='mlp'), ), - maxvit_rmlp_small_rw_224=MaxxVitCfg( + maxvit_rmlp_small_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 2, 5, 2), block_type=('M',) * 4, @@ -1670,27 +1698,7 @@ model_cfgs = dict( init_values=1e-6, ), ), - maxvit_rmlp_small_rw_256=MaxxVitCfg( - embed_dim=(96, 192, 384, 768), - depths=(2, 2, 5, 2), - block_type=('M',) * 4, - stem_width=(32, 64), - **_rw_max_cfg( - rel_pos_type='mlp', - init_values=1e-6, - ), - ), - maxvit_rmlp_base_rw_224=MaxxVitCfg( - embed_dim=(96, 192, 384, 768), - depths=(2, 6, 14, 2), - block_type=('M',) * 4, - stem_width=(32, 64), - head_hidden_size=768, - **_rw_max_cfg( - rel_pos_type='mlp', - ), - ), - maxvit_rmlp_base_rw_384=MaxxVitCfg( + maxvit_rmlp_base_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 6, 14, 2), block_type=('M',) * 4, @@ -1701,15 +1709,7 @@ model_cfgs = dict( ), ), - maxvit_tiny_pm_256=MaxxVitCfg( - embed_dim=(64, 128, 256, 512), - depths=(2, 2, 5, 2), - block_type=('PM',) * 4, - stem_width=(32, 64), - **_rw_max_cfg(), - ), - - maxxvit_rmlp_nano_rw_256=MaxxVitCfg( + maxxvit_rmlp_nano_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(1, 2, 3, 1), block_type=('M',) * 4, @@ -1717,33 +1717,50 @@ model_cfgs = dict( weight_init='normal', **_next_cfg(), ), - maxxvit_rmlp_tiny_rw_256=MaxxVitCfg( + maxxvit_rmlp_tiny_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(32, 64), **_next_cfg(), ), - maxxvit_rmlp_small_rw_256=MaxxVitCfg( + maxxvit_rmlp_small_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(48, 96), **_next_cfg(), ), - maxxvit_rmlp_base_rw_224=MaxxVitCfg( + + maxxvitv2_nano_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), - depths=(2, 6, 14, 2), + depths=(1, 2, 3, 1), block_type=('M',) * 4, stem_width=(48, 96), - **_next_cfg(), + weight_init='normal', + **_next_cfg( + no_block_attn=True, + rel_pos_type='bias', + ), ), - maxxvit_rmlp_large_rw_224=MaxxVitCfg( + maxxvitv2_rmlp_base_rw=MaxxVitCfg( embed_dim=(128, 256, 512, 1024), depths=(2, 6, 12, 2), block_type=('M',) * 4, stem_width=(64, 128), - **_next_cfg(), + **_next_cfg( + no_block_attn=True, + ), + ), + maxxvitv2_rmlp_large_rw=MaxxVitCfg( + embed_dim=(160, 320, 640, 1280), + depths=(2, 6, 16, 2), + block_type=('M',) * 4, + stem_width=(80, 160), + head_hidden_size=1280, + **_next_cfg( + no_block_attn=True, + ), ), # Trying to be like the MaxViT paper configs @@ -1795,11 +1812,29 @@ model_cfgs = dict( ) +def checkpoint_filter_fn(state_dict, model: nn.Module): + model_state_dict = model.state_dict() + out_dict = {} + for k, v in state_dict.items(): + if k in model_state_dict and v.ndim != model_state_dict[k].ndim and v.numel() == model_state_dict[k].numel(): + # adapt between conv2d / linear layers + assert v.ndim in (2, 4) + v = v.reshape(model_state_dict[k].shape) + out_dict[k] = v + return out_dict + + def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs): + if cfg_variant is None: + if variant in model_cfgs: + cfg_variant = variant + else: + cfg_variant = '_'.join(variant.split('_')[:-1]) return build_model_with_cfg( MaxxVit, variant, pretrained, - model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], + model_cfg=model_cfgs[cfg_variant], feature_cfg=dict(flatten_sequential=True), + pretrained_filter_fn=checkpoint_filter_fn, **kwargs) @@ -1815,155 +1850,218 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs({ - # Fiddling with configs / defaults / still pretraining - 'coatnet_pico_rw_224': _cfg(url=''), - 'coatnet_nano_rw_224': _cfg( + # timm specific CoAtNet configs, ImageNet-1k pretrain, fixed rel-pos + 'coatnet_pico_rw_224.untrained': _cfg(url=''), + 'coatnet_nano_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth', crop_pct=0.9), - 'coatnet_0_rw_224': _cfg( + 'coatnet_0_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'), - 'coatnet_1_rw_224': _cfg( + 'coatnet_1_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth' ), - 'coatnet_2_rw_224': _cfg(url=''), - 'coatnet_3_rw_224': _cfg(url=''), - # Highly experimental configs - 'coatnet_bn_0_rw_224': _cfg( + # timm specific CoAtNet configs, ImageNet-12k pretrain w/ 1k fine-tune, fixed rel-pos + 'coatnet_2_rw_224.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/'), + #'coatnet_3_rw_224.untrained': _cfg(url=''), + + # Experimental CoAtNet configs w/ ImageNet-12k pretrain -> 1k fine-tune (different norm layers, MLP rel-pos) + 'coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + + # Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos) + 'coatnet_bn_0_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=0.95), - 'coatnet_rmlp_nano_rw_224': _cfg( + 'coatnet_rmlp_nano_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth', crop_pct=0.9), - 'coatnet_rmlp_0_rw_224': _cfg(url=''), - 'coatnet_rmlp_1_rw_224': _cfg( + 'coatnet_rmlp_0_rw_224.untrained': _cfg(url=''), + 'coatnet_rmlp_1_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'), - 'coatnet_rmlp_1_rw2_224': _cfg(url=''), - 'coatnet_rmlp_2_rw_224': _cfg( + 'coatnet_rmlp_2_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'), - 'coatnet_rmlp_3_rw_224': _cfg(url=''), - 'coatnet_nano_cc_224': _cfg(url=''), - 'coatnext_nano_rw_224': _cfg( + 'coatnet_rmlp_3_rw_224.untrained': _cfg(url=''), + 'coatnet_nano_cc_224.untrained': _cfg(url=''), + 'coatnext_nano_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth', crop_pct=0.9), - # Trying to be like the CoAtNet paper configs - 'coatnet_0_224': _cfg(url=''), - 'coatnet_1_224': _cfg(url=''), - 'coatnet_2_224': _cfg(url=''), - 'coatnet_3_224': _cfg(url=''), - 'coatnet_4_224': _cfg(url=''), - 'coatnet_5_224': _cfg(url=''), - - # Experimental configs - 'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_nano_rw_256': _cfg( + # ImagenNet-12k pretrain CoAtNet + 'coatnet_2_rw_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), + 'coatnet_3_rw_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), + 'coatnet_rmlp_1_rw2_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), + 'coatnet_rmlp_2_rw_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), + + # Trying to be like the CoAtNet paper configs (will adapt if 'tf' weights are ever released) + 'coatnet_0_224.untrained': _cfg(url=''), + 'coatnet_1_224.untrained': _cfg(url=''), + 'coatnet_2_224.untrained': _cfg(url=''), + 'coatnet_3_224.untrained': _cfg(url=''), + 'coatnet_4_224.untrained': _cfg(url=''), + 'coatnet_5_224.untrained': _cfg(url=''), + + # timm specific MaxVit configs, ImageNet-1k pretrain or untrained + 'maxvit_pico_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_nano_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_tiny_rw_224': _cfg( + 'maxvit_tiny_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'), - 'maxvit_tiny_rw_256': _cfg( + 'maxvit_tiny_rw_256.untrained': _cfg( url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_rmlp_pico_rw_256': _cfg( + 'maxvit_tiny_pm_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + + # timm specific MaxVit w/ MLP rel-pos, ImageNet-1k pretrain + 'maxvit_rmlp_pico_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_rmlp_nano_rw_256': _cfg( + 'maxvit_rmlp_nano_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_rmlp_tiny_rw_256': _cfg( + 'maxvit_rmlp_tiny_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_rmlp_small_rw_224': _cfg( + 'maxvit_rmlp_small_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth', crop_pct=0.9, ), - 'maxvit_rmlp_small_rw_256': _cfg( + 'maxvit_rmlp_small_rw_256.untrained': _cfg( url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_rmlp_base_rw_224': _cfg( - url='', - ), - 'maxvit_rmlp_base_rw_384': _cfg( - url='', - input_size=(3, 384, 384), pool_size=(12, 12)), - 'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + # timm specific MaxVit w/ ImageNet-12k pretrain and 1k fine-tune + 'maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + ), + 'maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + + # timm specific MaxVit w/ ImageNet-12k pretrain + 'maxvit_rmlp_base_rw_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, + ), - 'maxxvit_rmlp_nano_rw_256': _cfg( + # timm MaxxViT configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks) + 'maxxvit_rmlp_nano_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxxvit_rmlp_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxxvit_rmlp_small_rw_256': _cfg( + 'maxxvit_rmlp_tiny_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvit_rmlp_small_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxxvit_rmlp_base_rw_224': _cfg(url=''), - 'maxxvit_rmlp_large_rw_224': _cfg(url=''), + # timm MaxxViT-V2 configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks, more width, no block attn) + 'maxxvitv2_nano_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'maxxvitv2_rmlp_large_rw_224.untrained': _cfg(url=''), + + 'maxxvitv2_rmlp_base_rw_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), # MaxViT models ported from official Tensorflow impl 'maxvit_tiny_tf_224.in1k': _cfg( - hf_hub_id='timm/maxvit_tiny_tf_224.in1k', + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'maxvit_tiny_tf_384.in1k': _cfg( - hf_hub_id='timm/maxvit_tiny_tf_384.in1k', - input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'maxvit_tiny_tf_512.in1k': _cfg( - hf_hub_id='timm/maxvit_tiny_tf_512.in1k', - input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), 'maxvit_small_tf_224.in1k': _cfg( - hf_hub_id='timm/maxvit_small_tf_224.in1k', + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'maxvit_small_tf_384.in1k': _cfg( - hf_hub_id='timm/maxvit_small_tf_384.in1k', - input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'maxvit_small_tf_512.in1k': _cfg( - hf_hub_id='timm/maxvit_small_tf_512.in1k', - input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), 'maxvit_base_tf_224.in1k': _cfg( - hf_hub_id='timm/maxvit_base_tf_224.in1k', + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'maxvit_base_tf_384.in1k': _cfg( - hf_hub_id='timm/maxvit_base_tf_384.in1k', - input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'maxvit_base_tf_512.in1k': _cfg( - hf_hub_id='timm/maxvit_base_tf_512.in1k', - input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), 'maxvit_large_tf_224.in1k': _cfg( - hf_hub_id='timm/maxvit_large_tf_224.in1k', + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'maxvit_large_tf_384.in1k': _cfg( - hf_hub_id='timm/maxvit_large_tf_384.in1k', - input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'maxvit_large_tf_512.in1k': _cfg( - hf_hub_id='timm/maxvit_large_tf_512.in1k', - input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), 'maxvit_base_tf_224.in21k': _cfg( url=''), 'maxvit_base_tf_384.in21k_ft_in1k': _cfg( - hf_hub_id='timm/maxvit_base_tf_384.in21k_ft_in1k', - input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'maxvit_base_tf_512.in21k_ft_in1k': _cfg( - hf_hub_id='timm/maxvit_base_tf_512.in21k_ft_in1k', - input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), 'maxvit_large_tf_224.in21k': _cfg( url=''), 'maxvit_large_tf_384.in21k_ft_in1k': _cfg( - hf_hub_id='timm/maxvit_large_tf_384.in21k_ft_in1k', - input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'maxvit_large_tf_512.in21k_ft_in1k': _cfg( - hf_hub_id='timm/maxvit_large_tf_512.in21k_ft_in1k', + hf_hub_id='timm/', input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), 'maxvit_xlarge_tf_224.in21k': _cfg( url=''), 'maxvit_xlarge_tf_384.in21k_ft_in1k': _cfg( - hf_hub_id='timm/maxvit_xlarge_tf_384.in21k_ft_in1k', - input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'maxvit_xlarge_tf_512.in21k_ft_in1k': _cfg( - hf_hub_id='timm/maxvit_xlarge_tf_512.in21k_ft_in1k', - input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), }) @@ -2027,6 +2125,11 @@ def coatnet_rmlp_2_rw_224(pretrained=False, **kwargs): return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs) +@register_model +def coatnet_rmlp_2_rw_384(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_rmlp_2_rw_384', pretrained=pretrained, **kwargs) + + @register_model def coatnet_rmlp_3_rw_224(pretrained=False, **kwargs): return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs) @@ -2148,13 +2251,23 @@ def maxxvit_rmlp_small_rw_256(pretrained=False, **kwargs): @register_model -def maxxvit_rmlp_base_rw_224(pretrained=False, **kwargs): - return _create_maxxvit('maxxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs) +def maxxvitv2_nano_rw_256(pretrained=False, **kwargs): + return _create_maxxvit('maxxvitv2_nano_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxxvitv2_rmlp_base_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('maxxvitv2_rmlp_base_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxxvitv2_rmlp_base_rw_384(pretrained=False, **kwargs): + return _create_maxxvit('maxxvitv2_rmlp_base_rw_384', pretrained=pretrained, **kwargs) @register_model -def maxxvit_rmlp_large_rw_224(pretrained=False, **kwargs): - return _create_maxxvit('maxxvit_rmlp_large_rw_224', pretrained=pretrained, **kwargs) +def maxxvitv2_rmlp_large_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('maxxvitv2_rmlp_large_rw_224', pretrained=pretrained, **kwargs) @register_model diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 9d2528f6..63c9b57f 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -496,7 +496,7 @@ class RegNet(nn.Module): self.final_conv = get_act_layer(cfg.act_layer)() if final_act else nn.Identity() self.num_features = prev_width self.head = ClassifierHead( - in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 8ffb1200..d32f9dea 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -1029,6 +1029,10 @@ default_cfgs = generate_default_cfgs({ hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K', hf_hub_filename='open_clip_pytorch_model.bin', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), + 'vit_gigantic_patch14_clip_224.laion2b': _cfg( + hf_hub_id='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280), 'vit_base_patch32_clip_224.openai': _cfg( hf_hub_id='timm/', @@ -1498,6 +1502,17 @@ def vit_giant_patch14_clip_224(pretrained=False, **kwargs): return model +@register_model +def vit_gigantic_patch14_clip_224(pretrained=False, **kwargs): + """ ViT-bigG model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + Pretrained weights from CLIP image tower. + """ + model_kwargs = dict( + patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_gigantic_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + return model + # Experimental models below @register_model diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index e3348e64..6bb7085f 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -216,7 +216,7 @@ class XceptionAligned(nn.Module): num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))] self.act = act_layer(inplace=True) if preact else nn.Identity() self.head = ClassifierHead( - in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) @torch.jit.ignore def group_matcher(self, coarse=False): diff --git a/timm/version.py b/timm/version.py index b285df69..e2ac9a76 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.8.6dev0' +__version__ = '0.8.8dev0'