Merge branch 'rwightman:main' into metaformer_baselines_for_vision

pull/1647/head
Fredo Guan 1 year ago committed by GitHub
commit 0b1f84142f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -40,9 +40,10 @@ jobs:
- name: Install torch on ubuntu
if: startsWith(matrix.os, 'ubuntu')
run: |
pip install --no-cache-dir torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
sudo sed -i 's/azure\.//' /etc/apt/sources.list
sudo apt update
sudo apt install -y google-perftools
pip install --no-cache-dir torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install requirements
run: |
pip install -r requirements.txt

@ -24,6 +24,65 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
* ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗
* Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch.
### Jan 20, 2023
* Add two convnext 12k -> 1k fine-tunes at 384x384
* `convnext_tiny.in12k_ft_in1k_384` - 85.1 @ 384
* `convnext_small.in12k_ft_in1k_384` - 86.2 @ 384
* Push all MaxxViT weights to HF hub, and add new ImageNet-12k -> 1k fine-tunes for `rw` base MaxViT and CoAtNet 1/2 models
|model |top1 |top5 |samples / sec |Params (M) |GMAC |Act (M)|
|------------------------------------------------------------------------------------------------------------------------|----:|----:|--------------:|--------------:|-----:|------:|
|[maxvit_xlarge_tf_512.in21k_ft_in1k](https://huggingface.co/timm/maxvit_xlarge_tf_512.in21k_ft_in1k) |88.53|98.64| 21.76| 475.77|534.14|1413.22|
|[maxvit_xlarge_tf_384.in21k_ft_in1k](https://huggingface.co/timm/maxvit_xlarge_tf_384.in21k_ft_in1k) |88.32|98.54| 42.53| 475.32|292.78| 668.76|
|[maxvit_base_tf_512.in21k_ft_in1k](https://huggingface.co/timm/maxvit_base_tf_512.in21k_ft_in1k) |88.20|98.53| 50.87| 119.88|138.02| 703.99|
|[maxvit_large_tf_512.in21k_ft_in1k](https://huggingface.co/timm/maxvit_large_tf_512.in21k_ft_in1k) |88.04|98.40| 36.42| 212.33|244.75| 942.15|
|[maxvit_large_tf_384.in21k_ft_in1k](https://huggingface.co/timm/maxvit_large_tf_384.in21k_ft_in1k) |87.98|98.56| 71.75| 212.03|132.55| 445.84|
|[maxvit_base_tf_384.in21k_ft_in1k](https://huggingface.co/timm/maxvit_base_tf_384.in21k_ft_in1k) |87.92|98.54| 104.71| 119.65| 73.80| 332.90|
|[maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k](https://huggingface.co/timm/maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k) |87.81|98.37| 106.55| 116.14| 70.97| 318.95|
|[maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k](https://huggingface.co/timm/maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k) |87.47|98.37| 149.49| 116.09| 72.98| 213.74|
|[coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k) |87.39|98.31| 160.80| 73.88| 47.69| 209.43|
|[maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k) |86.89|98.02| 375.86| 116.14| 23.15| 92.64|
|[maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k) |86.64|98.02| 501.03| 116.09| 24.20| 62.77|
|[maxvit_base_tf_512.in1k](https://huggingface.co/timm/maxvit_base_tf_512.in1k) |86.60|97.92| 50.75| 119.88|138.02| 703.99|
|[coatnet_2_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_2_rw_224.sw_in12k_ft_in1k) |86.57|97.89| 631.88| 73.87| 15.09| 49.22|
|[maxvit_large_tf_512.in1k](https://huggingface.co/timm/maxvit_large_tf_512.in1k) |86.52|97.88| 36.04| 212.33|244.75| 942.15|
|[coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k) |86.49|97.90| 620.58| 73.88| 15.18| 54.78|
|[maxvit_base_tf_384.in1k](https://huggingface.co/timm/maxvit_base_tf_384.in1k) |86.29|97.80| 101.09| 119.65| 73.80| 332.90|
|[maxvit_large_tf_384.in1k](https://huggingface.co/timm/maxvit_large_tf_384.in1k) |86.23|97.69| 70.56| 212.03|132.55| 445.84|
|[maxvit_small_tf_512.in1k](https://huggingface.co/timm/maxvit_small_tf_512.in1k) |86.10|97.76| 88.63| 69.13| 67.26| 383.77|
|[maxvit_tiny_tf_512.in1k](https://huggingface.co/timm/maxvit_tiny_tf_512.in1k) |85.67|97.58| 144.25| 31.05| 33.49| 257.59|
|[maxvit_small_tf_384.in1k](https://huggingface.co/timm/maxvit_small_tf_384.in1k) |85.54|97.46| 188.35| 69.02| 35.87| 183.65|
|[maxvit_tiny_tf_384.in1k](https://huggingface.co/timm/maxvit_tiny_tf_384.in1k) |85.11|97.38| 293.46| 30.98| 17.53| 123.42|
|[maxvit_large_tf_224.in1k](https://huggingface.co/timm/maxvit_large_tf_224.in1k) |84.93|96.97| 247.71| 211.79| 43.68| 127.35|
|[coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k) |84.90|96.96| 1025.45| 41.72| 8.11| 40.13|
|[maxvit_base_tf_224.in1k](https://huggingface.co/timm/maxvit_base_tf_224.in1k) |84.85|96.99| 358.25| 119.47| 24.04| 95.01|
|[maxxvit_rmlp_small_rw_256.sw_in1k](https://huggingface.co/timm/maxxvit_rmlp_small_rw_256.sw_in1k) |84.63|97.06| 575.53| 66.01| 14.67| 58.38|
|[coatnet_rmlp_2_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_rmlp_2_rw_224.sw_in1k) |84.61|96.74| 625.81| 73.88| 15.18| 54.78|
|[maxvit_rmlp_small_rw_224.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_small_rw_224.sw_in1k) |84.49|96.76| 693.82| 64.90| 10.75| 49.30|
|[maxvit_small_tf_224.in1k](https://huggingface.co/timm/maxvit_small_tf_224.in1k) |84.43|96.83| 647.96| 68.93| 11.66| 53.17|
|[maxvit_rmlp_tiny_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_tiny_rw_256.sw_in1k) |84.23|96.78| 807.21| 29.15| 6.77| 46.92|
|[coatnet_1_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_1_rw_224.sw_in1k) |83.62|96.38| 989.59| 41.72| 8.04| 34.60|
|[maxvit_tiny_rw_224.sw_in1k](https://huggingface.co/timm/maxvit_tiny_rw_224.sw_in1k) |83.50|96.50| 1100.53| 29.06| 5.11| 33.11|
|[maxvit_tiny_tf_224.in1k](https://huggingface.co/timm/maxvit_tiny_tf_224.in1k) |83.41|96.59| 1004.94| 30.92| 5.60| 35.78|
|[coatnet_rmlp_1_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_rmlp_1_rw_224.sw_in1k) |83.36|96.45| 1093.03| 41.69| 7.85| 35.47|
|[maxxvitv2_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxxvitv2_nano_rw_256.sw_in1k) |83.11|96.33| 1276.88| 23.70| 6.26| 23.05|
|[maxxvit_rmlp_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxxvit_rmlp_nano_rw_256.sw_in1k) |83.03|96.34| 1341.24| 16.78| 4.37| 26.05|
|[maxvit_rmlp_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_nano_rw_256.sw_in1k) |82.96|96.26| 1283.24| 15.50| 4.47| 31.92|
|[maxvit_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_nano_rw_256.sw_in1k) |82.93|96.23| 1218.17| 15.45| 4.46| 30.28|
|[coatnet_bn_0_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_bn_0_rw_224.sw_in1k) |82.39|96.19| 1600.14| 27.44| 4.67| 22.04|
|[coatnet_0_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_0_rw_224.sw_in1k) |82.39|95.84| 1831.21| 27.44| 4.43| 18.73|
|[coatnet_rmlp_nano_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_rmlp_nano_rw_224.sw_in1k) |82.05|95.87| 2109.09| 15.15| 2.62| 20.34|
|[coatnext_nano_rw_224.sw_in1k](https://huggingface.co/timm/coatnext_nano_rw_224.sw_in1k) |81.95|95.92| 2525.52| 14.70| 2.47| 12.80|
|[coatnet_nano_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_nano_rw_224.sw_in1k) |81.70|95.64| 2344.52| 15.14| 2.41| 15.41|
|[maxvit_rmlp_pico_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_pico_rw_256.sw_in1k) |80.53|95.21| 1594.71| 7.52| 1.85| 24.86|
### Jan 11, 2023
* Update ConvNeXt ImageNet-12k pretrain series w/ two new fine-tuned weights (and pre FT `.in12k` tags)
* `convnext_nano.in12k_ft_in1k` - 82.3 @ 224, 82.9 @ 288 (previously released)
* `convnext_tiny.in12k_ft_in1k` - 84.2 @ 224, 84.5 @ 288
* `convnext_small.in12k_ft_in1k` - 85.2 @ 224, 85.3 @ 288
### Jan 6, 2023
* Finally got around to adding `--model-kwargs` and `--opt-kwargs` to scripts to pass through rare args directly to model classes from cmd line
* `train.py /imagenet --model resnet50 --amp --model-kwargs output_stride=16 act_layer=silu`

@ -38,7 +38,7 @@ An ImageNet test set of 10,000 images sampled from new images roughly 10 years a
### ImageNet-Adversarial - [`results-imagenet-a.csv`](results-imagenet-a.csv)
A collection of 7500 images covering 200 of the 1000 ImageNet classes. Images are naturally occuring adversarial examples that confuse typical ImageNet classifiers. This is a challenging dataset, your typical ResNet-50 will score 0% top-1.
A collection of 7500 images covering 200 of the 1000 ImageNet classes. Images are naturally occurring adversarial examples that confuse typical ImageNet classifiers. This is a challenging dataset, your typical ResNet-50 will score 0% top-1.
For clean validation with same 200 classes, see [`results-imagenet-a-clean.csv`](results-imagenet-a-clean.csv)

@ -27,7 +27,7 @@ NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*', 'flexivit*'
'eva_*', 'flexivit*'
]
NUM_NON_STD = len(NON_STD_FILTERS)
@ -38,7 +38,7 @@ if 'GITHUB_ACTIONS' in os.environ:
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
'swin*giant*', 'convnextv2_huge*']
'swin*giant*', 'convnextv2_huge*', 'maxvit_xlarge*', 'davit_giant', 'davit_huge']
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*']
else:
EXCLUDE_FILTERS = []
@ -53,7 +53,7 @@ MAX_JIT_SIZE = 320
TARGET_FFEAT_SIZE = 96
MAX_FFEAT_SIZE = 256
TARGET_FWD_FX_SIZE = 128
MAX_FWD_FX_SIZE = 224
MAX_FWD_FX_SIZE = 256
TARGET_BWD_FX_SIZE = 128
MAX_BWD_FX_SIZE = 224
@ -269,7 +269,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
EXCLUDE_JIT_FILTERS = [
'*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable
'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point
'dla*', 'hrnet*', 'ghostnet*' # hopefully fix at some point
'vit_large_*', 'vit_huge_*', 'vit_gi*',
]

@ -1,6 +1,6 @@
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
rand_augment_transform, auto_augment_transform
from .config import resolve_data_config
from .config import resolve_data_config, resolve_model_data_config
from .constants import *
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset

@ -6,16 +6,18 @@ _logger = logging.getLogger(__name__)
def resolve_data_config(
args,
default_cfg=None,
args=None,
pretrained_cfg=None,
model=None,
use_test_size=False,
verbose=False
):
new_config = {}
default_cfg = default_cfg or {}
if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
default_cfg = model.default_cfg
assert model or args or pretrained_cfg, "At least one of model, args, or pretrained_cfg required for data config."
args = args or {}
pretrained_cfg = pretrained_cfg or {}
if not pretrained_cfg and model is not None and hasattr(model, 'pretrained_cfg'):
pretrained_cfg = model.pretrained_cfg
data_config = {}
# Resolve input/image size
in_chans = 3
@ -32,65 +34,94 @@ def resolve_data_config(
assert isinstance(args['img_size'], int)
input_size = (in_chans, args['img_size'], args['img_size'])
else:
if use_test_size and default_cfg.get('test_input_size', None) is not None:
input_size = default_cfg['test_input_size']
elif default_cfg.get('input_size', None) is not None:
input_size = default_cfg['input_size']
new_config['input_size'] = input_size
if use_test_size and pretrained_cfg.get('test_input_size', None) is not None:
input_size = pretrained_cfg['test_input_size']
elif pretrained_cfg.get('input_size', None) is not None:
input_size = pretrained_cfg['input_size']
data_config['input_size'] = input_size
# resolve interpolation method
new_config['interpolation'] = 'bicubic'
data_config['interpolation'] = 'bicubic'
if args.get('interpolation', None):
new_config['interpolation'] = args['interpolation']
elif default_cfg.get('interpolation', None):
new_config['interpolation'] = default_cfg['interpolation']
data_config['interpolation'] = args['interpolation']
elif pretrained_cfg.get('interpolation', None):
data_config['interpolation'] = pretrained_cfg['interpolation']
# resolve dataset + model mean for normalization
new_config['mean'] = IMAGENET_DEFAULT_MEAN
data_config['mean'] = IMAGENET_DEFAULT_MEAN
if args.get('mean', None) is not None:
mean = tuple(args['mean'])
if len(mean) == 1:
mean = tuple(list(mean) * in_chans)
else:
assert len(mean) == in_chans
new_config['mean'] = mean
elif default_cfg.get('mean', None):
new_config['mean'] = default_cfg['mean']
data_config['mean'] = mean
elif pretrained_cfg.get('mean', None):
data_config['mean'] = pretrained_cfg['mean']
# resolve dataset + model std deviation for normalization
new_config['std'] = IMAGENET_DEFAULT_STD
data_config['std'] = IMAGENET_DEFAULT_STD
if args.get('std', None) is not None:
std = tuple(args['std'])
if len(std) == 1:
std = tuple(list(std) * in_chans)
else:
assert len(std) == in_chans
new_config['std'] = std
elif default_cfg.get('std', None):
new_config['std'] = default_cfg['std']
data_config['std'] = std
elif pretrained_cfg.get('std', None):
data_config['std'] = pretrained_cfg['std']
# resolve default inference crop
crop_pct = DEFAULT_CROP_PCT
if args.get('crop_pct', None):
crop_pct = args['crop_pct']
else:
if use_test_size and default_cfg.get('test_crop_pct', None):
crop_pct = default_cfg['test_crop_pct']
elif default_cfg.get('crop_pct', None):
crop_pct = default_cfg['crop_pct']
new_config['crop_pct'] = crop_pct
if use_test_size and pretrained_cfg.get('test_crop_pct', None):
crop_pct = pretrained_cfg['test_crop_pct']
elif pretrained_cfg.get('crop_pct', None):
crop_pct = pretrained_cfg['crop_pct']
data_config['crop_pct'] = crop_pct
# resolve default crop percentage
crop_mode = DEFAULT_CROP_MODE
if args.get('crop_mode', None):
crop_mode = args['crop_mode']
elif default_cfg.get('crop_mode', None):
crop_mode = default_cfg['crop_mode']
new_config['crop_mode'] = crop_mode
elif pretrained_cfg.get('crop_mode', None):
crop_mode = pretrained_cfg['crop_mode']
data_config['crop_mode'] = crop_mode
if verbose:
_logger.info('Data processing configuration for current model + dataset:')
for n, v in new_config.items():
for n, v in data_config.items():
_logger.info('\t%s: %s' % (n, str(v)))
return new_config
return data_config
def resolve_model_data_config(
model,
args=None,
pretrained_cfg=None,
use_test_size=False,
verbose=False,
):
""" Resolve Model Data Config
This is equivalent to resolve_data_config() but with arguments re-ordered to put model first.
Args:
model (nn.Module): the model instance
args (dict): command line arguments / configuration in dict form (overrides pretrained_cfg)
pretrained_cfg (dict): pretrained model config (overrides pretrained_cfg attached to model)
use_test_size (bool): use the test time input resolution (if one exists) instead of default train resolution
verbose (bool): enable extra logging of resolved values
Returns:
dictionary of config
"""
return resolve_data_config(
args=args,
pretrained_cfg=pretrained_cfg,
model=model,
use_test_size=use_test_size,
verbose=verbose,
)

@ -29,7 +29,8 @@ from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
from .padding import get_padding, get_same_padding, pad_same
from .patch_embed import PatchEmbed, resample_patch_embed
from .pool2d_same import AvgPool2dSame, create_pool2d

@ -38,13 +38,24 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False
class ClassifierHead(nn.Module):
"""Classifier head w/ configurable global pooling and dropout."""
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
def __init__(self, in_features, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
super(ClassifierHead, self).__init__()
self.drop_rate = drop_rate
self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv)
self.in_features = in_features
self.use_conv = use_conv
self.global_pool, num_pooled_features = _create_pool(in_features, num_classes, pool_type, use_conv=use_conv)
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
def reset(self, num_classes, global_pool=None):
if global_pool is not None:
if global_pool != self.global_pool.pool_type:
self.global_pool, _ = _create_pool(self.in_features, num_classes, global_pool, use_conv=self.use_conv)
self.flatten = nn.Flatten(1) if self.use_conv and global_pool else nn.Identity()
num_pooled_features = self.in_features * self.global_pool.feat_mult()
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=self.use_conv)
def forward(self, x, pre_logits: bool = False):
x = self.global_pool(x)
if self.drop_rate:

@ -17,6 +17,7 @@ from typing import Union, List, Optional, Any
import torch
from torch import nn as nn
from torch.nn import functional as F
from torchvision.ops.misc import FrozenBatchNorm2d
from .create_act import get_act_layer
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
@ -77,7 +78,7 @@ class BatchNormAct2d(nn.BatchNorm2d):
if self.training and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None
if self.num_batches_tracked is not None: # type: ignore[has-type]
self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type]
self.num_batches_tracked.add_(1) # type: ignore[has-type]
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
@ -169,6 +170,159 @@ def convert_sync_batchnorm(module, process_group=None):
return module_output
class FrozenBatchNormAct2d(torch.nn.Module):
"""
BatchNormAct2d where the batch statistics and the affine parameters are fixed
Args:
num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)``
eps (float): a value added to the denominator for numerical stability. Default: 1e-5
"""
def __init__(
self,
num_features: int,
eps: float = 1e-5,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super().__init__()
self.eps = eps
self.register_buffer("weight", torch.ones(num_features))
self.register_buffer("bias", torch.zeros(num_features))
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features))
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
def _load_from_state_dict(
self,
state_dict: dict,
prefix: str,
local_metadata: dict,
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
scale = w * (rv + self.eps).rsqrt()
bias = b - rm * scale
x = x * scale + bias
x = self.act(self.drop(x))
return x
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps}, act={self.act})"
def freeze_batch_norm_2d(module):
"""
Converts all `BatchNorm2d` and `SyncBatchNorm` or `BatchNormAct2d` and `SyncBatchNormAct2d` layers
of provided module into `FrozenBatchNorm2d` or `FrozenBatchNormAct2d` respectively.
Args:
module (torch.nn.Module): Any PyTorch module.
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
if isinstance(module, (BatchNormAct2d, SyncBatchNormAct)):
res = FrozenBatchNormAct2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
res.drop = module.drop
res.act = module.act
elif isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
res = FrozenBatchNorm2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = freeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def unfreeze_batch_norm_2d(module):
"""
Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance
of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked
recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
if isinstance(module, FrozenBatchNormAct2d):
res = BatchNormAct2d(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
res.drop = module.drop
res.act = module.act
elif isinstance(module, FrozenBatchNorm2d):
res = torch.nn.BatchNorm2d(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = unfreeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def _num_groups(num_channels, num_groups, group_size):
if group_size:
assert num_channels % group_size == 0
@ -179,10 +333,54 @@ def _num_groups(num_channels, num_groups, group_size):
class GroupNormAct(nn.GroupNorm):
# NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
def __init__(
self, num_channels, num_groups=32, eps=1e-5, affine=True, group_size=None,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
self,
num_channels,
num_groups=32,
eps=1e-5,
affine=True,
group_size=None,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super(GroupNormAct, self).__init__(
_num_groups(num_channels, num_groups, group_size), num_channels, eps=eps, affine=affine)
_num_groups(num_channels, num_groups, group_size),
num_channels,
eps=eps,
affine=affine,
)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
self._fast_norm = is_fast_norm()
def forward(self, x):
if self._fast_norm:
x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
x = self.drop(x)
x = self.act(x)
return x
class GroupNorm1Act(nn.GroupNorm):
def __init__(
self,
num_channels,
eps=1e-5,
affine=True,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super(GroupNorm1Act, self).__init__(1, num_channels, eps=eps, affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
@ -204,8 +402,15 @@ class GroupNormAct(nn.GroupNorm):
class LayerNormAct(nn.LayerNorm):
def __init__(
self, normalization_shape: Union[int, List[int], torch.Size], eps=1e-5, affine=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
self,
normalization_shape: Union[int, List[int], torch.Size],
eps=1e-5,
affine=True,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
@ -228,8 +433,15 @@ class LayerNormAct(nn.LayerNorm):
class LayerNormAct2d(nn.LayerNorm):
def __init__(
self, num_channels, eps=1e-5, affine=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
self,
num_channels,
eps=1e-5,
affine=True,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module

@ -8,6 +8,7 @@ from .convmixer import *
from .convnext import *
from .crossvit import *
from .cspnet import *
from .davit import *
from .deit import *
from .densenet import *
from .dla import *

@ -179,11 +179,11 @@ def load_pretrained(
return
if filter_fn is not None:
# for backwards compat with filter fn that take one arg, try one first, the two
try:
state_dict = filter_fn(state_dict)
except TypeError:
state_dict = filter_fn(state_dict, model)
except TypeError as e:
# for backwards compat with filter fn that take one arg
state_dict = filter_fn(state_dict)
input_convs = pretrained_cfg.get('first_conv', None)
if input_convs is not None and in_chans != 3:

@ -236,20 +236,7 @@ def push_to_hf_hub(
model_card = model_card or {}
model_name = repo_id.split('/')[-1]
readme_path = Path(tmpdir) / "README.md"
readme_text = "---\n"
readme_text += "tags:\n- image-classification\n- timm\n"
readme_text += "library_tag: timm\n"
readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n"
readme_text += "---\n"
readme_text += f"# Model card for {model_name}\n"
if 'description' in model_card:
readme_text += f"\n{model_card['description']}\n"
if 'details' in model_card:
readme_text += f"\n## Model Details\n"
for k, v in model_card['details'].items():
readme_text += f"- **{k}:** {v}\n"
if 'citation' in model_card:
readme_text += f"\n## Citation\n```\n{model_card['citation']}```\n"
readme_text = generate_readme(model_card, model_name)
readme_path.write_text(readme_text)
# Upload model and return
@ -260,3 +247,51 @@ def push_to_hf_hub(
create_pr=create_pr,
commit_message=commit_message,
)
def generate_readme(model_card, model_name):
readme_text = "---\n"
readme_text += "tags:\n- image-classification\n- timm\n"
readme_text += "library_tag: timm\n"
readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n"
if 'details' in model_card and 'Dataset' in model_card['details']:
readme_text += 'datasets:\n'
readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
if 'Pretrain Dataset' in model_card['details']:
readme_text += f"- {model_card['details']['Pretrain Dataset'].lower()}\n"
readme_text += "---\n"
readme_text += f"# Model card for {model_name}\n"
if 'description' in model_card:
readme_text += f"\n{model_card['description']}\n"
if 'details' in model_card:
readme_text += f"\n## Model Details\n"
for k, v in model_card['details'].items():
if isinstance(v, (list, tuple)):
readme_text += f"- **{k}:**\n"
for vi in v:
readme_text += f" - {vi}\n"
elif isinstance(v, dict):
readme_text += f"- **{k}:**\n"
for ki, vi in v.items():
readme_text += f" - {ki}: {vi}\n"
else:
readme_text += f"- **{k}:** {v}\n"
if 'usage' in model_card:
readme_text += f"\n## Model Usage\n"
readme_text += model_card['usage']
readme_text += '\n'
if 'comparison' in model_card:
readme_text += f"\n## Model Comparison\n"
readme_text += model_card['comparison']
readme_text += '\n'
if 'citation' in model_card:
readme_text += f"\n## Citation\n"
if not isinstance(model_card['citation'], (list, tuple)):
citations = [model_card['citation']]
else:
citations = model_card['citation']
for c in citations:
readme_text += f"```bibtex\n{c}\n```\n"
return readme_text

@ -43,7 +43,7 @@ from functools import partial
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
from ._builder import build_model_with_cfg
@ -205,6 +205,7 @@ class ConvNeXt(nn.Module):
use_grn=False,
act_layer='gelu',
norm_layer=None,
norm_eps=None,
drop_rate=0.,
drop_path_rate=0.,
):
@ -236,10 +237,15 @@ class ConvNeXt(nn.Module):
if norm_layer is None:
norm_layer = LayerNorm2d
norm_layer_cl = norm_layer if conv_mlp else LayerNorm
if norm_eps is not None:
norm_layer = partial(norm_layer, eps=norm_eps)
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
else:
assert conv_mlp,\
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
norm_layer_cl = norm_layer
if norm_eps is not None:
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
self.num_classes = num_classes
self.drop_rate = drop_rate
@ -250,7 +256,7 @@ class ConvNeXt(nn.Module):
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias),
norm_layer(dims[0])
norm_layer(dims[0]),
)
stem_stride = patch_size
else:
@ -301,11 +307,10 @@ class ConvNeXt(nn.Module):
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
self.head_norm_first = head_norm_first
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
('norm', nn.Identity() if head_norm_first or num_classes == 0 else norm_layer(self.num_features)),
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
@ -336,14 +341,7 @@ class ConvNeXt(nn.Module):
if global_pool is not None:
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
if num_classes == 0:
self.head.norm = nn.Identity()
self.head.fc = nn.Identity()
else:
if not self.head_norm_first:
norm_layer = type(self.stem[-1]) # obtain type from stem norm
self.head.norm = norm_layer(self.num_features)
self.head.fc = nn.Linear(self.num_features, num_classes)
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.stem(x)
@ -384,7 +382,15 @@ def checkpoint_filter_fn(state_dict, model):
return state_dict # non-FB checkpoint
if 'model' in state_dict:
state_dict = state_dict['model']
out_dict = {}
if 'visual.trunk.stem.0.weight' in state_dict:
out_dict = {k.replace('visual.trunk.', ''): v for k, v in state_dict.items() if k.startswith('visual.trunk.')}
if 'visual.head.proj.weight' in state_dict:
out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
return out_dict
import re
for k, v in state_dict.items():
k = k.replace('downsample_layers.0.', 'stem.')
@ -403,10 +409,16 @@ def checkpoint_filter_fn(state_dict, model):
model_shape = model.state_dict()[k].shape
v = v.reshape(model_shape)
out_dict[k] = v
return out_dict
def _create_convnext(variant, pretrained=False, **kwargs):
if kwargs.get('pretrained_cfg', '') == 'fcmae':
# NOTE fcmae pretrained weights have no classifier or final norm-layer (`head.norm`)
# This is workaround loading with num_classes=0 w/o removing norm-layer.
kwargs.setdefault('pretrained_strict', False)
model = build_model_with_cfg(
ConvNeXt, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
@ -481,10 +493,29 @@ default_cfgs = generate_default_cfgs({
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.in12k_ft_in1k_384': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_small.in12k_ft_in1k_384': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_nano.in12k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, num_classes=11821),
'convnext_tiny.in12k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, num_classes=11821),
'convnext_small.in12k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, num_classes=11821),
'convnext_tiny.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
@ -676,6 +707,33 @@ default_cfgs = generate_default_cfgs({
num_classes=0),
'convnextv2_small.untrained': _cfg(),
# CLIP based weights, original image tower weights and fine-tunes
'convnext_base.clip_laion2b': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laion2b_augreg': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laiona': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laiona_320': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laiona_augreg_320': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
})

@ -913,7 +913,7 @@ class CspNet(nn.Module):
# Construct the head
self.num_features = prev_chs
self.head = ClassifierHead(
in_chs=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
in_features=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)

@ -0,0 +1,679 @@
""" DaViT: Dual Attention Vision Transformers
As described in https://arxiv.org/abs/2204.03645
Input size invariant transformer architecture that combines channel and spacial
attention in each block. The attention mechanisms used are linear in complexity.
DaViT model defs and weights adapted from https://github.com/dingmyu/davit, original copyright below
"""
# Copyright (c) 2022 Mingyu Ding
# All rights reserved.
# This source code is licensed under the MIT license
import itertools
from collections import OrderedDict
from functools import partial
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp, LayerNorm2d, get_norm_layer
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
__all__ = ['DaViT']
class ConvPosEnc(nn.Module):
def __init__(self, dim: int, k: int = 3, act: bool = False):
super(ConvPosEnc, self).__init__()
self.proj = nn.Conv2d(dim, dim, k, 1, k // 2, groups=dim)
self.act = nn.GELU() if act else nn.Identity()
def forward(self, x: Tensor):
feat = self.proj(x)
x = x + self.act(feat)
return x
class Stem(nn.Module):
""" Size-agnostic implementation of 2D image to patch embedding,
allowing input size to be adjusted during model forward operation
"""
def __init__(
self,
in_chs=3,
out_chs=96,
stride=4,
norm_layer=LayerNorm2d,
):
super().__init__()
stride = to_2tuple(stride)
self.stride = stride
self.in_chs = in_chs
self.out_chs = out_chs
assert stride[0] == 4 # only setup for stride==4
self.conv = nn.Conv2d(
in_chs,
out_chs,
kernel_size=7,
stride=stride,
padding=3,
)
self.norm = norm_layer(out_chs)
def forward(self, x: Tensor):
B, C, H, W = x.shape
x = F.pad(x, (0, (self.stride[1] - W % self.stride[1]) % self.stride[1]))
x = F.pad(x, (0, 0, 0, (self.stride[0] - H % self.stride[0]) % self.stride[0]))
x = self.conv(x)
x = self.norm(x)
return x
class Downsample(nn.Module):
def __init__(
self,
in_chs,
out_chs,
norm_layer=LayerNorm2d,
):
super().__init__()
self.in_chs = in_chs
self.out_chs = out_chs
self.norm = norm_layer(in_chs)
self.conv = nn.Conv2d(
in_chs,
out_chs,
kernel_size=2,
stride=2,
padding=0,
)
def forward(self, x: Tensor):
B, C, H, W = x.shape
x = self.norm(x)
x = F.pad(x, (0, (2 - W % 2) % 2))
x = F.pad(x, (0, 0, 0, (2 - H % 2) % 2))
x = self.conv(x)
return x
class ChannelAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
def forward(self, x: Tensor):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
k = k * self.scale
attention = k.transpose(-1, -2) @ v
attention = attention.softmax(dim=-1)
x = (attention @ q.transpose(-1, -2)).transpose(-1, -2)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class ChannelBlock(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
ffn=True,
cpe_act=False,
):
super().__init__()
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
self.ffn = ffn
self.norm1 = norm_layer(dim)
self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
if self.ffn:
self.norm2 = norm_layer(dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
else:
self.norm2 = None
self.mlp = None
self.drop_path2 = None
def forward(self, x: Tensor):
B, C, H, W = x.shape
x = self.cpe1(x).flatten(2).transpose(1, 2)
cur = self.norm1(x)
cur = self.attn(cur)
x = x + self.drop_path1(cur)
x = self.cpe2(x.transpose(1, 2).view(B, C, H, W))
if self.mlp is not None:
x = x.flatten(2).transpose(1, 2)
x = x + self.drop_path2(self.mlp(self.norm2(x)))
x = x.transpose(1, 2).view(B, C, H, W)
return x
def window_partition(x: Tensor, window_size: Tuple[int, int]):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
return windows
@register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows: Tensor, window_size: Tuple[int, int], H: int, W: int):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x: Tensor):
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = self.softmax(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
return x
class SpatialBlock(nn.Module):
r""" Windows Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self,
dim,
num_heads,
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
ffn=True,
cpe_act=False,
):
super().__init__()
self.dim = dim
self.ffn = ffn
self.num_heads = num_heads
self.window_size = to_2tuple(window_size)
self.mlp_ratio = mlp_ratio
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
self.window_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
if self.ffn:
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
else:
self.norm2 = None
self.mlp = None
self.drop_path1 = None
def forward(self, x: Tensor):
B, C, H, W = x.shape
shortcut = self.cpe1(x).flatten(2).transpose(1, 2)
x = self.norm1(shortcut)
x = x.view(B, H, W, C)
pad_l = pad_t = 0
pad_r = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
pad_b = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
x_windows = window_partition(x, self.window_size)
x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows)
# merge windows
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
x = window_reverse(attn_windows, self.window_size, Hp, Wp)
# if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
x = shortcut + self.drop_path1(x)
x = self.cpe2(x.transpose(1, 2).view(B, C, H, W))
if self.mlp is not None:
x = x.flatten(2).transpose(1, 2)
x = x + self.drop_path2(self.mlp(self.norm2(x)))
x = x.transpose(1, 2).view(B, C, H, W)
return x
class DaViTStage(nn.Module):
def __init__(
self,
in_chs,
out_chs,
depth=1,
downsample=True,
attn_types=('spatial', 'channel'),
num_heads=3,
window_size=7,
mlp_ratio=4,
qkv_bias=True,
drop_path_rates=(0, 0),
norm_layer=LayerNorm2d,
norm_layer_cl=nn.LayerNorm,
ffn=True,
cpe_act=False
):
super().__init__()
self.grad_checkpointing = False
# downsample embedding layer at the beginning of each stage
if downsample:
self.downsample = Downsample(in_chs, out_chs, norm_layer=norm_layer)
else:
self.downsample = nn.Identity()
'''
repeating alternating attention blocks in each stage
default: (spatial -> channel) x depth
potential opportunity to integrate with a more general version of ByobNet/ByoaNet
since the logic is similar
'''
stage_blocks = []
for block_idx in range(depth):
dual_attention_block = []
for attn_idx, attn_type in enumerate(attn_types):
if attn_type == 'spatial':
dual_attention_block.append(SpatialBlock(
dim=out_chs,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=drop_path_rates[block_idx],
norm_layer=norm_layer_cl,
ffn=ffn,
cpe_act=cpe_act,
window_size=window_size,
))
elif attn_type == 'channel':
dual_attention_block.append(ChannelBlock(
dim=out_chs,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=drop_path_rates[block_idx],
norm_layer=norm_layer_cl,
ffn=ffn,
cpe_act=cpe_act
))
stage_blocks.append(nn.Sequential(*dual_attention_block))
self.blocks = nn.Sequential(*stage_blocks)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
def forward(self, x: Tensor):
x = self.downsample(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
class DaViT(nn.Module):
r""" DaViT
A PyTorch implementation of `DaViT: Dual Attention Vision Transformers` - https://arxiv.org/abs/2204.03645
Supports arbitrary input sizes and pyramid feature extraction
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks in each stage. Default: (1, 1, 3, 1)
embed_dims (tuple(int)): Patch embedding dimension. Default: (96, 192, 384, 768)
num_heads (tuple(int)): Number of attention heads in different layers. Default: (3, 6, 12, 24)
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
"""
def __init__(
self,
in_chans=3,
depths=(1, 1, 3, 1),
embed_dims=(96, 192, 384, 768),
num_heads=(3, 6, 12, 24),
window_size=7,
mlp_ratio=4,
qkv_bias=True,
norm_layer='layernorm2d',
norm_layer_cl='layernorm',
norm_eps=1e-5,
attn_types=('spatial', 'channel'),
ffn=True,
cpe_act=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_classes=1000,
global_pool='avg',
head_norm_first=False,
):
super().__init__()
num_stages = len(embed_dims)
assert num_stages == len(num_heads) == len(depths)
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
self.num_classes = num_classes
self.num_features = embed_dims[-1]
self.drop_rate = drop_rate
self.grad_checkpointing = False
self.feature_info = []
self.stem = Stem(in_chans, embed_dims[0], norm_layer=norm_layer)
in_chs = embed_dims[0]
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
stages = []
for stage_idx in range(num_stages):
out_chs = embed_dims[stage_idx]
stage = DaViTStage(
in_chs,
out_chs,
depth=depths[stage_idx],
downsample=stage_idx > 0,
attn_types=attn_types,
num_heads=num_heads[stage_idx],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path_rates=dpr[stage_idx],
norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl,
ffn=ffn,
cpe_act=cpe_act,
)
in_chs = out_chs
stages.append(stage)
self.feature_info += [dict(num_chs=out_chs, reduction=2, module=f'stages.{stage_idx}')]
self.stages = nn.Sequential(*stages)
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
# otherwise pool -> norm -> fc, the default DaViT order, similar to ConvNeXt
# FIXME generalize this structure to ClassifierHead
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
for stage in self.stages:
stage.set_grad_checkpointing(enable=enable)
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
if global_pool is not None:
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stages, x)
else:
x = self.stages(x)
x = self.norm_pre(x)
return x
def forward_head(self, x, pre_logits: bool = False):
x = self.head.global_pool(x)
x = self.head.norm(x)
x = self.head.flatten(x)
x = self.head.drop(x)
return x if pre_logits else self.head.fc(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """
if 'head.fc.weight' in state_dict:
return state_dict # non-MSFT checkpoint
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
import re
out_dict = {}
for k, v in state_dict.items():
k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.downsample', k)
k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
k = k.replace('downsample.proj', 'downsample.conv')
k = k.replace('stages.0.downsample', 'stem')
k = k.replace('head.', 'head.fc.')
k = k.replace('norms.', 'head.norm.')
k = k.replace('cpe.0', 'cpe1')
k = k.replace('cpe.1', 'cpe2')
out_dict[k] = v
return out_dict
def _create_davit(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
model = build_model_with_cfg(
DaViT,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs)
return model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.95, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
**kwargs
}
# TODO contact authors to get larger pretrained models
default_cfgs = generate_default_cfgs({
# official microsoft weights from https://github.com/dingmyu/davit
'davit_tiny.msft_in1k': _cfg(
hf_hub_id='timm/'),
'davit_small.msft_in1k': _cfg(
hf_hub_id='timm/'),
'davit_base.msft_in1k': _cfg(
hf_hub_id='timm/'),
'davit_large': _cfg(),
'davit_huge': _cfg(),
'davit_giant': _cfg(),
})
@register_model
def davit_tiny(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24), **kwargs)
return _create_davit('davit_tiny', pretrained=pretrained, **model_kwargs)
@register_model
def davit_small(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24), **kwargs)
return _create_davit('davit_small', pretrained=pretrained, **model_kwargs)
@register_model
def davit_base(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32), **kwargs)
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
@register_model
def davit_large(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536), num_heads=(6, 12, 24, 48), **kwargs)
return _create_davit('davit_large', pretrained=pretrained, **model_kwargs)
@register_model
def davit_huge(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64), **kwargs)
return _create_davit('davit_huge', pretrained=pretrained, **model_kwargs)
@register_model
def davit_giant(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96), **kwargs)
return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs)

@ -12,7 +12,7 @@ import torch.utils.checkpoint as cp
from torch.jit.annotations import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier
from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier
from ._builder import build_model_with_cfg
from ._manipulate import MATCH_PREV_GROUP
from ._registry import register_model
@ -115,8 +115,15 @@ class DenseBlock(nn.ModuleDict):
_version = 2
def __init__(
self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=BatchNormAct2d,
drop_rate=0., memory_efficient=False):
self,
num_layers,
num_input_features,
bn_size,
growth_rate,
norm_layer=BatchNormAct2d,
drop_rate=0.,
memory_efficient=False,
):
super(DenseBlock, self).__init__()
for i in range(num_layers):
layer = DenseLayer(
@ -165,12 +172,25 @@ class DenseNet(nn.Module):
"""
def __init__(
self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=1000, in_chans=3, global_pool='avg',
bn_size=4, stem_type='', norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0,
memory_efficient=False, aa_stem_only=True):
self,
growth_rate=32,
block_config=(6, 12, 24, 16),
num_classes=1000,
in_chans=3,
global_pool='avg',
bn_size=4,
stem_type='',
act_layer='relu',
norm_layer='batchnorm2d',
aa_layer=None,
drop_rate=0,
memory_efficient=False,
aa_stem_only=True,
):
self.num_classes = num_classes
self.drop_rate = drop_rate
super(DenseNet, self).__init__()
norm_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
# Stem
deep_stem = 'deep' in stem_type # 3x3 deep stem
@ -226,8 +246,11 @@ class DenseNet(nn.Module):
dict(num_chs=num_features, reduction=current_stride, module='features.' + module_name)]
current_stride *= 2
trans = DenseTransition(
num_input_features=num_features, num_output_features=num_features // 2,
norm_layer=norm_layer, aa_layer=transition_aa_layer)
num_input_features=num_features,
num_output_features=num_features // 2,
norm_layer=norm_layer,
aa_layer=transition_aa_layer,
)
self.features.add_module(f'transition{i + 1}', trans)
num_features = num_features // 2
@ -322,8 +345,8 @@ def densenetblur121d(pretrained=False, **kwargs):
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
model = _create_densenet(
'densenetblur121d', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, stem_type='deep',
aa_layer=BlurPool2d, **kwargs)
'densenetblur121d', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained,
stem_type='deep', aa_layer=BlurPool2d, **kwargs)
return model
@ -382,11 +405,9 @@ def densenet264(pretrained=False, **kwargs):
def densenet264d_iabn(pretrained=False, **kwargs):
r"""Densenet-264 model with deep stem and Inplace-ABN
"""
def norm_act_fn(num_features, **kwargs):
return create_norm_act_layer('iabn', num_features, act_layer='leaky_relu', **kwargs)
model = _create_densenet(
'densenet264d_iabn', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep',
norm_layer=norm_act_fn, pretrained=pretrained, **kwargs)
norm_layer='iabn', act_layer='leaky_relu', pretrained=pretrained, **kwargs)
return model

@ -15,7 +15,7 @@ import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier
from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier, get_norm_act_layer
from ._builder import build_model_with_cfg
from ._registry import register_model
@ -33,6 +33,7 @@ def _cfg(url='', **kwargs):
default_cfgs = {
'dpn48b': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'dpn68': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'),
'dpn68b': _cfg(
@ -82,7 +83,16 @@ class BnActConv2d(nn.Module):
class DualPathBlock(nn.Module):
def __init__(
self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, groups, block_type='normal', b=False):
self,
in_chs,
num_1x1_a,
num_3x3_b,
num_1x1_c,
inc,
groups,
block_type='normal',
b=False,
):
super(DualPathBlock, self).__init__()
self.num_1x1_c = num_1x1_c
self.inc = inc
@ -167,16 +177,31 @@ class DualPathBlock(nn.Module):
class DPN(nn.Module):
def __init__(
self, small=False, num_init_features=64, k_r=96, groups=32, global_pool='avg',
b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), output_stride=32,
num_classes=1000, in_chans=3, drop_rate=0., fc_act_layer=nn.ELU):
self,
k_sec=(3, 4, 20, 3),
inc_sec=(16, 32, 24, 128),
k_r=96,
groups=32,
num_classes=1000,
in_chans=3,
output_stride=32,
global_pool='avg',
small=False,
num_init_features=64,
b=False,
drop_rate=0.,
norm_layer='batchnorm2d',
act_layer='relu',
fc_act_layer='elu',
):
super(DPN, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
self.b = b
assert output_stride == 32 # FIXME look into dilation support
norm_layer = partial(BatchNormAct2d, eps=.001)
fc_norm_layer = partial(BatchNormAct2d, eps=.001, act_layer=fc_act_layer, inplace=False)
norm_layer = partial(get_norm_act_layer(norm_layer, act_layer=act_layer), eps=.001)
fc_norm_layer = partial(get_norm_act_layer(norm_layer, act_layer=fc_act_layer), eps=.001, inplace=False)
bw_factor = 1 if small else 4
blocks = OrderedDict()
@ -291,49 +316,57 @@ def _create_dpn(variant, pretrained=False, **kwargs):
**kwargs)
@register_model
def dpn48b(pretrained=False, **kwargs):
model_kwargs = dict(
small=True, num_init_features=10, k_r=128, groups=32,
b=True, k_sec=(3, 4, 6, 3), inc_sec=(16, 32, 32, 64), act_layer='silu')
return _create_dpn('dpn48b', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
def dpn68(pretrained=False, **kwargs):
model_kwargs = dict(
small=True, num_init_features=10, k_r=128, groups=32,
k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs)
return _create_dpn('dpn68', pretrained=pretrained, **model_kwargs)
k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64))
return _create_dpn('dpn68', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
def dpn68b(pretrained=False, **kwargs):
model_kwargs = dict(
small=True, num_init_features=10, k_r=128, groups=32,
b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs)
return _create_dpn('dpn68b', pretrained=pretrained, **model_kwargs)
b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64))
return _create_dpn('dpn68b', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
def dpn92(pretrained=False, **kwargs):
model_kwargs = dict(
num_init_features=64, k_r=96, groups=32,
k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), **kwargs)
return _create_dpn('dpn92', pretrained=pretrained, **model_kwargs)
k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128))
return _create_dpn('dpn92', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
def dpn98(pretrained=False, **kwargs):
model_kwargs = dict(
num_init_features=96, k_r=160, groups=40,
k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128), **kwargs)
return _create_dpn('dpn98', pretrained=pretrained, **model_kwargs)
k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128))
return _create_dpn('dpn98', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
def dpn131(pretrained=False, **kwargs):
model_kwargs = dict(
num_init_features=128, k_r=160, groups=40,
k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128), **kwargs)
return _create_dpn('dpn131', pretrained=pretrained, **model_kwargs)
k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128))
return _create_dpn('dpn131', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
def dpn107(pretrained=False, **kwargs):
model_kwargs = dict(
num_init_features=128, k_r=200, groups=50,
k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128), **kwargs)
return _create_dpn('dpn107', pretrained=pretrained, **model_kwargs)
k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128))
return _create_dpn('dpn107', pretrained=pretrained, **dict(model_kwargs, **kwargs))

@ -12,9 +12,6 @@ These configs work well and appear to be a bit faster / lower resource than the
The models without extra prefix / suffix' (coatnet_0_224, maxvit_tiny_224, etc), are intended to
match paper, BUT, without any official pretrained weights it's difficult to confirm a 100% match.
# FIXME / WARNING
This impl remains a WIP, some configs and models may vanish or change...
Papers:
MaxViT: Multi-Axis Vision Transformer - https://arxiv.org/abs/2204.01697
@ -76,6 +73,8 @@ class MaxxVitTransformerCfg:
partition_ratio: int = 32
window_size: Optional[Tuple[int, int]] = None
grid_size: Optional[Tuple[int, int]] = None
no_block_attn: bool = False # disable window block attention for maxvit (ie only grid)
use_nchw_attn: bool = False # for MaxViT variants (not used for CoAt), keep tensors in NCHW order
init_values: Optional[float] = None
act_layer: str = 'gelu'
norm_layer: str = 'layernorm2d'
@ -889,19 +888,17 @@ class MaxxVitBlock(nn.Module):
stride: int = 1,
conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
use_nchw_attn: bool = False, # FIXME move to cfg? True is ~20-30% faster on TPU, 5-10% slower on GPU
use_block_attn: bool = True, # FIXME for testing ConvNeXt conv w/o block attention
drop_path: float = 0.,
):
super().__init__()
self.nchw_attn = transformer_cfg.use_nchw_attn
conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttentionCl
self.nchw_attn = use_nchw_attn
self.attn_block = partition_layer(**attn_kwargs) if use_block_attn else None
partition_layer = PartitionAttention2d if self.nchw_attn else PartitionAttentionCl
self.attn_block = None if transformer_cfg.no_block_attn else partition_layer(**attn_kwargs)
self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs)
def init_weights(self, scheme=''):
@ -1084,26 +1081,48 @@ class NormMlpHead(nn.Module):
hidden_size=None,
pool_type='avg',
drop_rate=0.,
norm_layer=nn.LayerNorm,
act_layer=nn.Tanh,
norm_layer='layernorm2d',
act_layer='tanh',
):
super().__init__()
self.drop_rate = drop_rate
self.in_features = in_features
self.hidden_size = hidden_size
self.num_features = in_features
self.use_conv = not pool_type
norm_layer = get_norm_layer(norm_layer)
act_layer = get_act_layer(act_layer)
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
self.norm = norm_layer(in_features)
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
if hidden_size:
self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(in_features, hidden_size)),
('fc', linear_layer(in_features, hidden_size)),
('act', act_layer()),
]))
self.num_features = hidden_size
else:
self.pre_logits = nn.Identity()
self.drop = nn.Dropout(self.drop_rate)
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def reset(self, num_classes, global_pool=None):
if global_pool is not None:
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.use_conv = self.global_pool.is_identity()
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
if self.hidden_size:
if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or
(isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)):
with torch.no_grad():
new_fc = linear_layer(self.in_features, self.hidden_size)
new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape))
new_fc.bias.copy_(self.pre_logits.fc.bias)
self.pre_logits.fc = new_fc
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x, pre_logits: bool = False):
x = self.global_pool(x)
@ -1116,6 +1135,26 @@ class NormMlpHead(nn.Module):
return x
def _overlay_kwargs(cfg: MaxxVitCfg, **kwargs):
transformer_kwargs = {}
conv_kwargs = {}
base_kwargs = {}
for k, v in kwargs.items():
if k.startswith('transformer_'):
transformer_kwargs[k.replace('transformer_', '')] = v
elif k.startswith('conv_'):
conv_kwargs[k.replace('conv_', '')] = v
else:
base_kwargs[k] = v
cfg = replace(
cfg,
transformer_cfg=replace(cfg.transformer_cfg, **transformer_kwargs),
conv_cfg=replace(cfg.conv_cfg, **conv_kwargs),
**base_kwargs
)
return cfg
class MaxxVit(nn.Module):
""" CoaTNet + MaxVit base model.
@ -1130,16 +1169,20 @@ class MaxxVit(nn.Module):
num_classes: int = 1000,
global_pool: str = 'avg',
drop_rate: float = 0.,
drop_path_rate: float = 0.
drop_path_rate: float = 0.,
**kwargs,
):
super().__init__()
img_size = to_2tuple(img_size)
if kwargs:
cfg = _overlay_kwargs(cfg, **kwargs)
transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size)
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = cfg.embed_dim[-1]
self.drop_rate = drop_rate
self.grad_checkpointing = False
self.feature_info = []
self.stem = Stem(
in_chs=in_chans,
@ -1150,8 +1193,8 @@ class MaxxVit(nn.Module):
norm_layer=cfg.conv_cfg.norm_layer,
norm_eps=cfg.conv_cfg.norm_eps,
)
stride = self.stem.stride
self.feature_info += [dict(num_chs=self.stem.out_chs, reduction=2, module='stem')]
feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))])
num_stages = len(cfg.embed_dim)
@ -1175,15 +1218,17 @@ class MaxxVit(nn.Module):
)]
stride *= stage_stride
in_chs = out_chs
self.feature_info += [dict(num_chs=out_chs, reduction=stride, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
final_norm_layer = partial(get_norm_layer(cfg.transformer_cfg.norm_layer), eps=cfg.transformer_cfg.norm_eps)
if cfg.head_hidden_size:
self.head_hidden_size = cfg.head_hidden_size
if self.head_hidden_size:
self.norm = nn.Identity()
self.head = NormMlpHead(
self.num_features,
num_classes,
hidden_size=cfg.head_hidden_size,
hidden_size=self.head_hidden_size,
pool_type=global_pool,
drop_rate=drop_rate,
norm_layer=final_norm_layer,
@ -1230,9 +1275,7 @@ class MaxxVit(nn.Module):
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is None:
global_pool = self.head.global_pool.pool_type
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
self.head.reset(num_classes, global_pool)
def forward_features(self, x):
x = self.stem(x)
@ -1353,6 +1396,7 @@ def _next_cfg(
transformer_norm_layer='layernorm2d',
transformer_norm_layer_cl='layernorm',
window_size=None,
no_block_attn=False,
init_values=1e-6,
rel_pos_type='mlp', # MLP by default for maxxvit
rel_pos_dim=512,
@ -1373,6 +1417,7 @@ def _next_cfg(
expand_first=False,
pool_type=pool_type,
window_size=window_size,
no_block_attn=no_block_attn, # enabled for MaxxViT-V2
init_values=init_values[1],
norm_layer=transformer_norm_layer,
norm_layer_cl=transformer_norm_layer_cl,
@ -1399,8 +1444,8 @@ def _tf_cfg():
model_cfgs = dict(
# Fiddling with configs / defaults / still pretraining
coatnet_pico_rw_224=MaxxVitCfg(
# timm specific CoAtNet configs
coatnet_pico_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 3, 5, 2),
stem_width=(32, 64),
@ -1409,7 +1454,7 @@ model_cfgs = dict(
conv_attn_ratio=0.25,
),
),
coatnet_nano_rw_224=MaxxVitCfg(
coatnet_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3),
stem_width=(32, 64),
@ -1419,7 +1464,7 @@ model_cfgs = dict(
conv_attn_ratio=0.25,
),
),
coatnet_0_rw_224=MaxxVitCfg(
coatnet_0_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 3, 7, 2), # deeper than paper '0' model
stem_width=(32, 64),
@ -1428,7 +1473,7 @@ model_cfgs = dict(
transformer_shortcut_bias=False,
),
),
coatnet_1_rw_224=MaxxVitCfg(
coatnet_1_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
stem_width=(32, 64),
@ -1438,7 +1483,7 @@ model_cfgs = dict(
transformer_shortcut_bias=False,
)
),
coatnet_2_rw_224=MaxxVitCfg(
coatnet_2_rw=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2),
stem_width=(64, 128),
@ -1448,7 +1493,7 @@ model_cfgs = dict(
#init_values=1e-6,
),
),
coatnet_3_rw_224=MaxxVitCfg(
coatnet_3_rw=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2),
stem_width=(96, 192),
@ -1459,8 +1504,8 @@ model_cfgs = dict(
),
),
# Highly experimental configs
coatnet_bn_0_rw_224=MaxxVitCfg(
# Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
coatnet_bn_0_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 3, 7, 2), # deeper than paper '0' model
stem_width=(32, 64),
@ -1471,7 +1516,7 @@ model_cfgs = dict(
transformer_norm_layer='batchnorm2d',
)
),
coatnet_rmlp_nano_rw_224=MaxxVitCfg(
coatnet_rmlp_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3),
stem_width=(32, 64),
@ -1482,7 +1527,7 @@ model_cfgs = dict(
rel_pos_dim=384,
),
),
coatnet_rmlp_0_rw_224=MaxxVitCfg(
coatnet_rmlp_0_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 3, 7, 2), # deeper than paper '0' model
stem_width=(32, 64),
@ -1491,7 +1536,7 @@ model_cfgs = dict(
rel_pos_type='mlp',
),
),
coatnet_rmlp_1_rw_224=MaxxVitCfg(
coatnet_rmlp_1_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
stem_width=(32, 64),
@ -1503,7 +1548,7 @@ model_cfgs = dict(
rel_pos_dim=384, # was supposed to be 512, woops
),
),
coatnet_rmlp_1_rw2_224=MaxxVitCfg(
coatnet_rmlp_1_rw2=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
stem_width=(32, 64),
@ -1513,7 +1558,7 @@ model_cfgs = dict(
rel_pos_dim=512, # was supposed to be 512, woops
),
),
coatnet_rmlp_2_rw_224=MaxxVitCfg(
coatnet_rmlp_2_rw=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2),
stem_width=(64, 128),
@ -1524,7 +1569,7 @@ model_cfgs = dict(
rel_pos_type='mlp'
),
),
coatnet_rmlp_3_rw_224=MaxxVitCfg(
coatnet_rmlp_3_rw=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2),
stem_width=(96, 192),
@ -1536,14 +1581,14 @@ model_cfgs = dict(
),
),
coatnet_nano_cc_224=MaxxVitCfg(
coatnet_nano_cc=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3),
stem_width=(32, 64),
block_type=('C', 'C', ('C', 'T'), ('C', 'T')),
**_rw_coat_cfg(),
),
coatnext_nano_rw_224=MaxxVitCfg(
coatnext_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3),
stem_width=(32, 64),
@ -1555,89 +1600,95 @@ model_cfgs = dict(
),
# Trying to be like the CoAtNet paper configs
coatnet_0_224=MaxxVitCfg(
coatnet_0=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 3, 5, 2),
stem_width=64,
head_hidden_size=768,
),
coatnet_1_224=MaxxVitCfg(
coatnet_1=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
stem_width=64,
head_hidden_size=768,
),
coatnet_2_224=MaxxVitCfg(
coatnet_2=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2),
stem_width=128,
head_hidden_size=1024,
),
coatnet_3_224=MaxxVitCfg(
coatnet_3=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2),
stem_width=192,
head_hidden_size=1536,
),
coatnet_4_224=MaxxVitCfg(
coatnet_4=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536),
depths=(2, 12, 28, 2),
stem_width=192,
head_hidden_size=1536,
),
coatnet_5_224=MaxxVitCfg(
coatnet_5=MaxxVitCfg(
embed_dim=(256, 512, 1280, 2048),
depths=(2, 12, 28, 2),
stem_width=192,
head_hidden_size=2048,
),
# Experimental MaxVit configs
maxvit_pico_rw_256=MaxxVitCfg(
maxvit_pico_rw=MaxxVitCfg(
embed_dim=(32, 64, 128, 256),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(24, 32),
**_rw_max_cfg(),
),
maxvit_nano_rw_256=MaxxVitCfg(
maxvit_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(),
),
maxvit_tiny_rw_224=MaxxVitCfg(
maxvit_tiny_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(),
),
maxvit_tiny_rw_256=MaxxVitCfg(
maxvit_tiny_pm=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
block_type=('PM',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(),
),
maxvit_rmlp_pico_rw_256=MaxxVitCfg(
maxvit_rmlp_pico_rw=MaxxVitCfg(
embed_dim=(32, 64, 128, 256),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(24, 32),
**_rw_max_cfg(rel_pos_type='mlp'),
),
maxvit_rmlp_nano_rw_256=MaxxVitCfg(
maxvit_rmlp_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(rel_pos_type='mlp'),
),
maxvit_rmlp_tiny_rw_256=MaxxVitCfg(
maxvit_rmlp_tiny_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(rel_pos_type='mlp'),
),
maxvit_rmlp_small_rw_224=MaxxVitCfg(
maxvit_rmlp_small_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
@ -1647,26 +1698,18 @@ model_cfgs = dict(
init_values=1e-6,
),
),
maxvit_rmlp_small_rw_256=MaxxVitCfg(
maxvit_rmlp_base_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2),
depths=(2, 6, 14, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
head_hidden_size=768,
**_rw_max_cfg(
rel_pos_type='mlp',
init_values=1e-6,
),
),
maxvit_tiny_pm_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('PM',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(),
),
maxxvit_rmlp_nano_rw_256=MaxxVitCfg(
maxxvit_rmlp_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
@ -1674,33 +1717,50 @@ model_cfgs = dict(
weight_init='normal',
**_next_cfg(),
),
maxxvit_rmlp_tiny_rw_256=MaxxVitCfg(
maxxvit_rmlp_tiny_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_next_cfg(),
),
maxxvit_rmlp_small_rw_256=MaxxVitCfg(
maxxvit_rmlp_small_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(48, 96),
**_next_cfg(),
),
maxxvit_rmlp_base_rw_224=MaxxVitCfg(
maxxvitv2_nano_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
stem_width=(48, 96),
**_next_cfg(),
weight_init='normal',
**_next_cfg(
no_block_attn=True,
rel_pos_type='bias',
),
),
maxxvit_rmlp_large_rw_224=MaxxVitCfg(
maxxvitv2_rmlp_base_rw=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 12, 2),
block_type=('M',) * 4,
stem_width=(64, 128),
**_next_cfg(),
**_next_cfg(
no_block_attn=True,
),
),
maxxvitv2_rmlp_large_rw=MaxxVitCfg(
embed_dim=(160, 320, 640, 1280),
depths=(2, 6, 16, 2),
block_type=('M',) * 4,
stem_width=(80, 160),
head_hidden_size=1280,
**_next_cfg(
no_block_attn=True,
),
),
# Trying to be like the MaxViT paper configs
@ -1752,11 +1812,29 @@ model_cfgs = dict(
)
def checkpoint_filter_fn(state_dict, model: nn.Module):
model_state_dict = model.state_dict()
out_dict = {}
for k, v in state_dict.items():
if k in model_state_dict and v.ndim != model_state_dict[k].ndim and v.numel() == model_state_dict[k].numel():
# adapt between conv2d / linear layers
assert v.ndim in (2, 4)
v = v.reshape(model_state_dict[k].shape)
out_dict[k] = v
return out_dict
def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs):
if cfg_variant is None:
if variant in model_cfgs:
cfg_variant = variant
else:
cfg_variant = '_'.join(variant.split('_')[:-1])
return build_model_with_cfg(
MaxxVit, variant, pretrained,
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
model_cfg=model_cfgs[cfg_variant],
feature_cfg=dict(flatten_sequential=True),
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
@ -1772,149 +1850,218 @@ def _cfg(url='', **kwargs):
default_cfgs = generate_default_cfgs({
# Fiddling with configs / defaults / still pretraining
'coatnet_pico_rw_224': _cfg(url=''),
'coatnet_nano_rw_224': _cfg(
# timm specific CoAtNet configs, ImageNet-1k pretrain, fixed rel-pos
'coatnet_pico_rw_224.untrained': _cfg(url=''),
'coatnet_nano_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth',
crop_pct=0.9),
'coatnet_0_rw_224': _cfg(
'coatnet_0_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'),
'coatnet_1_rw_224': _cfg(
'coatnet_1_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth'
),
'coatnet_2_rw_224': _cfg(url=''),
'coatnet_3_rw_224': _cfg(url=''),
# Highly experimental configs
'coatnet_bn_0_rw_224': _cfg(
# timm specific CoAtNet configs, ImageNet-12k pretrain w/ 1k fine-tune, fixed rel-pos
'coatnet_2_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
#'coatnet_3_rw_224.untrained': _cfg(url=''),
# Experimental CoAtNet configs w/ ImageNet-12k pretrain -> 1k fine-tune (different norm layers, MLP rel-pos)
'coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
'coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
'coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
# Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
'coatnet_bn_0_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
crop_pct=0.95),
'coatnet_rmlp_nano_rw_224': _cfg(
'coatnet_rmlp_nano_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth',
crop_pct=0.9),
'coatnet_rmlp_0_rw_224': _cfg(url=''),
'coatnet_rmlp_1_rw_224': _cfg(
'coatnet_rmlp_0_rw_224.untrained': _cfg(url=''),
'coatnet_rmlp_1_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'),
'coatnet_rmlp_1_rw2_224': _cfg(url=''),
'coatnet_rmlp_2_rw_224': _cfg(
'coatnet_rmlp_2_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'),
'coatnet_rmlp_3_rw_224': _cfg(url=''),
'coatnet_nano_cc_224': _cfg(url=''),
'coatnext_nano_rw_224': _cfg(
'coatnet_rmlp_3_rw_224.untrained': _cfg(url=''),
'coatnet_nano_cc_224.untrained': _cfg(url=''),
'coatnext_nano_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth',
crop_pct=0.9),
# Trying to be like the CoAtNet paper configs
'coatnet_0_224': _cfg(url=''),
'coatnet_1_224': _cfg(url=''),
'coatnet_2_224': _cfg(url=''),
'coatnet_3_224': _cfg(url=''),
'coatnet_4_224': _cfg(url=''),
'coatnet_5_224': _cfg(url=''),
# Experimental configs
'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_nano_rw_256': _cfg(
# ImagenNet-12k pretrain CoAtNet
'coatnet_2_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
'coatnet_3_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
'coatnet_rmlp_1_rw2_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
'coatnet_rmlp_2_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
# Trying to be like the CoAtNet paper configs (will adapt if 'tf' weights are ever released)
'coatnet_0_224.untrained': _cfg(url=''),
'coatnet_1_224.untrained': _cfg(url=''),
'coatnet_2_224.untrained': _cfg(url=''),
'coatnet_3_224.untrained': _cfg(url=''),
'coatnet_4_224.untrained': _cfg(url=''),
'coatnet_5_224.untrained': _cfg(url=''),
# timm specific MaxVit configs, ImageNet-1k pretrain or untrained
'maxvit_pico_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_rw_224': _cfg(
'maxvit_tiny_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'),
'maxvit_tiny_rw_256': _cfg(
'maxvit_tiny_rw_256.untrained': _cfg(
url='',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_pico_rw_256': _cfg(
'maxvit_tiny_pm_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
# timm specific MaxVit w/ MLP rel-pos, ImageNet-1k pretrain
'maxvit_rmlp_pico_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_nano_rw_256': _cfg(
'maxvit_rmlp_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_tiny_rw_256': _cfg(
'maxvit_rmlp_tiny_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_small_rw_224': _cfg(
'maxvit_rmlp_small_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth',
crop_pct=0.9,
),
'maxvit_rmlp_small_rw_256': _cfg(
'maxvit_rmlp_small_rw_256.untrained': _cfg(
url='',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
# timm specific MaxVit w/ ImageNet-12k pretrain and 1k fine-tune
'maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
),
'maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
# timm specific MaxVit w/ ImageNet-12k pretrain
'maxvit_rmlp_base_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821,
),
'maxxvit_rmlp_nano_rw_256': _cfg(
# timm MaxxViT configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks)
'maxxvit_rmlp_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_rmlp_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_rmlp_small_rw_256': _cfg(
'maxxvit_rmlp_tiny_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_rmlp_small_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_rmlp_base_rw_224': _cfg(url=''),
'maxxvit_rmlp_large_rw_224': _cfg(url=''),
# timm MaxxViT-V2 configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks, more width, no block attn)
'maxxvitv2_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
'maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxxvitv2_rmlp_large_rw_224.untrained': _cfg(url=''),
'maxxvitv2_rmlp_base_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
# MaxViT models ported from official Tensorflow impl
'maxvit_tiny_tf_224.in1k': _cfg(
hf_hub_id='timm/maxvit_tiny_tf_224.in1k',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_tiny_tf_384.in1k': _cfg(
hf_hub_id='timm/maxvit_tiny_tf_384.in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_tiny_tf_512.in1k': _cfg(
hf_hub_id='timm/maxvit_tiny_tf_512.in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_small_tf_224.in1k': _cfg(
hf_hub_id='timm/maxvit_small_tf_224.in1k',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_small_tf_384.in1k': _cfg(
hf_hub_id='timm/maxvit_small_tf_384.in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_small_tf_512.in1k': _cfg(
hf_hub_id='timm/maxvit_small_tf_512.in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_224.in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_224.in1k',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_base_tf_384.in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_384.in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_512.in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_512.in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_224.in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_224.in1k',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_large_tf_384.in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_384.in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_512.in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_512.in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_224.in21k': _cfg(
url=''),
'maxvit_base_tf_384.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_384.in21k_ft_in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_512.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_512.in21k_ft_in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_224.in21k': _cfg(
url=''),
'maxvit_large_tf_384.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_384.in21k_ft_in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_512.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_512.in21k_ft_in1k',
hf_hub_id='timm/',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
'maxvit_xlarge_tf_224.in21k': _cfg(
url=''),
'maxvit_xlarge_tf_384.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_xlarge_tf_384.in21k_ft_in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_xlarge_tf_512.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_xlarge_tf_512.in21k_ft_in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
})
@ -1978,6 +2125,11 @@ def coatnet_rmlp_2_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_rmlp_2_rw_384(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_2_rw_384', pretrained=pretrained, **kwargs)
@register_model
def coatnet_rmlp_3_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs)
@ -2068,6 +2220,16 @@ def maxvit_rmlp_small_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_base_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_base_rw_384(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_base_rw_384', pretrained=pretrained, **kwargs)
@register_model
def maxvit_tiny_pm_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)
@ -2089,13 +2251,23 @@ def maxxvit_rmlp_small_rw_256(pretrained=False, **kwargs):
@register_model
def maxxvit_rmlp_base_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
def maxxvitv2_nano_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxxvitv2_nano_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxxvitv2_rmlp_base_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxxvitv2_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
@register_model
def maxxvitv2_rmlp_base_rw_384(pretrained=False, **kwargs):
return _create_maxxvit('maxxvitv2_rmlp_base_rw_384', pretrained=pretrained, **kwargs)
@register_model
def maxxvit_rmlp_large_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxxvit_rmlp_large_rw_224', pretrained=pretrained, **kwargs)
def maxxvitv2_rmlp_large_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxxvitv2_rmlp_large_rw_224', pretrained=pretrained, **kwargs)
@register_model

@ -266,9 +266,16 @@ class MobileVitBlock(nn.Module):
self.transformer = nn.Sequential(*[
TransformerBlock(
transformer_dim, mlp_ratio=mlp_ratio, num_heads=num_heads, qkv_bias=True,
attn_drop=attn_drop, drop=drop, drop_path=drop_path_rate,
act_layer=layers.act, norm_layer=transformer_norm_layer)
transformer_dim,
mlp_ratio=mlp_ratio,
num_heads=num_heads,
qkv_bias=True,
attn_drop=attn_drop,
drop=drop,
drop_path=drop_path_rate,
act_layer=layers.act,
norm_layer=transformer_norm_layer,
)
for _ in range(transformer_depth)
])
self.norm = transformer_norm_layer(transformer_dim)

@ -496,7 +496,7 @@ class RegNet(nn.Module):
self.final_conv = get_act_layer(cfg.act_layer)() if final_act else nn.Identity()
self.num_features = prev_width
self.head = ClassifierHead(
in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)

@ -156,8 +156,8 @@ def res2net50_26w_4s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4), **kwargs)
return _create_res2net('res2net50_26w_4s', pretrained, **model_args)
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4))
return _create_res2net('res2net50_26w_4s', pretrained, **dict(model_args, **kwargs))
@register_model
@ -167,8 +167,8 @@ def res2net101_26w_4s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model_args = dict(
block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4), **kwargs)
return _create_res2net('res2net101_26w_4s', pretrained, **model_args)
block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4))
return _create_res2net('res2net101_26w_4s', pretrained, **dict(model_args, **kwargs))
@register_model
@ -178,8 +178,8 @@ def res2net50_26w_6s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6), **kwargs)
return _create_res2net('res2net50_26w_6s', pretrained, **model_args)
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6))
return _create_res2net('res2net50_26w_6s', pretrained, **dict(model_args, **kwargs))
@register_model
@ -189,8 +189,8 @@ def res2net50_26w_8s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8), **kwargs)
return _create_res2net('res2net50_26w_8s', pretrained, **model_args)
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8))
return _create_res2net('res2net50_26w_8s', pretrained, **dict(model_args, **kwargs))
@register_model
@ -200,8 +200,8 @@ def res2net50_48w_2s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2), **kwargs)
return _create_res2net('res2net50_48w_2s', pretrained, **model_args)
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2))
return _create_res2net('res2net50_48w_2s', pretrained, **dict(model_args, **kwargs))
@register_model
@ -211,8 +211,8 @@ def res2net50_14w_8s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8), **kwargs)
return _create_res2net('res2net50_14w_8s', pretrained, **model_args)
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8))
return _create_res2net('res2net50_14w_8s', pretrained, **dict(model_args, **kwargs))
@register_model
@ -222,5 +222,5 @@ def res2next50(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4), **kwargs)
return _create_res2net('res2next50', pretrained, **model_args)
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4))
return _create_res2net('res2next50', pretrained, **dict(model_args, **kwargs))

