Merge branch 'rwightman:main' into metaformer_baselines_for_vision

pull/1647/head
Fredo Guan 2 years ago committed by GitHub
commit 0b1f84142f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -40,9 +40,10 @@ jobs:
- name: Install torch on ubuntu - name: Install torch on ubuntu
if: startsWith(matrix.os, 'ubuntu') if: startsWith(matrix.os, 'ubuntu')
run: | run: |
pip install --no-cache-dir torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html sudo sed -i 's/azure\.//' /etc/apt/sources.list
sudo apt update sudo apt update
sudo apt install -y google-perftools sudo apt install -y google-perftools
pip install --no-cache-dir torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install requirements - name: Install requirements
run: | run: |
pip install -r requirements.txt pip install -r requirements.txt

@ -24,6 +24,65 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
* ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗ * ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗
* Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch. * Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch.
### Jan 20, 2023
* Add two convnext 12k -> 1k fine-tunes at 384x384
* `convnext_tiny.in12k_ft_in1k_384` - 85.1 @ 384
* `convnext_small.in12k_ft_in1k_384` - 86.2 @ 384
* Push all MaxxViT weights to HF hub, and add new ImageNet-12k -> 1k fine-tunes for `rw` base MaxViT and CoAtNet 1/2 models
|model |top1 |top5 |samples / sec |Params (M) |GMAC |Act (M)|
|------------------------------------------------------------------------------------------------------------------------|----:|----:|--------------:|--------------:|-----:|------:|
|[maxvit_xlarge_tf_512.in21k_ft_in1k](https://huggingface.co/timm/maxvit_xlarge_tf_512.in21k_ft_in1k) |88.53|98.64| 21.76| 475.77|534.14|1413.22|
|[maxvit_xlarge_tf_384.in21k_ft_in1k](https://huggingface.co/timm/maxvit_xlarge_tf_384.in21k_ft_in1k) |88.32|98.54| 42.53| 475.32|292.78| 668.76|
|[maxvit_base_tf_512.in21k_ft_in1k](https://huggingface.co/timm/maxvit_base_tf_512.in21k_ft_in1k) |88.20|98.53| 50.87| 119.88|138.02| 703.99|
|[maxvit_large_tf_512.in21k_ft_in1k](https://huggingface.co/timm/maxvit_large_tf_512.in21k_ft_in1k) |88.04|98.40| 36.42| 212.33|244.75| 942.15|
|[maxvit_large_tf_384.in21k_ft_in1k](https://huggingface.co/timm/maxvit_large_tf_384.in21k_ft_in1k) |87.98|98.56| 71.75| 212.03|132.55| 445.84|
|[maxvit_base_tf_384.in21k_ft_in1k](https://huggingface.co/timm/maxvit_base_tf_384.in21k_ft_in1k) |87.92|98.54| 104.71| 119.65| 73.80| 332.90|
|[maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k](https://huggingface.co/timm/maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k) |87.81|98.37| 106.55| 116.14| 70.97| 318.95|
|[maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k](https://huggingface.co/timm/maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k) |87.47|98.37| 149.49| 116.09| 72.98| 213.74|
|[coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k) |87.39|98.31| 160.80| 73.88| 47.69| 209.43|
|[maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k) |86.89|98.02| 375.86| 116.14| 23.15| 92.64|
|[maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k) |86.64|98.02| 501.03| 116.09| 24.20| 62.77|
|[maxvit_base_tf_512.in1k](https://huggingface.co/timm/maxvit_base_tf_512.in1k) |86.60|97.92| 50.75| 119.88|138.02| 703.99|
|[coatnet_2_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_2_rw_224.sw_in12k_ft_in1k) |86.57|97.89| 631.88| 73.87| 15.09| 49.22|
|[maxvit_large_tf_512.in1k](https://huggingface.co/timm/maxvit_large_tf_512.in1k) |86.52|97.88| 36.04| 212.33|244.75| 942.15|
|[coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k) |86.49|97.90| 620.58| 73.88| 15.18| 54.78|
|[maxvit_base_tf_384.in1k](https://huggingface.co/timm/maxvit_base_tf_384.in1k) |86.29|97.80| 101.09| 119.65| 73.80| 332.90|
|[maxvit_large_tf_384.in1k](https://huggingface.co/timm/maxvit_large_tf_384.in1k) |86.23|97.69| 70.56| 212.03|132.55| 445.84|
|[maxvit_small_tf_512.in1k](https://huggingface.co/timm/maxvit_small_tf_512.in1k) |86.10|97.76| 88.63| 69.13| 67.26| 383.77|
|[maxvit_tiny_tf_512.in1k](https://huggingface.co/timm/maxvit_tiny_tf_512.in1k) |85.67|97.58| 144.25| 31.05| 33.49| 257.59|
|[maxvit_small_tf_384.in1k](https://huggingface.co/timm/maxvit_small_tf_384.in1k) |85.54|97.46| 188.35| 69.02| 35.87| 183.65|
|[maxvit_tiny_tf_384.in1k](https://huggingface.co/timm/maxvit_tiny_tf_384.in1k) |85.11|97.38| 293.46| 30.98| 17.53| 123.42|
|[maxvit_large_tf_224.in1k](https://huggingface.co/timm/maxvit_large_tf_224.in1k) |84.93|96.97| 247.71| 211.79| 43.68| 127.35|
|[coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k) |84.90|96.96| 1025.45| 41.72| 8.11| 40.13|
|[maxvit_base_tf_224.in1k](https://huggingface.co/timm/maxvit_base_tf_224.in1k) |84.85|96.99| 358.25| 119.47| 24.04| 95.01|
|[maxxvit_rmlp_small_rw_256.sw_in1k](https://huggingface.co/timm/maxxvit_rmlp_small_rw_256.sw_in1k) |84.63|97.06| 575.53| 66.01| 14.67| 58.38|
|[coatnet_rmlp_2_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_rmlp_2_rw_224.sw_in1k) |84.61|96.74| 625.81| 73.88| 15.18| 54.78|
|[maxvit_rmlp_small_rw_224.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_small_rw_224.sw_in1k) |84.49|96.76| 693.82| 64.90| 10.75| 49.30|
|[maxvit_small_tf_224.in1k](https://huggingface.co/timm/maxvit_small_tf_224.in1k) |84.43|96.83| 647.96| 68.93| 11.66| 53.17|
|[maxvit_rmlp_tiny_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_tiny_rw_256.sw_in1k) |84.23|96.78| 807.21| 29.15| 6.77| 46.92|
|[coatnet_1_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_1_rw_224.sw_in1k) |83.62|96.38| 989.59| 41.72| 8.04| 34.60|
|[maxvit_tiny_rw_224.sw_in1k](https://huggingface.co/timm/maxvit_tiny_rw_224.sw_in1k) |83.50|96.50| 1100.53| 29.06| 5.11| 33.11|
|[maxvit_tiny_tf_224.in1k](https://huggingface.co/timm/maxvit_tiny_tf_224.in1k) |83.41|96.59| 1004.94| 30.92| 5.60| 35.78|
|[coatnet_rmlp_1_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_rmlp_1_rw_224.sw_in1k) |83.36|96.45| 1093.03| 41.69| 7.85| 35.47|
|[maxxvitv2_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxxvitv2_nano_rw_256.sw_in1k) |83.11|96.33| 1276.88| 23.70| 6.26| 23.05|
|[maxxvit_rmlp_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxxvit_rmlp_nano_rw_256.sw_in1k) |83.03|96.34| 1341.24| 16.78| 4.37| 26.05|
|[maxvit_rmlp_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_nano_rw_256.sw_in1k) |82.96|96.26| 1283.24| 15.50| 4.47| 31.92|
|[maxvit_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_nano_rw_256.sw_in1k) |82.93|96.23| 1218.17| 15.45| 4.46| 30.28|
|[coatnet_bn_0_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_bn_0_rw_224.sw_in1k) |82.39|96.19| 1600.14| 27.44| 4.67| 22.04|
|[coatnet_0_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_0_rw_224.sw_in1k) |82.39|95.84| 1831.21| 27.44| 4.43| 18.73|
|[coatnet_rmlp_nano_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_rmlp_nano_rw_224.sw_in1k) |82.05|95.87| 2109.09| 15.15| 2.62| 20.34|
|[coatnext_nano_rw_224.sw_in1k](https://huggingface.co/timm/coatnext_nano_rw_224.sw_in1k) |81.95|95.92| 2525.52| 14.70| 2.47| 12.80|
|[coatnet_nano_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_nano_rw_224.sw_in1k) |81.70|95.64| 2344.52| 15.14| 2.41| 15.41|
|[maxvit_rmlp_pico_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_pico_rw_256.sw_in1k) |80.53|95.21| 1594.71| 7.52| 1.85| 24.86|
### Jan 11, 2023
* Update ConvNeXt ImageNet-12k pretrain series w/ two new fine-tuned weights (and pre FT `.in12k` tags)
* `convnext_nano.in12k_ft_in1k` - 82.3 @ 224, 82.9 @ 288 (previously released)
* `convnext_tiny.in12k_ft_in1k` - 84.2 @ 224, 84.5 @ 288
* `convnext_small.in12k_ft_in1k` - 85.2 @ 224, 85.3 @ 288
### Jan 6, 2023 ### Jan 6, 2023
* Finally got around to adding `--model-kwargs` and `--opt-kwargs` to scripts to pass through rare args directly to model classes from cmd line * Finally got around to adding `--model-kwargs` and `--opt-kwargs` to scripts to pass through rare args directly to model classes from cmd line
* `train.py /imagenet --model resnet50 --amp --model-kwargs output_stride=16 act_layer=silu` * `train.py /imagenet --model resnet50 --amp --model-kwargs output_stride=16 act_layer=silu`

@ -38,7 +38,7 @@ An ImageNet test set of 10,000 images sampled from new images roughly 10 years a
### ImageNet-Adversarial - [`results-imagenet-a.csv`](results-imagenet-a.csv) ### ImageNet-Adversarial - [`results-imagenet-a.csv`](results-imagenet-a.csv)
A collection of 7500 images covering 200 of the 1000 ImageNet classes. Images are naturally occuring adversarial examples that confuse typical ImageNet classifiers. This is a challenging dataset, your typical ResNet-50 will score 0% top-1. A collection of 7500 images covering 200 of the 1000 ImageNet classes. Images are naturally occurring adversarial examples that confuse typical ImageNet classifiers. This is a challenging dataset, your typical ResNet-50 will score 0% top-1.
For clean validation with same 200 classes, see [`results-imagenet-a-clean.csv`](results-imagenet-a-clean.csv) For clean validation with same 200 classes, see [`results-imagenet-a-clean.csv`](results-imagenet-a-clean.csv)

@ -27,7 +27,7 @@ NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*', 'flexivit*' 'eva_*', 'flexivit*'
] ]
NUM_NON_STD = len(NON_STD_FILTERS) NUM_NON_STD = len(NON_STD_FILTERS)
@ -38,7 +38,7 @@ if 'GITHUB_ACTIONS' in os.environ:
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*', '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
'swin*giant*', 'convnextv2_huge*'] 'swin*giant*', 'convnextv2_huge*', 'maxvit_xlarge*', 'davit_giant', 'davit_huge']
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*'] NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*']
else: else:
EXCLUDE_FILTERS = [] EXCLUDE_FILTERS = []
@ -53,7 +53,7 @@ MAX_JIT_SIZE = 320
TARGET_FFEAT_SIZE = 96 TARGET_FFEAT_SIZE = 96
MAX_FFEAT_SIZE = 256 MAX_FFEAT_SIZE = 256
TARGET_FWD_FX_SIZE = 128 TARGET_FWD_FX_SIZE = 128
MAX_FWD_FX_SIZE = 224 MAX_FWD_FX_SIZE = 256
TARGET_BWD_FX_SIZE = 128 TARGET_BWD_FX_SIZE = 128
MAX_BWD_FX_SIZE = 224 MAX_BWD_FX_SIZE = 224
@ -269,7 +269,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
EXCLUDE_JIT_FILTERS = [ EXCLUDE_JIT_FILTERS = [
'*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable '*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable
'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point 'dla*', 'hrnet*', 'ghostnet*' # hopefully fix at some point
'vit_large_*', 'vit_huge_*', 'vit_gi*', 'vit_large_*', 'vit_huge_*', 'vit_gi*',
] ]

@ -1,6 +1,6 @@
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
rand_augment_transform, auto_augment_transform rand_augment_transform, auto_augment_transform
from .config import resolve_data_config from .config import resolve_data_config, resolve_model_data_config
from .constants import * from .constants import *
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset from .dataset_factory import create_dataset

@ -6,16 +6,18 @@ _logger = logging.getLogger(__name__)
def resolve_data_config( def resolve_data_config(
args, args=None,
default_cfg=None, pretrained_cfg=None,
model=None, model=None,
use_test_size=False, use_test_size=False,
verbose=False verbose=False
): ):
new_config = {} assert model or args or pretrained_cfg, "At least one of model, args, or pretrained_cfg required for data config."
default_cfg = default_cfg or {} args = args or {}
if not default_cfg and model is not None and hasattr(model, 'default_cfg'): pretrained_cfg = pretrained_cfg or {}
default_cfg = model.default_cfg if not pretrained_cfg and model is not None and hasattr(model, 'pretrained_cfg'):
pretrained_cfg = model.pretrained_cfg
data_config = {}
# Resolve input/image size # Resolve input/image size
in_chans = 3 in_chans = 3
@ -32,65 +34,94 @@ def resolve_data_config(
assert isinstance(args['img_size'], int) assert isinstance(args['img_size'], int)
input_size = (in_chans, args['img_size'], args['img_size']) input_size = (in_chans, args['img_size'], args['img_size'])
else: else:
if use_test_size and default_cfg.get('test_input_size', None) is not None: if use_test_size and pretrained_cfg.get('test_input_size', None) is not None:
input_size = default_cfg['test_input_size'] input_size = pretrained_cfg['test_input_size']
elif default_cfg.get('input_size', None) is not None: elif pretrained_cfg.get('input_size', None) is not None:
input_size = default_cfg['input_size'] input_size = pretrained_cfg['input_size']
new_config['input_size'] = input_size data_config['input_size'] = input_size
# resolve interpolation method # resolve interpolation method
new_config['interpolation'] = 'bicubic' data_config['interpolation'] = 'bicubic'
if args.get('interpolation', None): if args.get('interpolation', None):
new_config['interpolation'] = args['interpolation'] data_config['interpolation'] = args['interpolation']
elif default_cfg.get('interpolation', None): elif pretrained_cfg.get('interpolation', None):
new_config['interpolation'] = default_cfg['interpolation'] data_config['interpolation'] = pretrained_cfg['interpolation']
# resolve dataset + model mean for normalization # resolve dataset + model mean for normalization
new_config['mean'] = IMAGENET_DEFAULT_MEAN data_config['mean'] = IMAGENET_DEFAULT_MEAN
if args.get('mean', None) is not None: if args.get('mean', None) is not None:
mean = tuple(args['mean']) mean = tuple(args['mean'])
if len(mean) == 1: if len(mean) == 1:
mean = tuple(list(mean) * in_chans) mean = tuple(list(mean) * in_chans)
else: else:
assert len(mean) == in_chans assert len(mean) == in_chans
new_config['mean'] = mean data_config['mean'] = mean
elif default_cfg.get('mean', None): elif pretrained_cfg.get('mean', None):
new_config['mean'] = default_cfg['mean'] data_config['mean'] = pretrained_cfg['mean']
# resolve dataset + model std deviation for normalization # resolve dataset + model std deviation for normalization
new_config['std'] = IMAGENET_DEFAULT_STD data_config['std'] = IMAGENET_DEFAULT_STD
if args.get('std', None) is not None: if args.get('std', None) is not None:
std = tuple(args['std']) std = tuple(args['std'])
if len(std) == 1: if len(std) == 1:
std = tuple(list(std) * in_chans) std = tuple(list(std) * in_chans)
else: else:
assert len(std) == in_chans assert len(std) == in_chans
new_config['std'] = std data_config['std'] = std
elif default_cfg.get('std', None): elif pretrained_cfg.get('std', None):
new_config['std'] = default_cfg['std'] data_config['std'] = pretrained_cfg['std']
# resolve default inference crop # resolve default inference crop
crop_pct = DEFAULT_CROP_PCT crop_pct = DEFAULT_CROP_PCT
if args.get('crop_pct', None): if args.get('crop_pct', None):
crop_pct = args['crop_pct'] crop_pct = args['crop_pct']
else: else:
if use_test_size and default_cfg.get('test_crop_pct', None): if use_test_size and pretrained_cfg.get('test_crop_pct', None):
crop_pct = default_cfg['test_crop_pct'] crop_pct = pretrained_cfg['test_crop_pct']
elif default_cfg.get('crop_pct', None): elif pretrained_cfg.get('crop_pct', None):
crop_pct = default_cfg['crop_pct'] crop_pct = pretrained_cfg['crop_pct']
new_config['crop_pct'] = crop_pct data_config['crop_pct'] = crop_pct
# resolve default crop percentage # resolve default crop percentage
crop_mode = DEFAULT_CROP_MODE crop_mode = DEFAULT_CROP_MODE
if args.get('crop_mode', None): if args.get('crop_mode', None):
crop_mode = args['crop_mode'] crop_mode = args['crop_mode']
elif default_cfg.get('crop_mode', None): elif pretrained_cfg.get('crop_mode', None):
crop_mode = default_cfg['crop_mode'] crop_mode = pretrained_cfg['crop_mode']
new_config['crop_mode'] = crop_mode data_config['crop_mode'] = crop_mode
if verbose: if verbose:
_logger.info('Data processing configuration for current model + dataset:') _logger.info('Data processing configuration for current model + dataset:')
for n, v in new_config.items(): for n, v in data_config.items():
_logger.info('\t%s: %s' % (n, str(v))) _logger.info('\t%s: %s' % (n, str(v)))
return new_config return data_config
def resolve_model_data_config(
model,
args=None,
pretrained_cfg=None,
use_test_size=False,
verbose=False,
):
""" Resolve Model Data Config
This is equivalent to resolve_data_config() but with arguments re-ordered to put model first.
Args:
model (nn.Module): the model instance
args (dict): command line arguments / configuration in dict form (overrides pretrained_cfg)
pretrained_cfg (dict): pretrained model config (overrides pretrained_cfg attached to model)
use_test_size (bool): use the test time input resolution (if one exists) instead of default train resolution
verbose (bool): enable extra logging of resolved values
Returns:
dictionary of config
"""
return resolve_data_config(
args=args,
pretrained_cfg=pretrained_cfg,
model=model,
use_test_size=use_test_size,
verbose=verbose,
)

@ -29,7 +29,8 @@ from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
from .padding import get_padding, get_same_padding, pad_same from .padding import get_padding, get_same_padding, pad_same
from .patch_embed import PatchEmbed, resample_patch_embed from .patch_embed import PatchEmbed, resample_patch_embed
from .pool2d_same import AvgPool2dSame, create_pool2d from .pool2d_same import AvgPool2dSame, create_pool2d

@ -38,13 +38,24 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False
class ClassifierHead(nn.Module): class ClassifierHead(nn.Module):
"""Classifier head w/ configurable global pooling and dropout.""" """Classifier head w/ configurable global pooling and dropout."""
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False): def __init__(self, in_features, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
super(ClassifierHead, self).__init__() super(ClassifierHead, self).__init__()
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) self.in_features = in_features
self.use_conv = use_conv
self.global_pool, num_pooled_features = _create_pool(in_features, num_classes, pool_type, use_conv=use_conv)
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
def reset(self, num_classes, global_pool=None):
if global_pool is not None:
if global_pool != self.global_pool.pool_type:
self.global_pool, _ = _create_pool(self.in_features, num_classes, global_pool, use_conv=self.use_conv)
self.flatten = nn.Flatten(1) if self.use_conv and global_pool else nn.Identity()
num_pooled_features = self.in_features * self.global_pool.feat_mult()
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=self.use_conv)
def forward(self, x, pre_logits: bool = False): def forward(self, x, pre_logits: bool = False):
x = self.global_pool(x) x = self.global_pool(x)
if self.drop_rate: if self.drop_rate:

@ -17,6 +17,7 @@ from typing import Union, List, Optional, Any
import torch import torch
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from torchvision.ops.misc import FrozenBatchNorm2d
from .create_act import get_act_layer from .create_act import get_act_layer
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
@ -77,7 +78,7 @@ class BatchNormAct2d(nn.BatchNorm2d):
if self.training and self.track_running_stats: if self.training and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None # TODO: if statement only here to tell the jit to skip emitting this when it is None
if self.num_batches_tracked is not None: # type: ignore[has-type] if self.num_batches_tracked is not None: # type: ignore[has-type]
self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type] self.num_batches_tracked.add_(1) # type: ignore[has-type]
if self.momentum is None: # use cumulative moving average if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked) exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average else: # use exponential moving average
@ -169,6 +170,159 @@ def convert_sync_batchnorm(module, process_group=None):
return module_output return module_output
class FrozenBatchNormAct2d(torch.nn.Module):
"""
BatchNormAct2d where the batch statistics and the affine parameters are fixed
Args:
num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)``
eps (float): a value added to the denominator for numerical stability. Default: 1e-5
"""
def __init__(
self,
num_features: int,
eps: float = 1e-5,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super().__init__()
self.eps = eps
self.register_buffer("weight", torch.ones(num_features))
self.register_buffer("bias", torch.zeros(num_features))
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features))
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
def _load_from_state_dict(
self,
state_dict: dict,
prefix: str,
local_metadata: dict,
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
scale = w * (rv + self.eps).rsqrt()
bias = b - rm * scale
x = x * scale + bias
x = self.act(self.drop(x))
return x
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps}, act={self.act})"
def freeze_batch_norm_2d(module):
"""
Converts all `BatchNorm2d` and `SyncBatchNorm` or `BatchNormAct2d` and `SyncBatchNormAct2d` layers
of provided module into `FrozenBatchNorm2d` or `FrozenBatchNormAct2d` respectively.
Args:
module (torch.nn.Module): Any PyTorch module.
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
if isinstance(module, (BatchNormAct2d, SyncBatchNormAct)):
res = FrozenBatchNormAct2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
res.drop = module.drop
res.act = module.act
elif isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
res = FrozenBatchNorm2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = freeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def unfreeze_batch_norm_2d(module):
"""
Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance
of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked
recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
if isinstance(module, FrozenBatchNormAct2d):
res = BatchNormAct2d(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
res.drop = module.drop
res.act = module.act
elif isinstance(module, FrozenBatchNorm2d):
res = torch.nn.BatchNorm2d(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = unfreeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def _num_groups(num_channels, num_groups, group_size): def _num_groups(num_channels, num_groups, group_size):
if group_size: if group_size:
assert num_channels % group_size == 0 assert num_channels % group_size == 0
@ -179,10 +333,54 @@ def _num_groups(num_channels, num_groups, group_size):
class GroupNormAct(nn.GroupNorm): class GroupNormAct(nn.GroupNorm):
# NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
def __init__( def __init__(
self, num_channels, num_groups=32, eps=1e-5, affine=True, group_size=None, self,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): num_channels,
num_groups=32,
eps=1e-5,
affine=True,
group_size=None,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super(GroupNormAct, self).__init__( super(GroupNormAct, self).__init__(
_num_groups(num_channels, num_groups, group_size), num_channels, eps=eps, affine=affine) _num_groups(num_channels, num_groups, group_size),
num_channels,
eps=eps,
affine=affine,
)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
self._fast_norm = is_fast_norm()
def forward(self, x):
if self._fast_norm:
x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
x = self.drop(x)
x = self.act(x)
return x
class GroupNorm1Act(nn.GroupNorm):
def __init__(
self,
num_channels,
eps=1e-5,
affine=True,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super(GroupNorm1Act, self).__init__(1, num_channels, eps=eps, affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act: if act_layer is not None and apply_act:
@ -204,8 +402,15 @@ class GroupNormAct(nn.GroupNorm):
class LayerNormAct(nn.LayerNorm): class LayerNormAct(nn.LayerNorm):
def __init__( def __init__(
self, normalization_shape: Union[int, List[int], torch.Size], eps=1e-5, affine=True, self,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): normalization_shape: Union[int, List[int], torch.Size],
eps=1e-5,
affine=True,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine) super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module act_layer = get_act_layer(act_layer) # string -> nn.Module
@ -228,8 +433,15 @@ class LayerNormAct(nn.LayerNorm):
class LayerNormAct2d(nn.LayerNorm): class LayerNormAct2d(nn.LayerNorm):
def __init__( def __init__(
self, num_channels, eps=1e-5, affine=True, self,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): num_channels,
eps=1e-5,
affine=True,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine) super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module act_layer = get_act_layer(act_layer) # string -> nn.Module

@ -8,6 +8,7 @@ from .convmixer import *
from .convnext import * from .convnext import *
from .crossvit import * from .crossvit import *
from .cspnet import * from .cspnet import *
from .davit import *
from .deit import * from .deit import *
from .densenet import * from .densenet import *
from .dla import * from .dla import *

@ -179,11 +179,11 @@ def load_pretrained(
return return
if filter_fn is not None: if filter_fn is not None:
# for backwards compat with filter fn that take one arg, try one first, the two
try: try:
state_dict = filter_fn(state_dict)
except TypeError:
state_dict = filter_fn(state_dict, model) state_dict = filter_fn(state_dict, model)
except TypeError as e:
# for backwards compat with filter fn that take one arg
state_dict = filter_fn(state_dict)
input_convs = pretrained_cfg.get('first_conv', None) input_convs = pretrained_cfg.get('first_conv', None)
if input_convs is not None and in_chans != 3: if input_convs is not None and in_chans != 3:

@ -236,20 +236,7 @@ def push_to_hf_hub(
model_card = model_card or {} model_card = model_card or {}
model_name = repo_id.split('/')[-1] model_name = repo_id.split('/')[-1]
readme_path = Path(tmpdir) / "README.md" readme_path = Path(tmpdir) / "README.md"
readme_text = "---\n" readme_text = generate_readme(model_card, model_name)
readme_text += "tags:\n- image-classification\n- timm\n"
readme_text += "library_tag: timm\n"
readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n"
readme_text += "---\n"
readme_text += f"# Model card for {model_name}\n"
if 'description' in model_card:
readme_text += f"\n{model_card['description']}\n"
if 'details' in model_card:
readme_text += f"\n## Model Details\n"
for k, v in model_card['details'].items():
readme_text += f"- **{k}:** {v}\n"
if 'citation' in model_card:
readme_text += f"\n## Citation\n```\n{model_card['citation']}```\n"
readme_path.write_text(readme_text) readme_path.write_text(readme_text)
# Upload model and return # Upload model and return
@ -260,3 +247,51 @@ def push_to_hf_hub(
create_pr=create_pr, create_pr=create_pr,
commit_message=commit_message, commit_message=commit_message,
) )
def generate_readme(model_card, model_name):
readme_text = "---\n"
readme_text += "tags:\n- image-classification\n- timm\n"
readme_text += "library_tag: timm\n"
readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n"
if 'details' in model_card and 'Dataset' in model_card['details']:
readme_text += 'datasets:\n'
readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
if 'Pretrain Dataset' in model_card['details']:
readme_text += f"- {model_card['details']['Pretrain Dataset'].lower()}\n"
readme_text += "---\n"
readme_text += f"# Model card for {model_name}\n"
if 'description' in model_card:
readme_text += f"\n{model_card['description']}\n"
if 'details' in model_card:
readme_text += f"\n## Model Details\n"
for k, v in model_card['details'].items():
if isinstance(v, (list, tuple)):
readme_text += f"- **{k}:**\n"
for vi in v:
readme_text += f" - {vi}\n"
elif isinstance(v, dict):
readme_text += f"- **{k}:**\n"
for ki, vi in v.items():
readme_text += f" - {ki}: {vi}\n"
else:
readme_text += f"- **{k}:** {v}\n"
if 'usage' in model_card:
readme_text += f"\n## Model Usage\n"
readme_text += model_card['usage']
readme_text += '\n'
if 'comparison' in model_card:
readme_text += f"\n## Model Comparison\n"
readme_text += model_card['comparison']
readme_text += '\n'
if 'citation' in model_card:
readme_text += f"\n## Citation\n"
if not isinstance(model_card['citation'], (list, tuple)):
citations = [model_card['citation']]
else:
citations = model_card['citation']
for c in citations:
readme_text += f"```bibtex\n{c}\n```\n"
return readme_text

@ -43,7 +43,7 @@ from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \ from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
@ -205,6 +205,7 @@ class ConvNeXt(nn.Module):
use_grn=False, use_grn=False,
act_layer='gelu', act_layer='gelu',
norm_layer=None, norm_layer=None,
norm_eps=None,
drop_rate=0., drop_rate=0.,
drop_path_rate=0., drop_path_rate=0.,
): ):
@ -236,10 +237,15 @@ class ConvNeXt(nn.Module):
if norm_layer is None: if norm_layer is None:
norm_layer = LayerNorm2d norm_layer = LayerNorm2d
norm_layer_cl = norm_layer if conv_mlp else LayerNorm norm_layer_cl = norm_layer if conv_mlp else LayerNorm
if norm_eps is not None:
norm_layer = partial(norm_layer, eps=norm_eps)
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
else: else:
assert conv_mlp,\ assert conv_mlp,\
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
norm_layer_cl = norm_layer norm_layer_cl = norm_layer
if norm_eps is not None:
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
@ -250,7 +256,7 @@ class ConvNeXt(nn.Module):
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
self.stem = nn.Sequential( self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias), nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias),
norm_layer(dims[0]) norm_layer(dims[0]),
) )
stem_stride = patch_size stem_stride = patch_size
else: else:
@ -301,11 +307,10 @@ class ConvNeXt(nn.Module):
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
self.head_norm_first = head_norm_first
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
self.head = nn.Sequential(OrderedDict([ self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
('norm', nn.Identity() if head_norm_first or num_classes == 0 else norm_layer(self.num_features)), ('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
('flatten', nn.Flatten(1) if global_pool else nn.Identity()), ('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('drop', nn.Dropout(self.drop_rate)), ('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())])) ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
@ -336,14 +341,7 @@ class ConvNeXt(nn.Module):
if global_pool is not None: if global_pool is not None:
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
if num_classes == 0: self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head.norm = nn.Identity()
self.head.fc = nn.Identity()
else:
if not self.head_norm_first:
norm_layer = type(self.stem[-1]) # obtain type from stem norm
self.head.norm = norm_layer(self.num_features)
self.head.fc = nn.Linear(self.num_features, num_classes)
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
@ -384,7 +382,15 @@ def checkpoint_filter_fn(state_dict, model):
return state_dict # non-FB checkpoint return state_dict # non-FB checkpoint
if 'model' in state_dict: if 'model' in state_dict:
state_dict = state_dict['model'] state_dict = state_dict['model']
out_dict = {} out_dict = {}
if 'visual.trunk.stem.0.weight' in state_dict:
out_dict = {k.replace('visual.trunk.', ''): v for k, v in state_dict.items() if k.startswith('visual.trunk.')}
if 'visual.head.proj.weight' in state_dict:
out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
return out_dict
import re import re
for k, v in state_dict.items(): for k, v in state_dict.items():
k = k.replace('downsample_layers.0.', 'stem.') k = k.replace('downsample_layers.0.', 'stem.')
@ -403,10 +409,16 @@ def checkpoint_filter_fn(state_dict, model):
model_shape = model.state_dict()[k].shape model_shape = model.state_dict()[k].shape
v = v.reshape(model_shape) v = v.reshape(model_shape)
out_dict[k] = v out_dict[k] = v
return out_dict return out_dict
def _create_convnext(variant, pretrained=False, **kwargs): def _create_convnext(variant, pretrained=False, **kwargs):
if kwargs.get('pretrained_cfg', '') == 'fcmae':
# NOTE fcmae pretrained weights have no classifier or final norm-layer (`head.norm`)
# This is workaround loading with num_classes=0 w/o removing norm-layer.
kwargs.setdefault('pretrained_strict', False)
model = build_model_with_cfg( model = build_model_with_cfg(
ConvNeXt, variant, pretrained, ConvNeXt, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
@ -481,10 +493,29 @@ default_cfgs = generate_default_cfgs({
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
hf_hub_id='timm/', hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.in12k_ft_in1k_384': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_small.in12k_ft_in1k_384': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_nano.in12k': _cfg( 'convnext_nano.in12k': _cfg(
hf_hub_id='timm/', hf_hub_id='timm/',
crop_pct=0.95, num_classes=11821), crop_pct=0.95, num_classes=11821),
'convnext_tiny.in12k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, num_classes=11821),
'convnext_small.in12k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, num_classes=11821),
'convnext_tiny.fb_in1k': _cfg( 'convnext_tiny.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
@ -676,6 +707,33 @@ default_cfgs = generate_default_cfgs({
num_classes=0), num_classes=0),
'convnextv2_small.untrained': _cfg(), 'convnextv2_small.untrained': _cfg(),
# CLIP based weights, original image tower weights and fine-tunes
'convnext_base.clip_laion2b': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laion2b_augreg': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laiona': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laiona_320': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laiona_augreg_320': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
}) })

@ -913,7 +913,7 @@ class CspNet(nn.Module):
# Construct the head # Construct the head
self.num_features = prev_chs self.num_features = prev_chs
self.head = ClassifierHead( self.head = ClassifierHead(
in_chs=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) in_features=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)

@ -0,0 +1,679 @@
""" DaViT: Dual Attention Vision Transformers
As described in https://arxiv.org/abs/2204.03645
Input size invariant transformer architecture that combines channel and spacial
attention in each block. The attention mechanisms used are linear in complexity.
DaViT model defs and weights adapted from https://github.com/dingmyu/davit, original copyright below
"""
# Copyright (c) 2022 Mingyu Ding
# All rights reserved.
# This source code is licensed under the MIT license
import itertools
from collections import OrderedDict
from functools import partial
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp, LayerNorm2d, get_norm_layer
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
__all__ = ['DaViT']
class ConvPosEnc(nn.Module):
def __init__(self, dim: int, k: int = 3, act: bool = False):
super(ConvPosEnc, self).__init__()
self.proj = nn.Conv2d(dim, dim, k, 1, k // 2, groups=dim)
self.act = nn.GELU() if act else nn.Identity()
def forward(self, x: Tensor):
feat = self.proj(x)
x = x + self.act(feat)
return x
class Stem(nn.Module):
""" Size-agnostic implementation of 2D image to patch embedding,
allowing input size to be adjusted during model forward operation
"""
def __init__(
self,
in_chs=3,
out_chs=96,
stride=4,
norm_layer=LayerNorm2d,
):
super().__init__()
stride = to_2tuple(stride)
self.stride = stride
self.in_chs = in_chs
self.out_chs = out_chs
assert stride[0] == 4 # only setup for stride==4
self.conv = nn.Conv2d(
in_chs,
out_chs,
kernel_size=7,
stride=stride,
padding=3,
)
self.norm = norm_layer(out_chs)
def forward(self, x: Tensor):
B, C, H, W = x.shape
x = F.pad(x, (0, (self.stride[1] - W % self.stride[1]) % self.stride[1]))
x = F.pad(x, (0, 0, 0, (self.stride[0] - H % self.stride[0]) % self.stride[0]))
x = self.conv(x)
x = self.norm(x)
return x
class Downsample(nn.Module):
def __init__(
self,
in_chs,
out_chs,
norm_layer=LayerNorm2d,
):
super().__init__()
self.in_chs = in_chs
self.out_chs = out_chs
self.norm = norm_layer(in_chs)
self.conv = nn.Conv2d(
in_chs,
out_chs,
kernel_size=2,
stride=2,
padding=0,
)
def forward(self, x: Tensor):
B, C, H, W = x.shape
x = self.norm(x)
x = F.pad(x, (0, (2 - W % 2) % 2))
x = F.pad(x, (0, 0, 0, (2 - H % 2) % 2))
x = self.conv(x)
return x
class ChannelAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
def forward(self, x: Tensor):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
k = k * self.scale
attention = k.transpose(-1, -2) @ v
attention = attention.softmax(dim=-1)
x = (attention @ q.transpose(-1, -2)).transpose(-1, -2)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class ChannelBlock(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
ffn=True,
cpe_act=False,
):
super().__init__()
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
self.ffn = ffn
self.norm1 = norm_layer(dim)
self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
if self.ffn:
self.norm2 = norm_layer(dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
else:
self.norm2 = None
self.mlp = None
self.drop_path2 = None
def forward(self, x: Tensor):
B, C, H, W = x.shape
x = self.cpe1(x).flatten(2).transpose(1, 2)
cur = self.norm1(x)
cur = self.attn(cur)
x = x + self.drop_path1(cur)
x = self.cpe2(x.transpose(1, 2).view(B, C, H, W))
if self.mlp is not None:
x = x.flatten(2).transpose(1, 2)
x = x + self.drop_path2(self.mlp(self.norm2(x)))
x = x.transpose(1, 2).view(B, C, H, W)
return x
def window_partition(x: Tensor, window_size: Tuple[int, int]):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
return windows
@register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows: Tensor, window_size: Tuple[int, int], H: int, W: int):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x: Tensor):
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = self.softmax(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
return x
class SpatialBlock(nn.Module):
r""" Windows Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self,
dim,
num_heads,
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
ffn=True,
cpe_act=False,
):
super().__init__()
self.dim = dim
self.ffn = ffn
self.num_heads = num_heads
self.window_size = to_2tuple(window_size)
self.mlp_ratio = mlp_ratio
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
self.window_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
if self.ffn:
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
else:
self.norm2 = None
self.mlp = None
self.drop_path1 = None
def forward(self, x: Tensor):
B, C, H, W = x.shape
shortcut = self.cpe1(x).flatten(2).transpose(1, 2)
x = self.norm1(shortcut)
x = x.view(B, H, W, C)
pad_l = pad_t = 0
pad_r = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
pad_b = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
x_windows = window_partition(x, self.window_size)
x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows)
# merge windows
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
x = window_reverse(attn_windows, self.window_size, Hp, Wp)
# if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
x = shortcut + self.drop_path1(x)
x = self.cpe2(x.transpose(1, 2).view(B, C, H, W))
if self.mlp is not None:
x = x.flatten(2).transpose(1, 2)
x = x + self.drop_path2(self.mlp(self.norm2(x)))
x = x.transpose(1, 2).view(B, C, H, W)
return x
class DaViTStage(nn.Module):
def __init__(
self,
in_chs,
out_chs,
depth=1,
downsample=True,
attn_types=('spatial', 'channel'),
num_heads=3,
window_size=7,
mlp_ratio=4,
qkv_bias=True,
drop_path_rates=(0, 0),
norm_layer=LayerNorm2d,
norm_layer_cl=nn.LayerNorm,
ffn=True,
cpe_act=False
):
super().__init__()
self.grad_checkpointing = False
# downsample embedding layer at the beginning of each stage
if downsample:
self.downsample = Downsample(in_chs, out_chs, norm_layer=norm_layer)
else:
self.downsample = nn.Identity()
'''
repeating alternating attention blocks in each stage
default: (spatial -> channel) x depth
potential opportunity to integrate with a more general version of ByobNet/ByoaNet
since the logic is similar
'''
stage_blocks = []
for block_idx in range(depth):
dual_attention_block = []
for attn_idx, attn_type in enumerate(attn_types):
if attn_type == 'spatial':
dual_attention_block.append(SpatialBlock(
dim=out_chs,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=drop_path_rates[block_idx],
norm_layer=norm_layer_cl,
ffn=ffn,
cpe_act=cpe_act,
window_size=window_size,
))
elif attn_type == 'channel':
dual_attention_block.append(ChannelBlock(
dim=out_chs,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=drop_path_rates[block_idx],
norm_layer=norm_layer_cl,
ffn=ffn,
cpe_act=cpe_act
))
stage_blocks.append(nn.Sequential(*dual_attention_block))
self.blocks = nn.Sequential(*stage_blocks)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
def forward(self, x: Tensor):
x = self.downsample(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
class DaViT(nn.Module):
r""" DaViT
A PyTorch implementation of `DaViT: Dual Attention Vision Transformers` - https://arxiv.org/abs/2204.03645
Supports arbitrary input sizes and pyramid feature extraction
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks in each stage. Default: (1, 1, 3, 1)
embed_dims (tuple(int)): Patch embedding dimension. Default: (96, 192, 384, 768)
num_heads (tuple(int)): Number of attention heads in different layers. Default: (3, 6, 12, 24)
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
"""
def __init__(
self,
in_chans=3,
depths=(1, 1, 3, 1),
embed_dims=(96, 192, 384, 768),
num_heads=(3, 6, 12, 24),
window_size=7,
mlp_ratio=4,
qkv_bias=True,
norm_layer='layernorm2d',
norm_layer_cl='layernorm',
norm_eps=1e-5,
attn_types=('spatial', 'channel'),
ffn=True,
cpe_act=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_classes=1000,
global_pool='avg',
head_norm_first=False,
):
super().__init__()
num_stages = len(embed_dims)
assert num_stages == len(num_heads) == len(depths)
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
self.num_classes = num_classes
self.num_features = embed_dims[-1]
self.drop_rate = drop_rate
self.grad_checkpointing = False
self.feature_info = []
self.stem = Stem(in_chans, embed_dims[0], norm_layer=norm_layer)
in_chs = embed_dims[0]
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
stages = []
for stage_idx in range(num_stages):
out_chs = embed_dims[stage_idx]
stage = DaViTStage(
in_chs,
out_chs,
depth=depths[stage_idx],
downsample=stage_idx > 0,
attn_types=attn_types,
num_heads=num_heads[stage_idx],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path_rates=dpr[stage_idx],
norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl,
ffn=ffn,
cpe_act=cpe_act,
)
in_chs = out_chs
stages.append(stage)
self.feature_info += [dict(num_chs=out_chs, reduction=2, module=f'stages.{stage_idx}')]
self.stages = nn.Sequential(*stages)
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
# otherwise pool -> norm -> fc, the default DaViT order, similar to ConvNeXt
# FIXME generalize this structure to ClassifierHead
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
for stage in self.stages:
stage.set_grad_checkpointing(enable=enable)
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
if global_pool is not None:
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stages, x)
else:
x = self.stages(x)
x = self.norm_pre(x)
return x
def forward_head(self, x, pre_logits: bool = False):
x = self.head.global_pool(x)
x = self.head.norm(x)
x = self.head.flatten(x)
x = self.head.drop(x)
return x if pre_logits else self.head.fc(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """
if 'head.fc.weight' in state_dict:
return state_dict # non-MSFT checkpoint
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
import re
out_dict = {}
for k, v in state_dict.items():
k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.downsample', k)
k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
k = k.replace('downsample.proj', 'downsample.conv')
k = k.replace('stages.0.downsample', 'stem')
k = k.replace('head.', 'head.fc.')
k = k.replace('norms.', 'head.norm.')
k = k.replace('cpe.0', 'cpe1')
k = k.replace('cpe.1', 'cpe2')
out_dict[k] = v
return out_dict
def _create_davit(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
model = build_model_with_cfg(
DaViT,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs)
return model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.95, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
**kwargs
}
# TODO contact authors to get larger pretrained models
default_cfgs = generate_default_cfgs({
# official microsoft weights from https://github.com/dingmyu/davit
'davit_tiny.msft_in1k': _cfg(
hf_hub_id='timm/'),
'davit_small.msft_in1k': _cfg(
hf_hub_id='timm/'),
'davit_base.msft_in1k': _cfg(
hf_hub_id='timm/'),
'davit_large': _cfg(),
'davit_huge': _cfg(),
'davit_giant': _cfg(),
})
@register_model
def davit_tiny(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24), **kwargs)
return _create_davit('davit_tiny', pretrained=pretrained, **model_kwargs)
@register_model
def davit_small(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24), **kwargs)
return _create_davit('davit_small', pretrained=pretrained, **model_kwargs)
@register_model
def davit_base(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32), **kwargs)
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
@register_model
def davit_large(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536), num_heads=(6, 12, 24, 48), **kwargs)
return _create_davit('davit_large', pretrained=pretrained, **model_kwargs)
@register_model
def davit_huge(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64), **kwargs)
return _create_davit('davit_huge', pretrained=pretrained, **model_kwargs)
@register_model
def davit_giant(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96), **kwargs)
return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs)

@ -12,7 +12,7 @@ import torch.utils.checkpoint as cp
from torch.jit.annotations import List from torch.jit.annotations import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import MATCH_PREV_GROUP from ._manipulate import MATCH_PREV_GROUP
from ._registry import register_model from ._registry import register_model
@ -115,8 +115,15 @@ class DenseBlock(nn.ModuleDict):
_version = 2 _version = 2
def __init__( def __init__(
self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=BatchNormAct2d, self,
drop_rate=0., memory_efficient=False): num_layers,
num_input_features,
bn_size,
growth_rate,
norm_layer=BatchNormAct2d,
drop_rate=0.,
memory_efficient=False,
):
super(DenseBlock, self).__init__() super(DenseBlock, self).__init__()
for i in range(num_layers): for i in range(num_layers):
layer = DenseLayer( layer = DenseLayer(
@ -165,12 +172,25 @@ class DenseNet(nn.Module):
""" """
def __init__( def __init__(
self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=1000, in_chans=3, global_pool='avg', self,
bn_size=4, stem_type='', norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, growth_rate=32,
memory_efficient=False, aa_stem_only=True): block_config=(6, 12, 24, 16),
num_classes=1000,
in_chans=3,
global_pool='avg',
bn_size=4,
stem_type='',
act_layer='relu',
norm_layer='batchnorm2d',
aa_layer=None,
drop_rate=0,
memory_efficient=False,
aa_stem_only=True,
):
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
super(DenseNet, self).__init__() super(DenseNet, self).__init__()
norm_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
# Stem # Stem
deep_stem = 'deep' in stem_type # 3x3 deep stem deep_stem = 'deep' in stem_type # 3x3 deep stem
@ -226,8 +246,11 @@ class DenseNet(nn.Module):
dict(num_chs=num_features, reduction=current_stride, module='features.' + module_name)] dict(num_chs=num_features, reduction=current_stride, module='features.' + module_name)]
current_stride *= 2 current_stride *= 2
trans = DenseTransition( trans = DenseTransition(
num_input_features=num_features, num_output_features=num_features // 2, num_input_features=num_features,
norm_layer=norm_layer, aa_layer=transition_aa_layer) num_output_features=num_features // 2,
norm_layer=norm_layer,
aa_layer=transition_aa_layer,
)
self.features.add_module(f'transition{i + 1}', trans) self.features.add_module(f'transition{i + 1}', trans)
num_features = num_features // 2 num_features = num_features // 2
@ -322,8 +345,8 @@ def densenetblur121d(pretrained=False, **kwargs):
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>` `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
""" """
model = _create_densenet( model = _create_densenet(
'densenetblur121d', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, stem_type='deep', 'densenetblur121d', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained,
aa_layer=BlurPool2d, **kwargs) stem_type='deep', aa_layer=BlurPool2d, **kwargs)
return model return model
@ -382,11 +405,9 @@ def densenet264(pretrained=False, **kwargs):
def densenet264d_iabn(pretrained=False, **kwargs): def densenet264d_iabn(pretrained=False, **kwargs):
r"""Densenet-264 model with deep stem and Inplace-ABN r"""Densenet-264 model with deep stem and Inplace-ABN
""" """
def norm_act_fn(num_features, **kwargs):
return create_norm_act_layer('iabn', num_features, act_layer='leaky_relu', **kwargs)
model = _create_densenet( model = _create_densenet(
'densenet264d_iabn', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep', 'densenet264d_iabn', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep',
norm_layer=norm_act_fn, pretrained=pretrained, **kwargs) norm_layer='iabn', act_layer='leaky_relu', pretrained=pretrained, **kwargs)
return model return model

@ -15,7 +15,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier, get_norm_act_layer
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._registry import register_model from ._registry import register_model
@ -33,6 +33,7 @@ def _cfg(url='', **kwargs):
default_cfgs = { default_cfgs = {
'dpn48b': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'dpn68': _cfg( 'dpn68': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'), url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'),
'dpn68b': _cfg( 'dpn68b': _cfg(
@ -82,7 +83,16 @@ class BnActConv2d(nn.Module):
class DualPathBlock(nn.Module): class DualPathBlock(nn.Module):
def __init__( def __init__(
self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, groups, block_type='normal', b=False): self,
in_chs,
num_1x1_a,
num_3x3_b,
num_1x1_c,
inc,
groups,
block_type='normal',
b=False,
):
super(DualPathBlock, self).__init__() super(DualPathBlock, self).__init__()
self.num_1x1_c = num_1x1_c self.num_1x1_c = num_1x1_c
self.inc = inc self.inc = inc
@ -167,16 +177,31 @@ class DualPathBlock(nn.Module):
class DPN(nn.Module): class DPN(nn.Module):
def __init__( def __init__(
self, small=False, num_init_features=64, k_r=96, groups=32, global_pool='avg', self,
b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), output_stride=32, k_sec=(3, 4, 20, 3),
num_classes=1000, in_chans=3, drop_rate=0., fc_act_layer=nn.ELU): inc_sec=(16, 32, 24, 128),
k_r=96,
groups=32,
num_classes=1000,
in_chans=3,
output_stride=32,
global_pool='avg',
small=False,
num_init_features=64,
b=False,
drop_rate=0.,
norm_layer='batchnorm2d',
act_layer='relu',
fc_act_layer='elu',
):
super(DPN, self).__init__() super(DPN, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.b = b self.b = b
assert output_stride == 32 # FIXME look into dilation support assert output_stride == 32 # FIXME look into dilation support
norm_layer = partial(BatchNormAct2d, eps=.001)
fc_norm_layer = partial(BatchNormAct2d, eps=.001, act_layer=fc_act_layer, inplace=False) norm_layer = partial(get_norm_act_layer(norm_layer, act_layer=act_layer), eps=.001)
fc_norm_layer = partial(get_norm_act_layer(norm_layer, act_layer=fc_act_layer), eps=.001, inplace=False)
bw_factor = 1 if small else 4 bw_factor = 1 if small else 4
blocks = OrderedDict() blocks = OrderedDict()
@ -291,49 +316,57 @@ def _create_dpn(variant, pretrained=False, **kwargs):
**kwargs) **kwargs)
@register_model
def dpn48b(pretrained=False, **kwargs):
model_kwargs = dict(
small=True, num_init_features=10, k_r=128, groups=32,
b=True, k_sec=(3, 4, 6, 3), inc_sec=(16, 32, 32, 64), act_layer='silu')
return _create_dpn('dpn48b', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model @register_model
def dpn68(pretrained=False, **kwargs): def dpn68(pretrained=False, **kwargs):
model_kwargs = dict( model_kwargs = dict(
small=True, num_init_features=10, k_r=128, groups=32, small=True, num_init_features=10, k_r=128, groups=32,
k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs) k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64))
return _create_dpn('dpn68', pretrained=pretrained, **model_kwargs) return _create_dpn('dpn68', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model @register_model
def dpn68b(pretrained=False, **kwargs): def dpn68b(pretrained=False, **kwargs):
model_kwargs = dict( model_kwargs = dict(
small=True, num_init_features=10, k_r=128, groups=32, small=True, num_init_features=10, k_r=128, groups=32,
b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs) b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64))
return _create_dpn('dpn68b', pretrained=pretrained, **model_kwargs) return _create_dpn('dpn68b', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model @register_model
def dpn92(pretrained=False, **kwargs): def dpn92(pretrained=False, **kwargs):
model_kwargs = dict( model_kwargs = dict(
num_init_features=64, k_r=96, groups=32, num_init_features=64, k_r=96, groups=32,
k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), **kwargs) k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128))
return _create_dpn('dpn92', pretrained=pretrained, **model_kwargs) return _create_dpn('dpn92', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model @register_model
def dpn98(pretrained=False, **kwargs): def dpn98(pretrained=False, **kwargs):
model_kwargs = dict( model_kwargs = dict(
num_init_features=96, k_r=160, groups=40, num_init_features=96, k_r=160, groups=40,
k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128), **kwargs) k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128))
return _create_dpn('dpn98', pretrained=pretrained, **model_kwargs) return _create_dpn('dpn98', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model @register_model
def dpn131(pretrained=False, **kwargs): def dpn131(pretrained=False, **kwargs):
model_kwargs = dict( model_kwargs = dict(
num_init_features=128, k_r=160, groups=40, num_init_features=128, k_r=160, groups=40,
k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128), **kwargs) k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128))
return _create_dpn('dpn131', pretrained=pretrained, **model_kwargs) return _create_dpn('dpn131', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model @register_model
def dpn107(pretrained=False, **kwargs): def dpn107(pretrained=False, **kwargs):
model_kwargs = dict( model_kwargs = dict(
num_init_features=128, k_r=200, groups=50, num_init_features=128, k_r=200, groups=50,
k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128), **kwargs) k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128))
return _create_dpn('dpn107', pretrained=pretrained, **model_kwargs) return _create_dpn('dpn107', pretrained=pretrained, **dict(model_kwargs, **kwargs))

@ -12,9 +12,6 @@ These configs work well and appear to be a bit faster / lower resource than the
The models without extra prefix / suffix' (coatnet_0_224, maxvit_tiny_224, etc), are intended to The models without extra prefix / suffix' (coatnet_0_224, maxvit_tiny_224, etc), are intended to
match paper, BUT, without any official pretrained weights it's difficult to confirm a 100% match. match paper, BUT, without any official pretrained weights it's difficult to confirm a 100% match.
# FIXME / WARNING
This impl remains a WIP, some configs and models may vanish or change...
Papers: Papers:
MaxViT: Multi-Axis Vision Transformer - https://arxiv.org/abs/2204.01697 MaxViT: Multi-Axis Vision Transformer - https://arxiv.org/abs/2204.01697
@ -76,6 +73,8 @@ class MaxxVitTransformerCfg:
partition_ratio: int = 32 partition_ratio: int = 32
window_size: Optional[Tuple[int, int]] = None window_size: Optional[Tuple[int, int]] = None
grid_size: Optional[Tuple[int, int]] = None grid_size: Optional[Tuple[int, int]] = None
no_block_attn: bool = False # disable window block attention for maxvit (ie only grid)
use_nchw_attn: bool = False # for MaxViT variants (not used for CoAt), keep tensors in NCHW order
init_values: Optional[float] = None init_values: Optional[float] = None
act_layer: str = 'gelu' act_layer: str = 'gelu'
norm_layer: str = 'layernorm2d' norm_layer: str = 'layernorm2d'
@ -889,19 +888,17 @@ class MaxxVitBlock(nn.Module):
stride: int = 1, stride: int = 1,
conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
use_nchw_attn: bool = False, # FIXME move to cfg? True is ~20-30% faster on TPU, 5-10% slower on GPU
use_block_attn: bool = True, # FIXME for testing ConvNeXt conv w/o block attention
drop_path: float = 0., drop_path: float = 0.,
): ):
super().__init__() super().__init__()
self.nchw_attn = transformer_cfg.use_nchw_attn
conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttentionCl partition_layer = PartitionAttention2d if self.nchw_attn else PartitionAttentionCl
self.nchw_attn = use_nchw_attn self.attn_block = None if transformer_cfg.no_block_attn else partition_layer(**attn_kwargs)
self.attn_block = partition_layer(**attn_kwargs) if use_block_attn else None
self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs) self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs)
def init_weights(self, scheme=''): def init_weights(self, scheme=''):
@ -1084,26 +1081,48 @@ class NormMlpHead(nn.Module):
hidden_size=None, hidden_size=None,
pool_type='avg', pool_type='avg',
drop_rate=0., drop_rate=0.,
norm_layer=nn.LayerNorm, norm_layer='layernorm2d',
act_layer=nn.Tanh, act_layer='tanh',
): ):
super().__init__() super().__init__()
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.in_features = in_features
self.hidden_size = hidden_size
self.num_features = in_features self.num_features = in_features
self.use_conv = not pool_type
norm_layer = get_norm_layer(norm_layer)
act_layer = get_act_layer(act_layer)
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
self.norm = norm_layer(in_features) self.norm = norm_layer(in_features)
self.flatten = nn.Flatten(1) if pool_type else nn.Identity() self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
if hidden_size: if hidden_size:
self.pre_logits = nn.Sequential(OrderedDict([ self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(in_features, hidden_size)), ('fc', linear_layer(in_features, hidden_size)),
('act', act_layer()), ('act', act_layer()),
])) ]))
self.num_features = hidden_size self.num_features = hidden_size
else: else:
self.pre_logits = nn.Identity() self.pre_logits = nn.Identity()
self.drop = nn.Dropout(self.drop_rate) self.drop = nn.Dropout(self.drop_rate)
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def reset(self, num_classes, global_pool=None):
if global_pool is not None:
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.use_conv = self.global_pool.is_identity()
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
if self.hidden_size:
if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or
(isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)):
with torch.no_grad():
new_fc = linear_layer(self.in_features, self.hidden_size)
new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape))
new_fc.bias.copy_(self.pre_logits.fc.bias)
self.pre_logits.fc = new_fc
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x, pre_logits: bool = False): def forward(self, x, pre_logits: bool = False):
x = self.global_pool(x) x = self.global_pool(x)
@ -1116,6 +1135,26 @@ class NormMlpHead(nn.Module):
return x return x
def _overlay_kwargs(cfg: MaxxVitCfg, **kwargs):
transformer_kwargs = {}
conv_kwargs = {}
base_kwargs = {}
for k, v in kwargs.items():
if k.startswith('transformer_'):
transformer_kwargs[k.replace('transformer_', '')] = v
elif k.startswith('conv_'):
conv_kwargs[k.replace('conv_', '')] = v
else:
base_kwargs[k] = v
cfg = replace(
cfg,
transformer_cfg=replace(cfg.transformer_cfg, **transformer_kwargs),
conv_cfg=replace(cfg.conv_cfg, **conv_kwargs),
**base_kwargs
)
return cfg
class MaxxVit(nn.Module): class MaxxVit(nn.Module):
""" CoaTNet + MaxVit base model. """ CoaTNet + MaxVit base model.
@ -1130,16 +1169,20 @@ class MaxxVit(nn.Module):
num_classes: int = 1000, num_classes: int = 1000,
global_pool: str = 'avg', global_pool: str = 'avg',
drop_rate: float = 0., drop_rate: float = 0.,
drop_path_rate: float = 0. drop_path_rate: float = 0.,
**kwargs,
): ):
super().__init__() super().__init__()
img_size = to_2tuple(img_size) img_size = to_2tuple(img_size)
if kwargs:
cfg = _overlay_kwargs(cfg, **kwargs)
transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size) transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size)
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = global_pool self.global_pool = global_pool
self.num_features = self.embed_dim = cfg.embed_dim[-1] self.num_features = self.embed_dim = cfg.embed_dim[-1]
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.grad_checkpointing = False self.grad_checkpointing = False
self.feature_info = []
self.stem = Stem( self.stem = Stem(
in_chs=in_chans, in_chs=in_chans,
@ -1150,8 +1193,8 @@ class MaxxVit(nn.Module):
norm_layer=cfg.conv_cfg.norm_layer, norm_layer=cfg.conv_cfg.norm_layer,
norm_eps=cfg.conv_cfg.norm_eps, norm_eps=cfg.conv_cfg.norm_eps,
) )
stride = self.stem.stride stride = self.stem.stride
self.feature_info += [dict(num_chs=self.stem.out_chs, reduction=2, module='stem')]
feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))]) feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))])
num_stages = len(cfg.embed_dim) num_stages = len(cfg.embed_dim)
@ -1175,15 +1218,17 @@ class MaxxVit(nn.Module):
)] )]
stride *= stage_stride stride *= stage_stride
in_chs = out_chs in_chs = out_chs
self.feature_info += [dict(num_chs=out_chs, reduction=stride, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages) self.stages = nn.Sequential(*stages)
final_norm_layer = partial(get_norm_layer(cfg.transformer_cfg.norm_layer), eps=cfg.transformer_cfg.norm_eps) final_norm_layer = partial(get_norm_layer(cfg.transformer_cfg.norm_layer), eps=cfg.transformer_cfg.norm_eps)
if cfg.head_hidden_size: self.head_hidden_size = cfg.head_hidden_size
if self.head_hidden_size:
self.norm = nn.Identity() self.norm = nn.Identity()
self.head = NormMlpHead( self.head = NormMlpHead(
self.num_features, self.num_features,
num_classes, num_classes,
hidden_size=cfg.head_hidden_size, hidden_size=self.head_hidden_size,
pool_type=global_pool, pool_type=global_pool,
drop_rate=drop_rate, drop_rate=drop_rate,
norm_layer=final_norm_layer, norm_layer=final_norm_layer,
@ -1230,9 +1275,7 @@ class MaxxVit(nn.Module):
def reset_classifier(self, num_classes, global_pool=None): def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes self.num_classes = num_classes
if global_pool is None: self.head.reset(num_classes, global_pool)
global_pool = self.head.global_pool.pool_type
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
@ -1353,6 +1396,7 @@ def _next_cfg(
transformer_norm_layer='layernorm2d', transformer_norm_layer='layernorm2d',
transformer_norm_layer_cl='layernorm', transformer_norm_layer_cl='layernorm',
window_size=None, window_size=None,
no_block_attn=False,
init_values=1e-6, init_values=1e-6,
rel_pos_type='mlp', # MLP by default for maxxvit rel_pos_type='mlp', # MLP by default for maxxvit
rel_pos_dim=512, rel_pos_dim=512,
@ -1373,6 +1417,7 @@ def _next_cfg(
expand_first=False, expand_first=False,
pool_type=pool_type, pool_type=pool_type,
window_size=window_size, window_size=window_size,
no_block_attn=no_block_attn, # enabled for MaxxViT-V2
init_values=init_values[1], init_values=init_values[1],
norm_layer=transformer_norm_layer, norm_layer=transformer_norm_layer,
norm_layer_cl=transformer_norm_layer_cl, norm_layer_cl=transformer_norm_layer_cl,
@ -1399,8 +1444,8 @@ def _tf_cfg():
model_cfgs = dict( model_cfgs = dict(
# Fiddling with configs / defaults / still pretraining # timm specific CoAtNet configs
coatnet_pico_rw_224=MaxxVitCfg( coatnet_pico_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(2, 3, 5, 2), depths=(2, 3, 5, 2),
stem_width=(32, 64), stem_width=(32, 64),
@ -1409,7 +1454,7 @@ model_cfgs = dict(
conv_attn_ratio=0.25, conv_attn_ratio=0.25,
), ),
), ),
coatnet_nano_rw_224=MaxxVitCfg( coatnet_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3), depths=(3, 4, 6, 3),
stem_width=(32, 64), stem_width=(32, 64),
@ -1419,7 +1464,7 @@ model_cfgs = dict(
conv_attn_ratio=0.25, conv_attn_ratio=0.25,
), ),
), ),
coatnet_0_rw_224=MaxxVitCfg( coatnet_0_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 3, 7, 2), # deeper than paper '0' model depths=(2, 3, 7, 2), # deeper than paper '0' model
stem_width=(32, 64), stem_width=(32, 64),
@ -1428,7 +1473,7 @@ model_cfgs = dict(
transformer_shortcut_bias=False, transformer_shortcut_bias=False,
), ),
), ),
coatnet_1_rw_224=MaxxVitCfg( coatnet_1_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
stem_width=(32, 64), stem_width=(32, 64),
@ -1438,7 +1483,7 @@ model_cfgs = dict(
transformer_shortcut_bias=False, transformer_shortcut_bias=False,
) )
), ),
coatnet_2_rw_224=MaxxVitCfg( coatnet_2_rw=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024), embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
stem_width=(64, 128), stem_width=(64, 128),
@ -1448,7 +1493,7 @@ model_cfgs = dict(
#init_values=1e-6, #init_values=1e-6,
), ),
), ),
coatnet_3_rw_224=MaxxVitCfg( coatnet_3_rw=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536), embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
stem_width=(96, 192), stem_width=(96, 192),
@ -1459,8 +1504,8 @@ model_cfgs = dict(
), ),
), ),
# Highly experimental configs # Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
coatnet_bn_0_rw_224=MaxxVitCfg( coatnet_bn_0_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 3, 7, 2), # deeper than paper '0' model depths=(2, 3, 7, 2), # deeper than paper '0' model
stem_width=(32, 64), stem_width=(32, 64),
@ -1471,7 +1516,7 @@ model_cfgs = dict(
transformer_norm_layer='batchnorm2d', transformer_norm_layer='batchnorm2d',
) )
), ),
coatnet_rmlp_nano_rw_224=MaxxVitCfg( coatnet_rmlp_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3), depths=(3, 4, 6, 3),
stem_width=(32, 64), stem_width=(32, 64),
@ -1482,7 +1527,7 @@ model_cfgs = dict(
rel_pos_dim=384, rel_pos_dim=384,
), ),
), ),
coatnet_rmlp_0_rw_224=MaxxVitCfg( coatnet_rmlp_0_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 3, 7, 2), # deeper than paper '0' model depths=(2, 3, 7, 2), # deeper than paper '0' model
stem_width=(32, 64), stem_width=(32, 64),
@ -1491,7 +1536,7 @@ model_cfgs = dict(
rel_pos_type='mlp', rel_pos_type='mlp',
), ),
), ),
coatnet_rmlp_1_rw_224=MaxxVitCfg( coatnet_rmlp_1_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
stem_width=(32, 64), stem_width=(32, 64),
@ -1503,7 +1548,7 @@ model_cfgs = dict(
rel_pos_dim=384, # was supposed to be 512, woops rel_pos_dim=384, # was supposed to be 512, woops
), ),
), ),
coatnet_rmlp_1_rw2_224=MaxxVitCfg( coatnet_rmlp_1_rw2=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
stem_width=(32, 64), stem_width=(32, 64),
@ -1513,7 +1558,7 @@ model_cfgs = dict(
rel_pos_dim=512, # was supposed to be 512, woops rel_pos_dim=512, # was supposed to be 512, woops
), ),
), ),
coatnet_rmlp_2_rw_224=MaxxVitCfg( coatnet_rmlp_2_rw=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024), embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
stem_width=(64, 128), stem_width=(64, 128),
@ -1524,7 +1569,7 @@ model_cfgs = dict(
rel_pos_type='mlp' rel_pos_type='mlp'
), ),
), ),
coatnet_rmlp_3_rw_224=MaxxVitCfg( coatnet_rmlp_3_rw=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536), embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
stem_width=(96, 192), stem_width=(96, 192),
@ -1536,14 +1581,14 @@ model_cfgs = dict(
), ),
), ),
coatnet_nano_cc_224=MaxxVitCfg( coatnet_nano_cc=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3), depths=(3, 4, 6, 3),
stem_width=(32, 64), stem_width=(32, 64),
block_type=('C', 'C', ('C', 'T'), ('C', 'T')), block_type=('C', 'C', ('C', 'T'), ('C', 'T')),
**_rw_coat_cfg(), **_rw_coat_cfg(),
), ),
coatnext_nano_rw_224=MaxxVitCfg( coatnext_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3), depths=(3, 4, 6, 3),
stem_width=(32, 64), stem_width=(32, 64),
@ -1555,89 +1600,95 @@ model_cfgs = dict(
), ),
# Trying to be like the CoAtNet paper configs # Trying to be like the CoAtNet paper configs
coatnet_0_224=MaxxVitCfg( coatnet_0=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 3, 5, 2), depths=(2, 3, 5, 2),
stem_width=64, stem_width=64,
head_hidden_size=768,
), ),
coatnet_1_224=MaxxVitCfg( coatnet_1=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
stem_width=64, stem_width=64,
head_hidden_size=768,
), ),
coatnet_2_224=MaxxVitCfg( coatnet_2=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024), embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
stem_width=128, stem_width=128,
head_hidden_size=1024,
), ),
coatnet_3_224=MaxxVitCfg( coatnet_3=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536), embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2), depths=(2, 6, 14, 2),
stem_width=192, stem_width=192,
head_hidden_size=1536,
), ),
coatnet_4_224=MaxxVitCfg( coatnet_4=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536), embed_dim=(192, 384, 768, 1536),
depths=(2, 12, 28, 2), depths=(2, 12, 28, 2),
stem_width=192, stem_width=192,
head_hidden_size=1536,
), ),
coatnet_5_224=MaxxVitCfg( coatnet_5=MaxxVitCfg(
embed_dim=(256, 512, 1280, 2048), embed_dim=(256, 512, 1280, 2048),
depths=(2, 12, 28, 2), depths=(2, 12, 28, 2),
stem_width=192, stem_width=192,
head_hidden_size=2048,
), ),
# Experimental MaxVit configs # Experimental MaxVit configs
maxvit_pico_rw_256=MaxxVitCfg( maxvit_pico_rw=MaxxVitCfg(
embed_dim=(32, 64, 128, 256), embed_dim=(32, 64, 128, 256),
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(24, 32), stem_width=(24, 32),
**_rw_max_cfg(), **_rw_max_cfg(),
), ),
maxvit_nano_rw_256=MaxxVitCfg( maxvit_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1), depths=(1, 2, 3, 1),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(), **_rw_max_cfg(),
), ),
maxvit_tiny_rw_224=MaxxVitCfg( maxvit_tiny_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(), **_rw_max_cfg(),
), ),
maxvit_tiny_rw_256=MaxxVitCfg( maxvit_tiny_pm=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('M',) * 4, block_type=('PM',) * 4,
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(), **_rw_max_cfg(),
), ),
maxvit_rmlp_pico_rw_256=MaxxVitCfg( maxvit_rmlp_pico_rw=MaxxVitCfg(
embed_dim=(32, 64, 128, 256), embed_dim=(32, 64, 128, 256),
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(24, 32), stem_width=(24, 32),
**_rw_max_cfg(rel_pos_type='mlp'), **_rw_max_cfg(rel_pos_type='mlp'),
), ),
maxvit_rmlp_nano_rw_256=MaxxVitCfg( maxvit_rmlp_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1), depths=(1, 2, 3, 1),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(rel_pos_type='mlp'), **_rw_max_cfg(rel_pos_type='mlp'),
), ),
maxvit_rmlp_tiny_rw_256=MaxxVitCfg( maxvit_rmlp_tiny_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(32, 64), stem_width=(32, 64),
**_rw_max_cfg(rel_pos_type='mlp'), **_rw_max_cfg(rel_pos_type='mlp'),
), ),
maxvit_rmlp_small_rw_224=MaxxVitCfg( maxvit_rmlp_small_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('M',) * 4, block_type=('M',) * 4,
@ -1647,26 +1698,18 @@ model_cfgs = dict(
init_values=1e-6, init_values=1e-6,
), ),
), ),
maxvit_rmlp_small_rw_256=MaxxVitCfg( maxvit_rmlp_base_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2), depths=(2, 6, 14, 2),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(32, 64), stem_width=(32, 64),
head_hidden_size=768,
**_rw_max_cfg( **_rw_max_cfg(
rel_pos_type='mlp', rel_pos_type='mlp',
init_values=1e-6,
), ),
), ),
maxvit_tiny_pm_256=MaxxVitCfg( maxxvit_rmlp_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('PM',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(),
),
maxxvit_rmlp_nano_rw_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1), depths=(1, 2, 3, 1),
block_type=('M',) * 4, block_type=('M',) * 4,
@ -1674,33 +1717,50 @@ model_cfgs = dict(
weight_init='normal', weight_init='normal',
**_next_cfg(), **_next_cfg(),
), ),
maxxvit_rmlp_tiny_rw_256=MaxxVitCfg( maxxvit_rmlp_tiny_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512), embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(32, 64), stem_width=(32, 64),
**_next_cfg(), **_next_cfg(),
), ),
maxxvit_rmlp_small_rw_256=MaxxVitCfg( maxxvit_rmlp_small_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2), depths=(2, 2, 5, 2),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(48, 96), stem_width=(48, 96),
**_next_cfg(), **_next_cfg(),
), ),
maxxvit_rmlp_base_rw_224=MaxxVitCfg(
maxxvitv2_nano_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768), embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2), depths=(1, 2, 3, 1),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(48, 96), stem_width=(48, 96),
**_next_cfg(), weight_init='normal',
**_next_cfg(
no_block_attn=True,
rel_pos_type='bias',
),
), ),
maxxvit_rmlp_large_rw_224=MaxxVitCfg( maxxvitv2_rmlp_base_rw=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024), embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 12, 2), depths=(2, 6, 12, 2),
block_type=('M',) * 4, block_type=('M',) * 4,
stem_width=(64, 128), stem_width=(64, 128),
**_next_cfg(), **_next_cfg(
no_block_attn=True,
),
),
maxxvitv2_rmlp_large_rw=MaxxVitCfg(
embed_dim=(160, 320, 640, 1280),
depths=(2, 6, 16, 2),
block_type=('M',) * 4,
stem_width=(80, 160),
head_hidden_size=1280,
**_next_cfg(
no_block_attn=True,
),
), ),
# Trying to be like the MaxViT paper configs # Trying to be like the MaxViT paper configs
@ -1752,11 +1812,29 @@ model_cfgs = dict(
) )
def checkpoint_filter_fn(state_dict, model: nn.Module):
model_state_dict = model.state_dict()
out_dict = {}
for k, v in state_dict.items():
if k in model_state_dict and v.ndim != model_state_dict[k].ndim and v.numel() == model_state_dict[k].numel():
# adapt between conv2d / linear layers
assert v.ndim in (2, 4)
v = v.reshape(model_state_dict[k].shape)
out_dict[k] = v
return out_dict
def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs): def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs):
if cfg_variant is None:
if variant in model_cfgs:
cfg_variant = variant
else:
cfg_variant = '_'.join(variant.split('_')[:-1])
return build_model_with_cfg( return build_model_with_cfg(
MaxxVit, variant, pretrained, MaxxVit, variant, pretrained,
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], model_cfg=model_cfgs[cfg_variant],
feature_cfg=dict(flatten_sequential=True), feature_cfg=dict(flatten_sequential=True),
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs) **kwargs)
@ -1772,149 +1850,218 @@ def _cfg(url='', **kwargs):
default_cfgs = generate_default_cfgs({ default_cfgs = generate_default_cfgs({
# Fiddling with configs / defaults / still pretraining # timm specific CoAtNet configs, ImageNet-1k pretrain, fixed rel-pos
'coatnet_pico_rw_224': _cfg(url=''), 'coatnet_pico_rw_224.untrained': _cfg(url=''),
'coatnet_nano_rw_224': _cfg( 'coatnet_nano_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth',
crop_pct=0.9), crop_pct=0.9),
'coatnet_0_rw_224': _cfg( 'coatnet_0_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'),
'coatnet_1_rw_224': _cfg( 'coatnet_1_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth' url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth'
), ),
'coatnet_2_rw_224': _cfg(url=''),
'coatnet_3_rw_224': _cfg(url=''),
# Highly experimental configs # timm specific CoAtNet configs, ImageNet-12k pretrain w/ 1k fine-tune, fixed rel-pos
'coatnet_bn_0_rw_224': _cfg( 'coatnet_2_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
#'coatnet_3_rw_224.untrained': _cfg(url=''),
# Experimental CoAtNet configs w/ ImageNet-12k pretrain -> 1k fine-tune (different norm layers, MLP rel-pos)
'coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
'coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
'coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
# Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
'coatnet_bn_0_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
crop_pct=0.95), crop_pct=0.95),
'coatnet_rmlp_nano_rw_224': _cfg( 'coatnet_rmlp_nano_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth',
crop_pct=0.9), crop_pct=0.9),
'coatnet_rmlp_0_rw_224': _cfg(url=''), 'coatnet_rmlp_0_rw_224.untrained': _cfg(url=''),
'coatnet_rmlp_1_rw_224': _cfg( 'coatnet_rmlp_1_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'),
'coatnet_rmlp_1_rw2_224': _cfg(url=''), 'coatnet_rmlp_2_rw_224.sw_in1k': _cfg(
'coatnet_rmlp_2_rw_224': _cfg( hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'),
'coatnet_rmlp_3_rw_224': _cfg(url=''), 'coatnet_rmlp_3_rw_224.untrained': _cfg(url=''),
'coatnet_nano_cc_224': _cfg(url=''), 'coatnet_nano_cc_224.untrained': _cfg(url=''),
'coatnext_nano_rw_224': _cfg( 'coatnext_nano_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth',
crop_pct=0.9), crop_pct=0.9),
# Trying to be like the CoAtNet paper configs # ImagenNet-12k pretrain CoAtNet
'coatnet_0_224': _cfg(url=''), 'coatnet_2_rw_224.sw_in12k': _cfg(
'coatnet_1_224': _cfg(url=''), hf_hub_id='timm/',
'coatnet_2_224': _cfg(url=''), num_classes=11821),
'coatnet_3_224': _cfg(url=''), 'coatnet_3_rw_224.sw_in12k': _cfg(
'coatnet_4_224': _cfg(url=''), hf_hub_id='timm/',
'coatnet_5_224': _cfg(url=''), num_classes=11821),
'coatnet_rmlp_1_rw2_224.sw_in12k': _cfg(
# Experimental configs hf_hub_id='timm/',
'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), num_classes=11821),
'maxvit_nano_rw_256': _cfg( 'coatnet_rmlp_2_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
# Trying to be like the CoAtNet paper configs (will adapt if 'tf' weights are ever released)
'coatnet_0_224.untrained': _cfg(url=''),
'coatnet_1_224.untrained': _cfg(url=''),
'coatnet_2_224.untrained': _cfg(url=''),
'coatnet_3_224.untrained': _cfg(url=''),
'coatnet_4_224.untrained': _cfg(url=''),
'coatnet_5_224.untrained': _cfg(url=''),
# timm specific MaxVit configs, ImageNet-1k pretrain or untrained
'maxvit_pico_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth',
input_size=(3, 256, 256), pool_size=(8, 8)), input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_rw_224': _cfg( 'maxvit_tiny_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'),
'maxvit_tiny_rw_256': _cfg( 'maxvit_tiny_rw_256.untrained': _cfg(
url='', url='',
input_size=(3, 256, 256), pool_size=(8, 8)), input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_pico_rw_256': _cfg( 'maxvit_tiny_pm_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
# timm specific MaxVit w/ MLP rel-pos, ImageNet-1k pretrain
'maxvit_rmlp_pico_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth',
input_size=(3, 256, 256), pool_size=(8, 8)), input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_nano_rw_256': _cfg( 'maxvit_rmlp_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth',
input_size=(3, 256, 256), pool_size=(8, 8)), input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_tiny_rw_256': _cfg( 'maxvit_rmlp_tiny_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth',
input_size=(3, 256, 256), pool_size=(8, 8)), input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_small_rw_224': _cfg( 'maxvit_rmlp_small_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth',
crop_pct=0.9, crop_pct=0.9,
), ),
'maxvit_rmlp_small_rw_256': _cfg( 'maxvit_rmlp_small_rw_256.untrained': _cfg(
url='', url='',
input_size=(3, 256, 256), pool_size=(8, 8)), input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), # timm specific MaxVit w/ ImageNet-12k pretrain and 1k fine-tune
'maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
),
'maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
# timm specific MaxVit w/ ImageNet-12k pretrain
'maxvit_rmlp_base_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821,
),
'maxxvit_rmlp_nano_rw_256': _cfg( # timm MaxxViT configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks)
'maxxvit_rmlp_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth',
input_size=(3, 256, 256), pool_size=(8, 8)), input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_rmlp_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), 'maxxvit_rmlp_tiny_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_rmlp_small_rw_256': _cfg( 'maxxvit_rmlp_small_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth',
input_size=(3, 256, 256), pool_size=(8, 8)), input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_rmlp_base_rw_224': _cfg(url=''),
'maxxvit_rmlp_large_rw_224': _cfg(url=''),
# timm MaxxViT-V2 configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks, more width, no block attn)
'maxxvitv2_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
'maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxxvitv2_rmlp_large_rw_224.untrained': _cfg(url=''),
'maxxvitv2_rmlp_base_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
# MaxViT models ported from official Tensorflow impl # MaxViT models ported from official Tensorflow impl
'maxvit_tiny_tf_224.in1k': _cfg( 'maxvit_tiny_tf_224.in1k': _cfg(
hf_hub_id='timm/maxvit_tiny_tf_224.in1k', hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_tiny_tf_384.in1k': _cfg( 'maxvit_tiny_tf_384.in1k': _cfg(
hf_hub_id='timm/maxvit_tiny_tf_384.in1k', hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_tiny_tf_512.in1k': _cfg( 'maxvit_tiny_tf_512.in1k': _cfg(
hf_hub_id='timm/maxvit_tiny_tf_512.in1k', hf_hub_id='timm/',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_small_tf_224.in1k': _cfg( 'maxvit_small_tf_224.in1k': _cfg(
hf_hub_id='timm/maxvit_small_tf_224.in1k', hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_small_tf_384.in1k': _cfg( 'maxvit_small_tf_384.in1k': _cfg(
hf_hub_id='timm/maxvit_small_tf_384.in1k', hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_small_tf_512.in1k': _cfg( 'maxvit_small_tf_512.in1k': _cfg(
hf_hub_id='timm/maxvit_small_tf_512.in1k', hf_hub_id='timm/',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_224.in1k': _cfg( 'maxvit_base_tf_224.in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_224.in1k', hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_base_tf_384.in1k': _cfg( 'maxvit_base_tf_384.in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_384.in1k', hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_512.in1k': _cfg( 'maxvit_base_tf_512.in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_512.in1k', hf_hub_id='timm/',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_224.in1k': _cfg( 'maxvit_large_tf_224.in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_224.in1k', hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_large_tf_384.in1k': _cfg( 'maxvit_large_tf_384.in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_384.in1k', hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_512.in1k': _cfg( 'maxvit_large_tf_512.in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_512.in1k', hf_hub_id='timm/',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_224.in21k': _cfg( 'maxvit_base_tf_224.in21k': _cfg(
url=''), url=''),
'maxvit_base_tf_384.in21k_ft_in1k': _cfg( 'maxvit_base_tf_384.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_384.in21k_ft_in1k', hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_512.in21k_ft_in1k': _cfg( 'maxvit_base_tf_512.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_512.in21k_ft_in1k', hf_hub_id='timm/',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_224.in21k': _cfg( 'maxvit_large_tf_224.in21k': _cfg(
url=''), url=''),
'maxvit_large_tf_384.in21k_ft_in1k': _cfg( 'maxvit_large_tf_384.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_384.in21k_ft_in1k', hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_512.in21k_ft_in1k': _cfg( 'maxvit_large_tf_512.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_512.in21k_ft_in1k', hf_hub_id='timm/',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
'maxvit_xlarge_tf_224.in21k': _cfg( 'maxvit_xlarge_tf_224.in21k': _cfg(
url=''), url=''),
'maxvit_xlarge_tf_384.in21k_ft_in1k': _cfg( 'maxvit_xlarge_tf_384.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_xlarge_tf_384.in21k_ft_in1k', hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_xlarge_tf_512.in21k_ft_in1k': _cfg( 'maxvit_xlarge_tf_512.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_xlarge_tf_512.in21k_ft_in1k', hf_hub_id='timm/',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
}) })
@ -1978,6 +2125,11 @@ def coatnet_rmlp_2_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs) return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_rmlp_2_rw_384(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_2_rw_384', pretrained=pretrained, **kwargs)
@register_model @register_model
def coatnet_rmlp_3_rw_224(pretrained=False, **kwargs): def coatnet_rmlp_3_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs) return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs)
@ -2068,6 +2220,16 @@ def maxvit_rmlp_small_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs) return _create_maxxvit('maxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_base_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_base_rw_384(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_base_rw_384', pretrained=pretrained, **kwargs)
@register_model @register_model
def maxvit_tiny_pm_256(pretrained=False, **kwargs): def maxvit_tiny_pm_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs) return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)
@ -2089,13 +2251,23 @@ def maxxvit_rmlp_small_rw_256(pretrained=False, **kwargs):
@register_model @register_model
def maxxvit_rmlp_base_rw_224(pretrained=False, **kwargs): def maxxvitv2_nano_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs) return _create_maxxvit('maxxvitv2_nano_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxxvitv2_rmlp_base_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxxvitv2_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
@register_model
def maxxvitv2_rmlp_base_rw_384(pretrained=False, **kwargs):
return _create_maxxvit('maxxvitv2_rmlp_base_rw_384', pretrained=pretrained, **kwargs)
@register_model @register_model
def maxxvit_rmlp_large_rw_224(pretrained=False, **kwargs): def maxxvitv2_rmlp_large_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxxvit_rmlp_large_rw_224', pretrained=pretrained, **kwargs) return _create_maxxvit('maxxvitv2_rmlp_large_rw_224', pretrained=pretrained, **kwargs)
@register_model @register_model

@ -266,9 +266,16 @@ class MobileVitBlock(nn.Module):
self.transformer = nn.Sequential(*[ self.transformer = nn.Sequential(*[
TransformerBlock( TransformerBlock(
transformer_dim, mlp_ratio=mlp_ratio, num_heads=num_heads, qkv_bias=True, transformer_dim,
attn_drop=attn_drop, drop=drop, drop_path=drop_path_rate, mlp_ratio=mlp_ratio,
act_layer=layers.act, norm_layer=transformer_norm_layer) num_heads=num_heads,
qkv_bias=True,
attn_drop=attn_drop,
drop=drop,
drop_path=drop_path_rate,
act_layer=layers.act,
norm_layer=transformer_norm_layer,
)
for _ in range(transformer_depth) for _ in range(transformer_depth)
]) ])
self.norm = transformer_norm_layer(transformer_dim) self.norm = transformer_norm_layer(transformer_dim)

@ -496,7 +496,7 @@ class RegNet(nn.Module):
self.final_conv = get_act_layer(cfg.act_layer)() if final_act else nn.Identity() self.final_conv = get_act_layer(cfg.act_layer)() if final_act else nn.Identity()
self.num_features = prev_width self.num_features = prev_width
self.head = ClassifierHead( self.head = ClassifierHead(
in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)

@ -156,8 +156,8 @@ def res2net50_26w_4s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model_args = dict( model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4), **kwargs) block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4))
return _create_res2net('res2net50_26w_4s', pretrained, **model_args) return _create_res2net('res2net50_26w_4s', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -167,8 +167,8 @@ def res2net101_26w_4s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model_args = dict( model_args = dict(
block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4), **kwargs) block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4))
return _create_res2net('res2net101_26w_4s', pretrained, **model_args) return _create_res2net('res2net101_26w_4s', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -178,8 +178,8 @@ def res2net50_26w_6s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model_args = dict( model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6), **kwargs) block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6))
return _create_res2net('res2net50_26w_6s', pretrained, **model_args) return _create_res2net('res2net50_26w_6s', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -189,8 +189,8 @@ def res2net50_26w_8s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model_args = dict( model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8), **kwargs) block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8))
return _create_res2net('res2net50_26w_8s', pretrained, **model_args) return _create_res2net('res2net50_26w_8s', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -200,8 +200,8 @@ def res2net50_48w_2s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model_args = dict( model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2), **kwargs) block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2))
return _create_res2net('res2net50_48w_2s', pretrained, **model_args) return _create_res2net('res2net50_48w_2s', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -211,8 +211,8 @@ def res2net50_14w_8s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model_args = dict( model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8), **kwargs) block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8))
return _create_res2net('res2net50_14w_8s', pretrained, **model_args) return _create_res2net('res2net50_14w_8s', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -222,5 +222,5 @@ def res2next50(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet pretrained (bool): If True, returns a model pre-trained on ImageNet
""" """
model_args = dict( model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4), **kwargs) block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4))
return _create_res2net('res2next50', pretrained, **model_args) return _create_res2net('res2next50', pretrained, **dict(model_args, **kwargs))

@ -163,8 +163,8 @@ def resnest14d(pretrained=False, **kwargs):
model_kwargs = dict( model_kwargs = dict(
block=ResNestBottleneck, layers=[1, 1, 1, 1], block=ResNestBottleneck, layers=[1, 1, 1, 1],
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest14d', pretrained=pretrained, **model_kwargs) return _create_resnest('resnest14d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model @register_model
@ -174,8 +174,8 @@ def resnest26d(pretrained=False, **kwargs):
model_kwargs = dict( model_kwargs = dict(
block=ResNestBottleneck, layers=[2, 2, 2, 2], block=ResNestBottleneck, layers=[2, 2, 2, 2],
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest26d', pretrained=pretrained, **model_kwargs) return _create_resnest('resnest26d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model @register_model
@ -186,8 +186,8 @@ def resnest50d(pretrained=False, **kwargs):
model_kwargs = dict( model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 4, 6, 3], block=ResNestBottleneck, layers=[3, 4, 6, 3],
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest50d', pretrained=pretrained, **model_kwargs) return _create_resnest('resnest50d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model @register_model
@ -198,8 +198,8 @@ def resnest101e(pretrained=False, **kwargs):
model_kwargs = dict( model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 4, 23, 3], block=ResNestBottleneck, layers=[3, 4, 23, 3],
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest101e', pretrained=pretrained, **model_kwargs) return _create_resnest('resnest101e', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model @register_model
@ -210,8 +210,8 @@ def resnest200e(pretrained=False, **kwargs):
model_kwargs = dict( model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 24, 36, 3], block=ResNestBottleneck, layers=[3, 24, 36, 3],
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest200e', pretrained=pretrained, **model_kwargs) return _create_resnest('resnest200e', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model @register_model
@ -222,8 +222,8 @@ def resnest269e(pretrained=False, **kwargs):
model_kwargs = dict( model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 30, 48, 8], block=ResNestBottleneck, layers=[3, 30, 48, 8],
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest269e', pretrained=pretrained, **model_kwargs) return _create_resnest('resnest269e', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model @register_model
@ -233,8 +233,8 @@ def resnest50d_4s2x40d(pretrained=False, **kwargs):
model_kwargs = dict( model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 4, 6, 3], block=ResNestBottleneck, layers=[3, 4, 6, 3],
stem_type='deep', stem_width=32, avg_down=True, base_width=40, cardinality=2, stem_type='deep', stem_width=32, avg_down=True, base_width=40, cardinality=2,
block_args=dict(radix=4, avd=True, avd_first=True), **kwargs) block_args=dict(radix=4, avd=True, avd_first=True))
return _create_resnest('resnest50d_4s2x40d', pretrained=pretrained, **model_kwargs) return _create_resnest('resnest50d_4s2x40d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model @register_model
@ -244,5 +244,5 @@ def resnest50d_1s4x24d(pretrained=False, **kwargs):
model_kwargs = dict( model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 4, 6, 3], block=ResNestBottleneck, layers=[3, 4, 6, 3],
stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4, stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4,
block_args=dict(radix=1, avd=True, avd_first=True), **kwargs) block_args=dict(radix=1, avd=True, avd_first=True))
return _create_resnest('resnest50d_1s4x24d', pretrained=pretrained, **model_kwargs) return _create_resnest('resnest50d_1s4x24d', pretrained=pretrained, **dict(model_kwargs, **kwargs))

@ -704,7 +704,7 @@ class ResNet(nn.Module):
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.grad_checkpointing = False self.grad_checkpointing = False
act_layer = get_act_layer(act_layer) act_layer = get_act_layer(act_layer)
norm_layer = get_norm_layer(norm_layer) norm_layer = get_norm_layer(norm_layer)
@ -845,77 +845,72 @@ def _create_resnet(variant, pretrained=False, **kwargs):
def resnet10t(pretrained=False, **kwargs): def resnet10t(pretrained=False, **kwargs):
"""Constructs a ResNet-10-T model. """Constructs a ResNet-10-T model.
""" """
model_args = dict( model_args = dict(block=BasicBlock, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True)
block=BasicBlock, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs) return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs))
return _create_resnet('resnet10t', pretrained, **model_args)
@register_model @register_model
def resnet14t(pretrained=False, **kwargs): def resnet14t(pretrained=False, **kwargs):
"""Constructs a ResNet-14-T model. """Constructs a ResNet-14-T model.
""" """
model_args = dict( model_args = dict(block=Bottleneck, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True)
block=Bottleneck, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs) return _create_resnet('resnet14t', pretrained, **dict(model_args, **kwargs))
return _create_resnet('resnet14t', pretrained, **model_args)
@register_model @register_model
def resnet18(pretrained=False, **kwargs): def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model. """Constructs a ResNet-18 model.
""" """
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs) model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2])
return _create_resnet('resnet18', pretrained, **model_args) return _create_resnet('resnet18', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnet18d(pretrained=False, **kwargs): def resnet18d(pretrained=False, **kwargs):
"""Constructs a ResNet-18-D model. """Constructs a ResNet-18-D model.
""" """
model_args = dict( model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True)
block=BasicBlock, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, **kwargs) return _create_resnet('resnet18d', pretrained, **dict(model_args, **kwargs))
return _create_resnet('resnet18d', pretrained, **model_args)
@register_model @register_model
def resnet34(pretrained=False, **kwargs): def resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model. """Constructs a ResNet-34 model.
""" """
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs) model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3])
return _create_resnet('resnet34', pretrained, **model_args) return _create_resnet('resnet34', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnet34d(pretrained=False, **kwargs): def resnet34d(pretrained=False, **kwargs):
"""Constructs a ResNet-34-D model. """Constructs a ResNet-34-D model.
""" """
model_args = dict( model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True)
block=BasicBlock, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) return _create_resnet('resnet34d', pretrained, **dict(model_args, **kwargs))
return _create_resnet('resnet34d', pretrained, **model_args)
@register_model @register_model
def resnet26(pretrained=False, **kwargs): def resnet26(pretrained=False, **kwargs):
"""Constructs a ResNet-26 model. """Constructs a ResNet-26 model.
""" """
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], **kwargs) model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2])
return _create_resnet('resnet26', pretrained, **model_args) return _create_resnet('resnet26', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnet26t(pretrained=False, **kwargs): def resnet26t(pretrained=False, **kwargs):
"""Constructs a ResNet-26-T model. """Constructs a ResNet-26-T model.
""" """
model_args = dict( model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep_tiered', avg_down=True)
block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs) return _create_resnet('resnet26t', pretrained, **dict(model_args, **kwargs))
return _create_resnet('resnet26t', pretrained, **model_args)
@register_model @register_model
def resnet26d(pretrained=False, **kwargs): def resnet26d(pretrained=False, **kwargs):
"""Constructs a ResNet-26-D model. """Constructs a ResNet-26-D model.
""" """
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, **kwargs) model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnet26d', pretrained, **model_args) return _create_resnet('resnet26d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -923,83 +918,79 @@ def resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model. """Constructs a ResNet-50 model.
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
return _create_resnet('resnet50', pretrained, **model_args) return _create_resnet('resnet50', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnet50d(pretrained=False, **kwargs) -> ResNet: def resnet50d(pretrained=False, **kwargs) -> ResNet:
"""Constructs a ResNet-50-D model. """Constructs a ResNet-50-D model.
""" """
model_args = dict( model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True)
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) return _create_resnet('resnet50d', pretrained, **dict(model_args, **kwargs))
return _create_resnet('resnet50d', pretrained, **model_args)
@register_model @register_model
def resnet50t(pretrained=False, **kwargs): def resnet50t(pretrained=False, **kwargs):
"""Constructs a ResNet-50-T model. """Constructs a ResNet-50-T model.
""" """
model_args = dict( model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True)
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs) return _create_resnet('resnet50t', pretrained, **dict(model_args, **kwargs))
return _create_resnet('resnet50t', pretrained, **model_args)
@register_model @register_model
def resnet101(pretrained=False, **kwargs): def resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model. """Constructs a ResNet-101 model.
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3])
return _create_resnet('resnet101', pretrained, **model_args) return _create_resnet('resnet101', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnet101d(pretrained=False, **kwargs): def resnet101d(pretrained=False, **kwargs):
"""Constructs a ResNet-101-D model. """Constructs a ResNet-101-D model.
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnet101d', pretrained, **model_args) return _create_resnet('resnet101d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnet152(pretrained=False, **kwargs): def resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model. """Constructs a ResNet-152 model.
""" """
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs) model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3])
return _create_resnet('resnet152', pretrained, **model_args) return _create_resnet('resnet152', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnet152d(pretrained=False, **kwargs): def resnet152d(pretrained=False, **kwargs):
"""Constructs a ResNet-152-D model. """Constructs a ResNet-152-D model.
""" """
model_args = dict( model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True)
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) return _create_resnet('resnet152d', pretrained, **dict(model_args, **kwargs))
return _create_resnet('resnet152d', pretrained, **model_args)
@register_model @register_model
def resnet200(pretrained=False, **kwargs): def resnet200(pretrained=False, **kwargs):
"""Constructs a ResNet-200 model. """Constructs a ResNet-200 model.
""" """
model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3], **kwargs) model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3])
return _create_resnet('resnet200', pretrained, **model_args) return _create_resnet('resnet200', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnet200d(pretrained=False, **kwargs): def resnet200d(pretrained=False, **kwargs):
"""Constructs a ResNet-200-D model. """Constructs a ResNet-200-D model.
""" """
model_args = dict( model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True)
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) return _create_resnet('resnet200d', pretrained, **dict(model_args, **kwargs))
return _create_resnet('resnet200d', pretrained, **model_args)
@register_model @register_model
def tv_resnet34(pretrained=False, **kwargs): def tv_resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model with original Torchvision weights. """Constructs a ResNet-34 model with original Torchvision weights.
""" """
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs) model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3])
return _create_resnet('tv_resnet34', pretrained, **model_args) return _create_resnet('tv_resnet34', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1007,23 +998,23 @@ def tv_resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model with original Torchvision weights. """Constructs a ResNet-50 model with original Torchvision weights.
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
return _create_resnet('tv_resnet50', pretrained, **model_args) return _create_resnet('tv_resnet50', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def tv_resnet101(pretrained=False, **kwargs): def tv_resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model w/ Torchvision pretrained weights. """Constructs a ResNet-101 model w/ Torchvision pretrained weights.
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3])
return _create_resnet('tv_resnet101', pretrained, **model_args) return _create_resnet('tv_resnet101', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def tv_resnet152(pretrained=False, **kwargs): def tv_resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model w/ Torchvision pretrained weights. """Constructs a ResNet-152 model w/ Torchvision pretrained weights.
""" """
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs) model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3])
return _create_resnet('tv_resnet152', pretrained, **model_args) return _create_resnet('tv_resnet152', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1034,8 +1025,8 @@ def wide_resnet50_2(pretrained=False, **kwargs):
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048. channels, and in Wide ResNet-50-2 has 2048-1024-2048.
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128, **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128)
return _create_resnet('wide_resnet50_2', pretrained, **model_args) return _create_resnet('wide_resnet50_2', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1045,8 +1036,8 @@ def wide_resnet101_2(pretrained=False, **kwargs):
which is twice larger in every block. The number of channels in outer 1x1 which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same. convolutions is the same.
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], base_width=128, **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], base_width=128)
return _create_resnet('wide_resnet101_2', pretrained, **model_args) return _create_resnet('wide_resnet101_2', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1061,8 +1052,8 @@ def resnet50_gn(pretrained=False, **kwargs):
def resnext50_32x4d(pretrained=False, **kwargs): def resnext50_32x4d(pretrained=False, **kwargs):
"""Constructs a ResNeXt50-32x4d model. """Constructs a ResNeXt50-32x4d model.
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4)
return _create_resnet('resnext50_32x4d', pretrained, **model_args) return _create_resnet('resnext50_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1071,40 +1062,40 @@ def resnext50d_32x4d(pretrained=False, **kwargs):
""" """
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
stem_width=32, stem_type='deep', avg_down=True, **kwargs) stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnext50d_32x4d', pretrained, **model_args) return _create_resnet('resnext50d_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnext101_32x4d(pretrained=False, **kwargs): def resnext101_32x4d(pretrained=False, **kwargs):
"""Constructs a ResNeXt-101 32x4d model. """Constructs a ResNeXt-101 32x4d model.
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4)
return _create_resnet('resnext101_32x4d', pretrained, **model_args) return _create_resnet('resnext101_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnext101_32x8d(pretrained=False, **kwargs): def resnext101_32x8d(pretrained=False, **kwargs):
"""Constructs a ResNeXt-101 32x8d model. """Constructs a ResNeXt-101 32x8d model.
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8)
return _create_resnet('resnext101_32x8d', pretrained, **model_args) return _create_resnet('resnext101_32x8d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnext101_64x4d(pretrained=False, **kwargs): def resnext101_64x4d(pretrained=False, **kwargs):
"""Constructs a ResNeXt101-64x4d model. """Constructs a ResNeXt101-64x4d model.
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4)
return _create_resnet('resnext101_64x4d', pretrained, **model_args) return _create_resnet('resnext101_64x4d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def tv_resnext50_32x4d(pretrained=False, **kwargs): def tv_resnext50_32x4d(pretrained=False, **kwargs):
"""Constructs a ResNeXt50-32x4d model with original Torchvision weights. """Constructs a ResNeXt50-32x4d model with original Torchvision weights.
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4)
return _create_resnet('tv_resnext50_32x4d', pretrained, **model_args) return _create_resnet('tv_resnext50_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1114,8 +1105,8 @@ def ig_resnext101_32x8d(pretrained=False, **kwargs):
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_ `"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8)
return _create_resnet('ig_resnext101_32x8d', pretrained, **model_args) return _create_resnet('ig_resnext101_32x8d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1125,8 +1116,8 @@ def ig_resnext101_32x16d(pretrained=False, **kwargs):
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_ `"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16)
return _create_resnet('ig_resnext101_32x16d', pretrained, **model_args) return _create_resnet('ig_resnext101_32x16d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1136,8 +1127,8 @@ def ig_resnext101_32x32d(pretrained=False, **kwargs):
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_ `"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=32, **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=32)
return _create_resnet('ig_resnext101_32x32d', pretrained, **model_args) return _create_resnet('ig_resnext101_32x32d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1147,8 +1138,8 @@ def ig_resnext101_32x48d(pretrained=False, **kwargs):
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_ `"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=48, **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=48)
return _create_resnet('ig_resnext101_32x48d', pretrained, **model_args) return _create_resnet('ig_resnext101_32x48d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1157,8 +1148,8 @@ def ssl_resnet18(pretrained=False, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_ `"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
""" """
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs) model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2])
return _create_resnet('ssl_resnet18', pretrained, **model_args) return _create_resnet('ssl_resnet18', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1168,7 +1159,7 @@ def ssl_resnet50(pretrained=False, **kwargs):
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
return _create_resnet('ssl_resnet50', pretrained, **model_args) return _create_resnet('ssl_resnet50', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1177,8 +1168,8 @@ def ssl_resnext50_32x4d(pretrained=False, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_ `"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4)
return _create_resnet('ssl_resnext50_32x4d', pretrained, **model_args) return _create_resnet('ssl_resnext50_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1187,8 +1178,8 @@ def ssl_resnext101_32x4d(pretrained=False, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_ `"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4)
return _create_resnet('ssl_resnext101_32x4d', pretrained, **model_args) return _create_resnet('ssl_resnext101_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1197,8 +1188,8 @@ def ssl_resnext101_32x8d(pretrained=False, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_ `"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8)
return _create_resnet('ssl_resnext101_32x8d', pretrained, **model_args) return _create_resnet('ssl_resnext101_32x8d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1207,8 +1198,8 @@ def ssl_resnext101_32x16d(pretrained=False, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_ `"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16)
return _create_resnet('ssl_resnext101_32x16d', pretrained, **model_args) return _create_resnet('ssl_resnext101_32x16d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1218,8 +1209,8 @@ def swsl_resnet18(pretrained=False, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_ `"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
""" """
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs) model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2])
return _create_resnet('swsl_resnet18', pretrained, **model_args) return _create_resnet('swsl_resnet18', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1230,7 +1221,7 @@ def swsl_resnet50(pretrained=False, **kwargs):
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
return _create_resnet('swsl_resnet50', pretrained, **model_args) return _create_resnet('swsl_resnet50', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1240,8 +1231,8 @@ def swsl_resnext50_32x4d(pretrained=False, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_ `"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4)
return _create_resnet('swsl_resnext50_32x4d', pretrained, **model_args) return _create_resnet('swsl_resnext50_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1251,8 +1242,8 @@ def swsl_resnext101_32x4d(pretrained=False, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_ `"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4)
return _create_resnet('swsl_resnext101_32x4d', pretrained, **model_args) return _create_resnet('swsl_resnext101_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1262,8 +1253,8 @@ def swsl_resnext101_32x8d(pretrained=False, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_ `"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8)
return _create_resnet('swsl_resnext101_32x8d', pretrained, **model_args) return _create_resnet('swsl_resnext101_32x8d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1273,8 +1264,8 @@ def swsl_resnext101_32x16d(pretrained=False, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_ `"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16)
return _create_resnet('swsl_resnext101_32x16d', pretrained, **model_args) return _create_resnet('swsl_resnext101_32x16d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1285,8 +1276,8 @@ def ecaresnet26t(pretrained=False, **kwargs):
""" """
model_args = dict( model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32,
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs) stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet26t', pretrained, **model_args) return _create_resnet('ecaresnet26t', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1295,8 +1286,8 @@ def ecaresnet50d(pretrained=False, **kwargs):
""" """
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'), **kwargs) block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet50d', pretrained, **model_args) return _create_resnet('ecaresnet50d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1306,8 +1297,8 @@ def ecaresnet50d_pruned(pretrained=False, **kwargs):
""" """
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'), **kwargs) block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **model_args) return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **dict(model_args, **kwargs))
@register_model @register_model
@ -1317,8 +1308,8 @@ def ecaresnet50t(pretrained=False, **kwargs):
""" """
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32,
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs) stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet50t', pretrained, **model_args) return _create_resnet('ecaresnet50t', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1327,8 +1318,8 @@ def ecaresnetlight(pretrained=False, **kwargs):
""" """
model_args = dict( model_args = dict(
block=Bottleneck, layers=[1, 1, 11, 3], stem_width=32, avg_down=True, block=Bottleneck, layers=[1, 1, 11, 3], stem_width=32, avg_down=True,
block_args=dict(attn_layer='eca'), **kwargs) block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnetlight', pretrained, **model_args) return _create_resnet('ecaresnetlight', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1337,8 +1328,8 @@ def ecaresnet101d(pretrained=False, **kwargs):
""" """
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'), **kwargs) block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet101d', pretrained, **model_args) return _create_resnet('ecaresnet101d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1348,8 +1339,8 @@ def ecaresnet101d_pruned(pretrained=False, **kwargs):
""" """
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'), **kwargs) block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **model_args) return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **dict(model_args, **kwargs))
@register_model @register_model
@ -1358,8 +1349,8 @@ def ecaresnet200d(pretrained=False, **kwargs):
""" """
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True, block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'), **kwargs) block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet200d', pretrained, **model_args) return _create_resnet('ecaresnet200d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1368,8 +1359,8 @@ def ecaresnet269d(pretrained=False, **kwargs):
""" """
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True, block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'), **kwargs) block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet269d', pretrained, **model_args) return _create_resnet('ecaresnet269d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1380,8 +1371,8 @@ def ecaresnext26t_32x4d(pretrained=False, **kwargs):
""" """
model_args = dict( model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs) stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnext26t_32x4d', pretrained, **model_args) return _create_resnet('ecaresnext26t_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1392,54 +1383,54 @@ def ecaresnext50t_32x4d(pretrained=False, **kwargs):
""" """
model_args = dict( model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs) stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnext50t_32x4d', pretrained, **model_args) return _create_resnet('ecaresnext50t_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def seresnet18(pretrained=False, **kwargs): def seresnet18(pretrained=False, **kwargs):
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'), **kwargs) model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'))
return _create_resnet('seresnet18', pretrained, **model_args) return _create_resnet('seresnet18', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def seresnet34(pretrained=False, **kwargs): def seresnet34(pretrained=False, **kwargs):
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'), **kwargs) model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'))
return _create_resnet('seresnet34', pretrained, **model_args) return _create_resnet('seresnet34', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def seresnet50(pretrained=False, **kwargs): def seresnet50(pretrained=False, **kwargs):
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'), **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'))
return _create_resnet('seresnet50', pretrained, **model_args) return _create_resnet('seresnet50', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def seresnet50t(pretrained=False, **kwargs): def seresnet50t(pretrained=False, **kwargs):
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True, block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered',
block_args=dict(attn_layer='se'), **kwargs) avg_down=True, block_args=dict(attn_layer='se'))
return _create_resnet('seresnet50t', pretrained, **model_args) return _create_resnet('seresnet50t', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def seresnet101(pretrained=False, **kwargs): def seresnet101(pretrained=False, **kwargs):
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], block_args=dict(attn_layer='se'), **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], block_args=dict(attn_layer='se'))
return _create_resnet('seresnet101', pretrained, **model_args) return _create_resnet('seresnet101', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def seresnet152(pretrained=False, **kwargs): def seresnet152(pretrained=False, **kwargs):
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], block_args=dict(attn_layer='se'), **kwargs) model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], block_args=dict(attn_layer='se'))
return _create_resnet('seresnet152', pretrained, **model_args) return _create_resnet('seresnet152', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def seresnet152d(pretrained=False, **kwargs): def seresnet152d(pretrained=False, **kwargs):
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep',
block_args=dict(attn_layer='se'), **kwargs) avg_down=True, block_args=dict(attn_layer='se'))
return _create_resnet('seresnet152d', pretrained, **model_args) return _create_resnet('seresnet152d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1447,9 +1438,9 @@ def seresnet200d(pretrained=False, **kwargs):
"""Constructs a ResNet-200-D model with SE attn. """Constructs a ResNet-200-D model with SE attn.
""" """
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True, block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep',
block_args=dict(attn_layer='se'), **kwargs) avg_down=True, block_args=dict(attn_layer='se'))
return _create_resnet('seresnet200d', pretrained, **model_args) return _create_resnet('seresnet200d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1457,9 +1448,9 @@ def seresnet269d(pretrained=False, **kwargs):
"""Constructs a ResNet-269-D model with SE attn. """Constructs a ResNet-269-D model with SE attn.
""" """
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True, block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep',
block_args=dict(attn_layer='se'), **kwargs) avg_down=True, block_args=dict(attn_layer='se'))
return _create_resnet('seresnet269d', pretrained, **model_args) return _create_resnet('seresnet269d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1470,8 +1461,8 @@ def seresnext26d_32x4d(pretrained=False, **kwargs):
""" """
model_args = dict( model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'))
return _create_resnet('seresnext26d_32x4d', pretrained, **model_args) return _create_resnet('seresnext26d_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1482,8 +1473,8 @@ def seresnext26t_32x4d(pretrained=False, **kwargs):
""" """
model_args = dict( model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se'))
return _create_resnet('seresnext26t_32x4d', pretrained, **model_args) return _create_resnet('seresnext26t_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1499,24 +1490,24 @@ def seresnext26tn_32x4d(pretrained=False, **kwargs):
def seresnext50_32x4d(pretrained=False, **kwargs): def seresnext50_32x4d(pretrained=False, **kwargs):
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
block_args=dict(attn_layer='se'), **kwargs) block_args=dict(attn_layer='se'))
return _create_resnet('seresnext50_32x4d', pretrained, **model_args) return _create_resnet('seresnext50_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def seresnext101_32x4d(pretrained=False, **kwargs): def seresnext101_32x4d(pretrained=False, **kwargs):
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4,
block_args=dict(attn_layer='se'), **kwargs) block_args=dict(attn_layer='se'))
return _create_resnet('seresnext101_32x4d', pretrained, **model_args) return _create_resnet('seresnext101_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def seresnext101_32x8d(pretrained=False, **kwargs): def seresnext101_32x8d(pretrained=False, **kwargs):
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8,
block_args=dict(attn_layer='se'), **kwargs) block_args=dict(attn_layer='se'))
return _create_resnet('seresnext101_32x8d', pretrained, **model_args) return _create_resnet('seresnext101_32x8d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1524,32 +1515,32 @@ def seresnext101d_32x8d(pretrained=False, **kwargs):
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8,
stem_width=32, stem_type='deep', avg_down=True, stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='se'), **kwargs) block_args=dict(attn_layer='se'))
return _create_resnet('seresnext101d_32x8d', pretrained, **model_args) return _create_resnet('seresnext101d_32x8d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def senet154(pretrained=False, **kwargs): def senet154(pretrained=False, **kwargs):
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep',
down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se'), **kwargs) down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se'))
return _create_resnet('senet154', pretrained, **model_args) return _create_resnet('senet154', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnetblur18(pretrained=False, **kwargs): def resnetblur18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model with blur anti-aliasing """Constructs a ResNet-18 model with blur anti-aliasing
""" """
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d, **kwargs) model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d)
return _create_resnet('resnetblur18', pretrained, **model_args) return _create_resnet('resnetblur18', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnetblur50(pretrained=False, **kwargs): def resnetblur50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model with blur anti-aliasing """Constructs a ResNet-50 model with blur anti-aliasing
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d)
return _create_resnet('resnetblur50', pretrained, **model_args) return _create_resnet('resnetblur50', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1558,8 +1549,8 @@ def resnetblur50d(pretrained=False, **kwargs):
""" """
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs) stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnetblur50d', pretrained, **model_args) return _create_resnet('resnetblur50d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1568,16 +1559,25 @@ def resnetblur101d(pretrained=False, **kwargs):
""" """
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=BlurPool2d, block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=BlurPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs) stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnetblur101d', pretrained, **model_args) return _create_resnet('resnetblur101d', pretrained, **dict(model_args, **kwargs))
@register_model
def resnetaa34d(pretrained=False, **kwargs):
"""Constructs a ResNet-34-D model w/ avgpool anti-aliasing
"""
model_args = dict(
block=BasicBlock, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnetaa34d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnetaa50(pretrained=False, **kwargs): def resnetaa50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model with avgpool anti-aliasing """Constructs a ResNet-50 model with avgpool anti-aliasing
""" """
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, **kwargs) model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d)
return _create_resnet('resnetaa50', pretrained, **model_args) return _create_resnet('resnetaa50', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1586,8 +1586,8 @@ def resnetaa50d(pretrained=False, **kwargs):
""" """
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs) stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnetaa50d', pretrained, **model_args) return _create_resnet('resnetaa50d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1596,8 +1596,8 @@ def resnetaa101d(pretrained=False, **kwargs):
""" """
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=nn.AvgPool2d, block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs) stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnetaa101d', pretrained, **model_args) return _create_resnet('resnetaa101d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1606,8 +1606,8 @@ def seresnetaa50d(pretrained=False, **kwargs):
""" """
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'))
return _create_resnet('seresnetaa50d', pretrained, **model_args) return _create_resnet('seresnetaa50d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1617,8 +1617,8 @@ def seresnextaa101d_32x8d(pretrained=False, **kwargs):
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8,
stem_width=32, stem_type='deep', avg_down=True, aa_layer=nn.AvgPool2d, stem_width=32, stem_type='deep', avg_down=True, aa_layer=nn.AvgPool2d,
block_args=dict(attn_layer='se'), **kwargs) block_args=dict(attn_layer='se'))
return _create_resnet('seresnextaa101d_32x8d', pretrained, **model_args) return _create_resnet('seresnextaa101d_32x8d', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1630,8 +1630,8 @@ def resnetrs50(pretrained=False, **kwargs):
attn_layer = partial(get_attn('se'), rd_ratio=0.25) attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) avg_down=True, block_args=dict(attn_layer=attn_layer))
return _create_resnet('resnetrs50', pretrained, **model_args) return _create_resnet('resnetrs50', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1643,8 +1643,8 @@ def resnetrs101(pretrained=False, **kwargs):
attn_layer = partial(get_attn('se'), rd_ratio=0.25) attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) avg_down=True, block_args=dict(attn_layer=attn_layer))
return _create_resnet('resnetrs101', pretrained, **model_args) return _create_resnet('resnetrs101', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1656,8 +1656,8 @@ def resnetrs152(pretrained=False, **kwargs):
attn_layer = partial(get_attn('se'), rd_ratio=0.25) attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) avg_down=True, block_args=dict(attn_layer=attn_layer))
return _create_resnet('resnetrs152', pretrained, **model_args) return _create_resnet('resnetrs152', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1669,8 +1669,8 @@ def resnetrs200(pretrained=False, **kwargs):
attn_layer = partial(get_attn('se'), rd_ratio=0.25) attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) avg_down=True, block_args=dict(attn_layer=attn_layer))
return _create_resnet('resnetrs200', pretrained, **model_args) return _create_resnet('resnetrs200', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1682,8 +1682,8 @@ def resnetrs270(pretrained=False, **kwargs):
attn_layer = partial(get_attn('se'), rd_ratio=0.25) attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict( model_args = dict(
block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) avg_down=True, block_args=dict(attn_layer=attn_layer))
return _create_resnet('resnetrs270', pretrained, **model_args) return _create_resnet('resnetrs270', pretrained, **dict(model_args, **kwargs))
@ -1696,8 +1696,8 @@ def resnetrs350(pretrained=False, **kwargs):
attn_layer = partial(get_attn('se'), rd_ratio=0.25) attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict( model_args = dict(
block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) avg_down=True, block_args=dict(attn_layer=attn_layer))
return _create_resnet('resnetrs350', pretrained, **model_args) return _create_resnet('resnetrs350', pretrained, **dict(model_args, **kwargs))
@register_model @register_model
@ -1709,5 +1709,5 @@ def resnetrs420(pretrained=False, **kwargs):
attn_layer = partial(get_attn('se'), rd_ratio=0.25) attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict( model_args = dict(
block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) avg_down=True, block_args=dict(attn_layer=attn_layer))
return _create_resnet('resnetrs420', pretrained, **model_args) return _create_resnet('resnetrs420', pretrained, **dict(model_args, **kwargs))

@ -746,86 +746,83 @@ def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs):
@register_model @register_model
def resnetv2_50(pretrained=False, **kwargs): def resnetv2_50(pretrained=False, **kwargs):
return _create_resnetv2( model_args = dict(layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d)
'resnetv2_50', pretrained=pretrained, return _create_resnetv2('resnetv2_50', pretrained=pretrained, **dict(model_args, **kwargs))
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
@register_model @register_model
def resnetv2_50d(pretrained=False, **kwargs): def resnetv2_50d(pretrained=False, **kwargs):
return _create_resnetv2( model_args = dict(
'resnetv2_50d', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
stem_type='deep', avg_down=True, **kwargs) stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_50d', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnetv2_50t(pretrained=False, **kwargs): def resnetv2_50t(pretrained=False, **kwargs):
return _create_resnetv2( model_args = dict(
'resnetv2_50t', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
stem_type='tiered', avg_down=True, **kwargs) stem_type='tiered', avg_down=True)
return _create_resnetv2('resnetv2_50t', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnetv2_101(pretrained=False, **kwargs): def resnetv2_101(pretrained=False, **kwargs):
return _create_resnetv2( model_args = dict(layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d)
'resnetv2_101', pretrained=pretrained, return _create_resnetv2('resnetv2_101', pretrained=pretrained, **dict(model_args, **kwargs))
layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
@register_model @register_model
def resnetv2_101d(pretrained=False, **kwargs): def resnetv2_101d(pretrained=False, **kwargs):
return _create_resnetv2( model_args = dict(
'resnetv2_101d', pretrained=pretrained,
layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
stem_type='deep', avg_down=True, **kwargs) stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_101d', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnetv2_152(pretrained=False, **kwargs): def resnetv2_152(pretrained=False, **kwargs):
return _create_resnetv2( model_args = dict(layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d)
'resnetv2_152', pretrained=pretrained, return _create_resnetv2('resnetv2_152', pretrained=pretrained, **dict(model_args, **kwargs))
layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
@register_model @register_model
def resnetv2_152d(pretrained=False, **kwargs): def resnetv2_152d(pretrained=False, **kwargs):
return _create_resnetv2( model_args = dict(
'resnetv2_152d', pretrained=pretrained,
layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
stem_type='deep', avg_down=True, **kwargs) stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_152d', pretrained=pretrained, **dict(model_args, **kwargs))
# Experimental configs (may change / be removed) # Experimental configs (may change / be removed)
@register_model @register_model
def resnetv2_50d_gn(pretrained=False, **kwargs): def resnetv2_50d_gn(pretrained=False, **kwargs):
return _create_resnetv2( model_args = dict(
'resnetv2_50d_gn', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=GroupNormAct, layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=GroupNormAct,
stem_type='deep', avg_down=True, **kwargs) stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_50d_gn', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnetv2_50d_evob(pretrained=False, **kwargs): def resnetv2_50d_evob(pretrained=False, **kwargs):
return _create_resnetv2( model_args = dict(
'resnetv2_50d_evob', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dB0, layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dB0,
stem_type='deep', avg_down=True, zero_init_last=True, **kwargs) stem_type='deep', avg_down=True, zero_init_last=True)
return _create_resnetv2('resnetv2_50d_evob', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnetv2_50d_evos(pretrained=False, **kwargs): def resnetv2_50d_evos(pretrained=False, **kwargs):
return _create_resnetv2( model_args = dict(
'resnetv2_50d_evos', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dS0, layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dS0,
stem_type='deep', avg_down=True, **kwargs) stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_50d_evos', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model @register_model
def resnetv2_50d_frn(pretrained=False, **kwargs): def resnetv2_50d_frn(pretrained=False, **kwargs):
return _create_resnetv2( model_args = dict(
'resnetv2_50d_frn', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=FilterResponseNormTlu2d, layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=FilterResponseNormTlu2d,
stem_type='deep', avg_down=True, **kwargs) stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_50d_frn', pretrained=pretrained, **dict(model_args, **kwargs))

@ -1029,6 +1029,10 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K', hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K',
hf_hub_filename='open_clip_pytorch_model.bin', hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_gigantic_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),
'vit_base_patch32_clip_224.openai': _cfg( 'vit_base_patch32_clip_224.openai': _cfg(
hf_hub_id='timm/', hf_hub_id='timm/',
@ -1498,6 +1502,17 @@ def vit_giant_patch14_clip_224(pretrained=False, **kwargs):
return model return model
@register_model
def vit_gigantic_patch14_clip_224(pretrained=False, **kwargs):
""" ViT-bigG model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
Pretrained weights from CLIP image tower.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_gigantic_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
# Experimental models below # Experimental models below
@register_model @register_model

@ -216,7 +216,7 @@ class XceptionAligned(nn.Module):
num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))] num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))]
self.act = act_layer(inplace=True) if preact else nn.Identity() self.act = act_layer(inplace=True) if preact else nn.Identity()
self.head = ClassifierHead( self.head = ClassifierHead(
in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
@torch.jit.ignore @torch.jit.ignore
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):

@ -7,6 +7,8 @@ import fnmatch
import torch import torch
from torchvision.ops.misc import FrozenBatchNorm2d from torchvision.ops.misc import FrozenBatchNorm2d
from timm.layers import BatchNormAct2d, SyncBatchNormAct, FrozenBatchNormAct2d,\
freeze_batch_norm_2d, unfreeze_batch_norm_2d
from .model_ema import ModelEma from .model_ema import ModelEma
@ -100,70 +102,6 @@ def extract_spp_stats(
return hook.stats return hook.stats
def freeze_batch_norm_2d(module):
"""
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
returned. Otherwise, the module is walked recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
res = FrozenBatchNorm2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = freeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def unfreeze_batch_norm_2d(module):
"""
Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance
of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked
recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
if isinstance(module, FrozenBatchNorm2d):
res = torch.nn.BatchNorm2d(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = unfreeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'): def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'):
""" """
Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is
@ -179,7 +117,12 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
""" """
assert mode in ["freeze", "unfreeze"], '`mode` must be one of "freeze" or "unfreeze"' assert mode in ["freeze", "unfreeze"], '`mode` must be one of "freeze" or "unfreeze"'
if isinstance(root_module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): if isinstance(root_module, (
torch.nn.modules.batchnorm.BatchNorm2d,
torch.nn.modules.batchnorm.SyncBatchNorm,
BatchNormAct2d,
SyncBatchNormAct,
)):
# Raise assertion here because we can't convert it in place # Raise assertion here because we can't convert it in place
raise AssertionError( raise AssertionError(
"You have provided a batch norm layer as the `root module`. Please use " "You have provided a batch norm layer as the `root module`. Please use "
@ -213,13 +156,18 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
# It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't # It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't
# convert it in place, but will return the converted result. In this case `res` holds the converted # convert it in place, but will return the converted result. In this case `res` holds the converted
# result and we may try to re-assign the named module # result and we may try to re-assign the named module
if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): if isinstance(m, (
torch.nn.modules.batchnorm.BatchNorm2d,
torch.nn.modules.batchnorm.SyncBatchNorm,
BatchNormAct2d,
SyncBatchNormAct,
)):
_add_submodule(root_module, n, res) _add_submodule(root_module, n, res)
# Unfreeze batch norm # Unfreeze batch norm
else: else:
res = unfreeze_batch_norm_2d(m) res = unfreeze_batch_norm_2d(m)
# Ditto. See note above in mode == 'freeze' branch # Ditto. See note above in mode == 'freeze' branch
if isinstance(m, FrozenBatchNorm2d): if isinstance(m, (FrozenBatchNorm2d, FrozenBatchNormAct2d)):
_add_submodule(root_module, n, res) _add_submodule(root_module, n, res)

@ -1 +1 @@
__version__ = '0.8.5dev0' __version__ = '0.8.8dev0'

Loading…
Cancel
Save