@ -163,8 +163,8 @@ def resnest14d(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[1, 1, 1, 1],
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
return _create_resnest('resnest14d', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest14d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
@ -174,8 +174,8 @@ def resnest26d(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[2, 2, 2, 2],
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
return _create_resnest('resnest26d', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest26d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
@ -186,8 +186,8 @@ def resnest50d(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 4, 6, 3],
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
return _create_resnest('resnest50d', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest50d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
@ -198,8 +198,8 @@ def resnest101e(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 4, 23, 3],
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
return _create_resnest('resnest101e', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest101e', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
@ -210,8 +210,8 @@ def resnest200e(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 24, 36, 3],
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
return _create_resnest('resnest200e', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest200e', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
@ -222,8 +222,8 @@ def resnest269e(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 30, 48, 8],
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
return _create_resnest('resnest269e', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest269e', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
@ -233,8 +233,8 @@ def resnest50d_4s2x40d(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 4, 6, 3],
stem_type='deep', stem_width=32, avg_down=True, base_width=40, cardinality=2,
block_args=dict(radix=4, avd=True, avd_first=True), **kwargs)
return _create_resnest('resnest50d_4s2x40d', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=4, avd=True, avd_first=True))
return _create_resnest('resnest50d_4s2x40d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
@ -244,5 +244,5 @@ def resnest50d_1s4x24d(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 4, 6, 3],
stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4,
block_args=dict(radix=1, avd=True, avd_first=True), **kwargs)
return _create_resnest('resnest50d_1s4x24d', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=1, avd=True, avd_first=True))
return _create_resnest('resnest50d_1s4x24d', pretrained=pretrained, **dict(model_kwargs, **kwargs))

@ -704,7 +704,7 @@ class ResNet(nn.Module):
self.num_classes = num_classes
self.drop_rate = drop_rate
self.grad_checkpointing = False
act_layer = get_act_layer(act_layer)
norm_layer = get_norm_layer(norm_layer)
@ -845,77 +845,72 @@ def _create_resnet(variant, pretrained=False, **kwargs):
def resnet10t(pretrained=False, **kwargs):
"""Constructs a ResNet-10-T model.
"""
model_args = dict(
block=BasicBlock, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs)
return _create_resnet('resnet10t', pretrained, **model_args)
model_args = dict(block=BasicBlock, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True)
return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs))
@register_model
def resnet14t(pretrained=False, **kwargs):
"""Constructs a ResNet-14-T model.
"""
model_args = dict(
block=Bottleneck, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs)
return _create_resnet('resnet14t', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True)
return _create_resnet('resnet14t', pretrained, **dict(model_args, **kwargs))
@register_model
def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
"""
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs)
return _create_resnet('resnet18', pretrained, **model_args)
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2])
return _create_resnet('resnet18', pretrained, **dict(model_args, **kwargs))
@register_model
def resnet18d(pretrained=False, **kwargs):
"""Constructs a ResNet-18-D model.
"""
model_args = dict(
block=BasicBlock, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnet18d', pretrained, **model_args)
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnet18d', pretrained, **dict(model_args, **kwargs))
@register_model
def resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model.
"""
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)
return _create_resnet('resnet34', pretrained, **model_args)
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3])
return _create_resnet('resnet34', pretrained, **dict(model_args, **kwargs))
@register_model
def resnet34d(pretrained=False, **kwargs):
"""Constructs a ResNet-34-D model.
"""
model_args = dict(
block=BasicBlock, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnet34d', pretrained, **model_args)
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnet34d', pretrained, **dict(model_args, **kwargs))
@register_model
def resnet26(pretrained=False, **kwargs):
"""Constructs a ResNet-26 model.
"""
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], **kwargs)
return _create_resnet('resnet26', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2])
return _create_resnet('resnet26', pretrained, **dict(model_args, **kwargs))
@register_model
def resnet26t(pretrained=False, **kwargs):
"""Constructs a ResNet-26-T model.
"""
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs)
return _create_resnet('resnet26t', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep_tiered', avg_down=True)
return _create_resnet('resnet26t', pretrained, **dict(model_args, **kwargs))
@register_model
def resnet26d(pretrained=False, **kwargs):
"""Constructs a ResNet-26-D model.
"""
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnet26d', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnet26d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -923,83 +918,79 @@ def resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
return _create_resnet('resnet50', pretrained, **model_args)
return _create_resnet('resnet50', pretrained, **dict(model_args, **kwargs))
@register_model
def resnet50d(pretrained=False, **kwargs) -> ResNet:
"""Constructs a ResNet-50-D model.
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnet50d', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnet50d', pretrained, **dict(model_args, **kwargs))
@register_model
def resnet50t(pretrained=False, **kwargs):
"""Constructs a ResNet-50-T model.
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs)
return _create_resnet('resnet50t', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True)
return _create_resnet('resnet50t', pretrained, **dict(model_args, **kwargs))
@register_model
def resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs)
return _create_resnet('resnet101', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3])
return _create_resnet('resnet101', pretrained, **dict(model_args, **kwargs))
@register_model
def resnet101d(pretrained=False, **kwargs):
"""Constructs a ResNet-101-D model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnet101d', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnet101d', pretrained, **dict(model_args, **kwargs))
@register_model
def resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model.
"""
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs)
return _create_resnet('resnet152', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3])
return _create_resnet('resnet152', pretrained, **dict(model_args, **kwargs))
@register_model
def resnet152d(pretrained=False, **kwargs):
"""Constructs a ResNet-152-D model.
"""
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnet152d', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnet152d', pretrained, **dict(model_args, **kwargs))
@register_model
def resnet200(pretrained=False, **kwargs):
"""Constructs a ResNet-200 model.
"""
model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3], **kwargs)
return _create_resnet('resnet200', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3])
return _create_resnet('resnet200', pretrained, **dict(model_args, **kwargs))
@register_model
def resnet200d(pretrained=False, **kwargs):
"""Constructs a ResNet-200-D model.
"""
model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnet200d', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnet200d', pretrained, **dict(model_args, **kwargs))
@register_model
def tv_resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model with original Torchvision weights.
"""
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs)
return _create_resnet('tv_resnet34', pretrained, **model_args)
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3])
return _create_resnet('tv_resnet34', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1007,23 +998,23 @@ def tv_resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model with original Torchvision weights.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
return _create_resnet('tv_resnet50', pretrained, **model_args)
return _create_resnet('tv_resnet50', pretrained, **dict(model_args, **kwargs))
@register_model
def tv_resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model w/ Torchvision pretrained weights.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs)
return _create_resnet('tv_resnet101', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3])
return _create_resnet('tv_resnet101', pretrained, **dict(model_args, **kwargs))
@register_model
def tv_resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model w/ Torchvision pretrained weights.
"""
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs)
return _create_resnet('tv_resnet152', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3])
return _create_resnet('tv_resnet152', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1034,8 +1025,8 @@ def wide_resnet50_2(pretrained=False, **kwargs):
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128, **kwargs)
return _create_resnet('wide_resnet50_2', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128)
return _create_resnet('wide_resnet50_2', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1045,8 +1036,8 @@ def wide_resnet101_2(pretrained=False, **kwargs):
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], base_width=128, **kwargs)
return _create_resnet('wide_resnet101_2', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], base_width=128)
return _create_resnet('wide_resnet101_2', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1061,8 +1052,8 @@ def resnet50_gn(pretrained=False, **kwargs):
def resnext50_32x4d(pretrained=False, **kwargs):
"""Constructs a ResNeXt50-32x4d model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
return _create_resnet('resnext50_32x4d', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4)
return _create_resnet('resnext50_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1071,40 +1062,40 @@ def resnext50d_32x4d(pretrained=False, **kwargs):
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnext50d_32x4d', pretrained, **model_args)
stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnext50d_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model
def resnext101_32x4d(pretrained=False, **kwargs):
"""Constructs a ResNeXt-101 32x4d model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
return _create_resnet('resnext101_32x4d', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4)
return _create_resnet('resnext101_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model
def resnext101_32x8d(pretrained=False, **kwargs):
"""Constructs a ResNeXt-101 32x8d model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
return _create_resnet('resnext101_32x8d', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8)
return _create_resnet('resnext101_32x8d', pretrained, **dict(model_args, **kwargs))
@register_model
def resnext101_64x4d(pretrained=False, **kwargs):
"""Constructs a ResNeXt101-64x4d model.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, **kwargs)
return _create_resnet('resnext101_64x4d', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4)
return _create_resnet('resnext101_64x4d', pretrained, **dict(model_args, **kwargs))
@register_model
def tv_resnext50_32x4d(pretrained=False, **kwargs):
"""Constructs a ResNeXt50-32x4d model with original Torchvision weights.
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
return _create_resnet('tv_resnext50_32x4d', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4)
return _create_resnet('tv_resnext50_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1114,8 +1105,8 @@ def ig_resnext101_32x8d(pretrained=False, **kwargs):
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
return _create_resnet('ig_resnext101_32x8d', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8)
return _create_resnet('ig_resnext101_32x8d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1125,8 +1116,8 @@ def ig_resnext101_32x16d(pretrained=False, **kwargs):
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
return _create_resnet('ig_resnext101_32x16d', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16)
return _create_resnet('ig_resnext101_32x16d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1136,8 +1127,8 @@ def ig_resnext101_32x32d(pretrained=False, **kwargs):
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=32, **kwargs)
return _create_resnet('ig_resnext101_32x32d', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=32)
return _create_resnet('ig_resnext101_32x32d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1147,8 +1138,8 @@ def ig_resnext101_32x48d(pretrained=False, **kwargs):
`"Exploring the Limits of Weakly Supervised Pretraining" <https://arxiv.org/abs/1805.00932>`_
Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=48, **kwargs)
return _create_resnet('ig_resnext101_32x48d', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=48)
return _create_resnet('ig_resnext101_32x48d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1157,8 +1148,8 @@ def ssl_resnet18(pretrained=False, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs)
return _create_resnet('ssl_resnet18', pretrained, **model_args)
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2])
return _create_resnet('ssl_resnet18', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1168,7 +1159,7 @@ def ssl_resnet50(pretrained=False, **kwargs):
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
return _create_resnet('ssl_resnet50', pretrained, **model_args)
return _create_resnet('ssl_resnet50', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1177,8 +1168,8 @@ def ssl_resnext50_32x4d(pretrained=False, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
return _create_resnet('ssl_resnext50_32x4d', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4)
return _create_resnet('ssl_resnext50_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1187,8 +1178,8 @@ def ssl_resnext101_32x4d(pretrained=False, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
return _create_resnet('ssl_resnext101_32x4d', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4)
return _create_resnet('ssl_resnext101_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1197,8 +1188,8 @@ def ssl_resnext101_32x8d(pretrained=False, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
return _create_resnet('ssl_resnext101_32x8d', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8)
return _create_resnet('ssl_resnext101_32x8d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1207,8 +1198,8 @@ def ssl_resnext101_32x16d(pretrained=False, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
return _create_resnet('ssl_resnext101_32x16d', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16)
return _create_resnet('ssl_resnext101_32x16d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1218,8 +1209,8 @@ def swsl_resnet18(pretrained=False, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs)
return _create_resnet('swsl_resnet18', pretrained, **model_args)
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2])
return _create_resnet('swsl_resnet18', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1230,7 +1221,7 @@ def swsl_resnet50(pretrained=False, **kwargs):
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs)
return _create_resnet('swsl_resnet50', pretrained, **model_args)
return _create_resnet('swsl_resnet50', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1240,8 +1231,8 @@ def swsl_resnext50_32x4d(pretrained=False, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs)
return _create_resnet('swsl_resnext50_32x4d', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4)
return _create_resnet('swsl_resnext50_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1251,8 +1242,8 @@ def swsl_resnext101_32x4d(pretrained=False, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs)
return _create_resnet('swsl_resnext101_32x4d', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4)
return _create_resnet('swsl_resnext101_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1262,8 +1253,8 @@ def swsl_resnext101_32x8d(pretrained=False, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs)
return _create_resnet('swsl_resnext101_32x8d', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8)
return _create_resnet('swsl_resnext101_32x8d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1273,8 +1264,8 @@ def swsl_resnext101_32x16d(pretrained=False, **kwargs):
`"Billion-scale Semi-Supervised Learning for Image Classification" <https://arxiv.org/abs/1905.00546>`_
Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs)
return _create_resnet('swsl_resnext101_32x16d', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16)
return _create_resnet('swsl_resnext101_32x16d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1285,8 +1276,8 @@ def ecaresnet26t(pretrained=False, **kwargs):
"""
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32,
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnet26t', pretrained, **model_args)
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet26t', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1295,8 +1286,8 @@ def ecaresnet50d(pretrained=False, **kwargs):
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnet50d', pretrained, **model_args)
block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet50d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1306,8 +1297,8 @@ def ecaresnet50d_pruned(pretrained=False, **kwargs):
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **model_args)
block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **dict(model_args, **kwargs))
@register_model
@ -1317,8 +1308,8 @@ def ecaresnet50t(pretrained=False, **kwargs):
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32,
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnet50t', pretrained, **model_args)
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet50t', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1327,8 +1318,8 @@ def ecaresnetlight(pretrained=False, **kwargs):
"""
model_args = dict(
block=Bottleneck, layers=[1, 1, 11, 3], stem_width=32, avg_down=True,
block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnetlight', pretrained, **model_args)
block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnetlight', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1337,8 +1328,8 @@ def ecaresnet101d(pretrained=False, **kwargs):
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnet101d', pretrained, **model_args)
block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet101d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1348,8 +1339,8 @@ def ecaresnet101d_pruned(pretrained=False, **kwargs):
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **model_args)
block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **dict(model_args, **kwargs))
@register_model
@ -1358,8 +1349,8 @@ def ecaresnet200d(pretrained=False, **kwargs):
"""
model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnet200d', pretrained, **model_args)
block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet200d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1368,8 +1359,8 @@ def ecaresnet269d(pretrained=False, **kwargs):
"""
model_args = dict(
block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnet269d', pretrained, **model_args)
block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnet269d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1380,8 +1371,8 @@ def ecaresnext26t_32x4d(pretrained=False, **kwargs):
"""
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnext26t_32x4d', pretrained, **model_args)
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnext26t_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1392,54 +1383,54 @@ def ecaresnext50t_32x4d(pretrained=False, **kwargs):
"""
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs)
return _create_resnet('ecaresnext50t_32x4d', pretrained, **model_args)
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'))
return _create_resnet('ecaresnext50t_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model
def seresnet18(pretrained=False, **kwargs):
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnet18', pretrained, **model_args)
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'))
return _create_resnet('seresnet18', pretrained, **dict(model_args, **kwargs))
@register_model
def seresnet34(pretrained=False, **kwargs):
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnet34', pretrained, **model_args)
model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'))
return _create_resnet('seresnet34', pretrained, **dict(model_args, **kwargs))
@register_model
def seresnet50(pretrained=False, **kwargs):
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnet50', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'))
return _create_resnet('seresnet50', pretrained, **dict(model_args, **kwargs))
@register_model
def seresnet50t(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True,
block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnet50t', pretrained, **model_args)
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered',
avg_down=True, block_args=dict(attn_layer='se'))
return _create_resnet('seresnet50t', pretrained, **dict(model_args, **kwargs))
@register_model
def seresnet101(pretrained=False, **kwargs):
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnet101', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], block_args=dict(attn_layer='se'))
return _create_resnet('seresnet101', pretrained, **dict(model_args, **kwargs))
@register_model
def seresnet152(pretrained=False, **kwargs):
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnet152', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], block_args=dict(attn_layer='se'))
return _create_resnet('seresnet152', pretrained, **dict(model_args, **kwargs))
@register_model
def seresnet152d(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnet152d', pretrained, **model_args)
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep',
avg_down=True, block_args=dict(attn_layer='se'))
return _create_resnet('seresnet152d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1447,9 +1438,9 @@ def seresnet200d(pretrained=False, **kwargs):
"""Constructs a ResNet-200-D model with SE attn.
"""
model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnet200d', pretrained, **model_args)
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep',
avg_down=True, block_args=dict(attn_layer='se'))
return _create_resnet('seresnet200d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1457,9 +1448,9 @@ def seresnet269d(pretrained=False, **kwargs):
"""Constructs a ResNet-269-D model with SE attn.
"""
model_args = dict(
block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnet269d', pretrained, **model_args)
block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep',
avg_down=True, block_args=dict(attn_layer='se'))
return _create_resnet('seresnet269d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1470,8 +1461,8 @@ def seresnext26d_32x4d(pretrained=False, **kwargs):
"""
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnext26d_32x4d', pretrained, **model_args)
stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'))
return _create_resnet('seresnext26d_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1482,8 +1473,8 @@ def seresnext26t_32x4d(pretrained=False, **kwargs):
"""
model_args = dict(
block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32,
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnext26t_32x4d', pretrained, **model_args)
stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se'))
return _create_resnet('seresnext26t_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1499,24 +1490,24 @@ def seresnext26tn_32x4d(pretrained=False, **kwargs):
def seresnext50_32x4d(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnext50_32x4d', pretrained, **model_args)
block_args=dict(attn_layer='se'))
return _create_resnet('seresnext50_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model
def seresnext101_32x4d(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4,
block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnext101_32x4d', pretrained, **model_args)
block_args=dict(attn_layer='se'))
return _create_resnet('seresnext101_32x4d', pretrained, **dict(model_args, **kwargs))
@register_model
def seresnext101_32x8d(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8,
block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnext101_32x8d', pretrained, **model_args)
block_args=dict(attn_layer='se'))
return _create_resnet('seresnext101_32x8d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1524,32 +1515,32 @@ def seresnext101d_32x8d(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8,
stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnext101d_32x8d', pretrained, **model_args)
block_args=dict(attn_layer='se'))
return _create_resnet('seresnext101d_32x8d', pretrained, **dict(model_args, **kwargs))
@register_model
def senet154(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep',
down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('senet154', pretrained, **model_args)
down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se'))
return _create_resnet('senet154', pretrained, **dict(model_args, **kwargs))
@register_model
def resnetblur18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model with blur anti-aliasing
"""
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d, **kwargs)
return _create_resnet('resnetblur18', pretrained, **model_args)
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d)
return _create_resnet('resnetblur18', pretrained, **dict(model_args, **kwargs))
@register_model
def resnetblur50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model with blur anti-aliasing
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, **kwargs)
return _create_resnet('resnetblur50', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d)
return _create_resnet('resnetblur50', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1558,8 +1549,8 @@ def resnetblur50d(pretrained=False, **kwargs):
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetblur50d', pretrained, **model_args)
stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnetblur50d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1568,16 +1559,25 @@ def resnetblur101d(pretrained=False, **kwargs):
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=BlurPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetblur101d', pretrained, **model_args)
stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnetblur101d', pretrained, **dict(model_args, **kwargs))
@register_model
def resnetaa34d(pretrained=False, **kwargs):
"""Constructs a ResNet-34-D model w/ avgpool anti-aliasing
"""
model_args = dict(
block=BasicBlock, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnetaa34d', pretrained, **dict(model_args, **kwargs))
@register_model
def resnetaa50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model with avgpool anti-aliasing
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, **kwargs)
return _create_resnet('resnetaa50', pretrained, **model_args)
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d)
return _create_resnet('resnetaa50', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1586,8 +1586,8 @@ def resnetaa50d(pretrained=False, **kwargs):
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetaa50d', pretrained, **model_args)
stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnetaa50d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1596,8 +1596,8 @@ def resnetaa101d(pretrained=False, **kwargs):
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetaa101d', pretrained, **model_args)
stem_width=32, stem_type='deep', avg_down=True)
return _create_resnet('resnetaa101d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1606,8 +1606,8 @@ def seresnetaa50d(pretrained=False, **kwargs):
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnetaa50d', pretrained, **model_args)
stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'))
return _create_resnet('seresnetaa50d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1617,8 +1617,8 @@ def seresnextaa101d_32x8d(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8,
stem_width=32, stem_type='deep', avg_down=True, aa_layer=nn.AvgPool2d,
block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnextaa101d_32x8d', pretrained, **model_args)
block_args=dict(attn_layer='se'))
return _create_resnet('seresnextaa101d_32x8d', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1630,8 +1630,8 @@ def resnetrs50(pretrained=False, **kwargs):
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs50', pretrained, **model_args)
avg_down=True, block_args=dict(attn_layer=attn_layer))
return _create_resnet('resnetrs50', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1643,8 +1643,8 @@ def resnetrs101(pretrained=False, **kwargs):
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs101', pretrained, **model_args)
avg_down=True, block_args=dict(attn_layer=attn_layer))
return _create_resnet('resnetrs101', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1656,8 +1656,8 @@ def resnetrs152(pretrained=False, **kwargs):
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs152', pretrained, **model_args)
avg_down=True, block_args=dict(attn_layer=attn_layer))
return _create_resnet('resnetrs152', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1669,8 +1669,8 @@ def resnetrs200(pretrained=False, **kwargs):
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs200', pretrained, **model_args)
avg_down=True, block_args=dict(attn_layer=attn_layer))
return _create_resnet('resnetrs200', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1682,8 +1682,8 @@ def resnetrs270(pretrained=False, **kwargs):
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs270', pretrained, **model_args)
avg_down=True, block_args=dict(attn_layer=attn_layer))
return _create_resnet('resnetrs270', pretrained, **dict(model_args, **kwargs))
@ -1696,8 +1696,8 @@ def resnetrs350(pretrained=False, **kwargs):
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs350', pretrained, **model_args)
avg_down=True, block_args=dict(attn_layer=attn_layer))
return _create_resnet('resnetrs350', pretrained, **dict(model_args, **kwargs))
@register_model
@ -1709,5 +1709,5 @@ def resnetrs420(pretrained=False, **kwargs):
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs420', pretrained, **model_args)
avg_down=True, block_args=dict(attn_layer=attn_layer))
return _create_resnet('resnetrs420', pretrained, **dict(model_args, **kwargs))

@ -746,86 +746,83 @@ def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs):
@register_model
def resnetv2_50(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
model_args = dict(layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d)
return _create_resnetv2('resnetv2_50', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_50d(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d', pretrained=pretrained,
model_args = dict(
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
stem_type='deep', avg_down=True, **kwargs)
stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_50d', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_50t(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50t', pretrained=pretrained,
model_args = dict(
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
stem_type='tiered', avg_down=True, **kwargs)
stem_type='tiered', avg_down=True)
return _create_resnetv2('resnetv2_50t', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_101(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101', pretrained=pretrained,
layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
model_args = dict(layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d)
return _create_resnetv2('resnetv2_101', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_101d(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101d', pretrained=pretrained,
model_args = dict(
layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
stem_type='deep', avg_down=True, **kwargs)
stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_101d', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_152(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152', pretrained=pretrained,
layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
model_args = dict(layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d)
return _create_resnetv2('resnetv2_152', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_152d(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152d', pretrained=pretrained,
model_args = dict(
layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
stem_type='deep', avg_down=True, **kwargs)
stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_152d', pretrained=pretrained, **dict(model_args, **kwargs))
# Experimental configs (may change / be removed)
@register_model
def resnetv2_50d_gn(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d_gn', pretrained=pretrained,
model_args = dict(
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=GroupNormAct,
stem_type='deep', avg_down=True, **kwargs)
stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_50d_gn', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_50d_evob(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d_evob', pretrained=pretrained,
model_args = dict(
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dB0,
stem_type='deep', avg_down=True, zero_init_last=True, **kwargs)
stem_type='deep', avg_down=True, zero_init_last=True)
return _create_resnetv2('resnetv2_50d_evob', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_50d_evos(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d_evos', pretrained=pretrained,
model_args = dict(
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dS0,
stem_type='deep', avg_down=True, **kwargs)
stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_50d_evos', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_50d_frn(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d_frn', pretrained=pretrained,
model_args = dict(
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=FilterResponseNormTlu2d,
stem_type='deep', avg_down=True, **kwargs)
stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_50d_frn', pretrained=pretrained, **dict(model_args, **kwargs))

@ -1029,6 +1029,10 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_gigantic_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),
'vit_base_patch32_clip_224.openai': _cfg(
hf_hub_id='timm/',
@ -1498,6 +1502,17 @@ def vit_giant_patch14_clip_224(pretrained=False, **kwargs):
return model
@register_model
def vit_gigantic_patch14_clip_224(pretrained=False, **kwargs):
""" ViT-bigG model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
Pretrained weights from CLIP image tower.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_gigantic_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
# Experimental models below
@register_model

@ -216,7 +216,7 @@ class XceptionAligned(nn.Module):
num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))]
self.act = act_layer(inplace=True) if preact else nn.Identity()
self.head = ClassifierHead(
in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
@torch.jit.ignore
def group_matcher(self, coarse=False):

@ -7,6 +7,8 @@ import fnmatch
import torch
from torchvision.ops.misc import FrozenBatchNorm2d
from timm.layers import BatchNormAct2d, SyncBatchNormAct, FrozenBatchNormAct2d,\
freeze_batch_norm_2d, unfreeze_batch_norm_2d
from .model_ema import ModelEma
@ -100,70 +102,6 @@ def extract_spp_stats(
return hook.stats
def freeze_batch_norm_2d(module):
"""
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
returned. Otherwise, the module is walked recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
res = FrozenBatchNorm2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = freeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def unfreeze_batch_norm_2d(module):
"""
Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance
of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked
recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
if isinstance(module, FrozenBatchNorm2d):
res = torch.nn.BatchNorm2d(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = unfreeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'):
"""
Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is
@ -179,7 +117,12 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
"""
assert mode in ["freeze", "unfreeze"], '`mode` must be one of "freeze" or "unfreeze"'
if isinstance(root_module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
if isinstance(root_module, (
torch.nn.modules.batchnorm.BatchNorm2d,
torch.nn.modules.batchnorm.SyncBatchNorm,
BatchNormAct2d,
SyncBatchNormAct,
)):
# Raise assertion here because we can't convert it in place
raise AssertionError(
"You have provided a batch norm layer as the `root module`. Please use "
@ -213,13 +156,18 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
# It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't
# convert it in place, but will return the converted result. In this case `res` holds the converted
# result and we may try to re-assign the named module
if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
if isinstance(m, (
torch.nn.modules.batchnorm.BatchNorm2d,
torch.nn.modules.batchnorm.SyncBatchNorm,
BatchNormAct2d,
SyncBatchNormAct,
)):
_add_submodule(root_module, n, res)
# Unfreeze batch norm
else:
res = unfreeze_batch_norm_2d(m)
# Ditto. See note above in mode == 'freeze' branch
if isinstance(m, FrozenBatchNorm2d):
if isinstance(m, (FrozenBatchNorm2d, FrozenBatchNormAct2d)):
_add_submodule(root_module, n, res)

@ -1 +1 @@
__version__ = '0.8.5dev0'
__version__ = '0.8.8dev0'

Loading…
Cancel
Save