Merge branch 'fffffgggg54-main'

pull/1654/head
Ross Wightman 2 years ago
commit 29fda20e6d

@ -40,9 +40,10 @@ jobs:
- name: Install torch on ubuntu
if: startsWith(matrix.os, 'ubuntu')
run: |
pip install --no-cache-dir torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
sudo sed -i 's/azure\.//' /etc/apt/sources.list
sudo apt update
sudo apt install -y google-perftools
pip install --no-cache-dir torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install requirements
run: |
pip install -r requirements.txt

@ -21,12 +21,73 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
## What's New
### 🤗 Survey: Feedback Appreciated 🤗
For a few months now, `timm` has been part of the Hugging Face ecosystem. Yearly, we survey users of our tools to see what we could do better, what we need to continue doing, or what we need to stop doing.
If you have a couple of minutes and want to participate in shaping the future of the ecosystem, please share your thoughts:
[**hf.co/oss-survey**](https://hf.co/oss-survey) 🙏
* ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗
* Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch.
### Jan 20, 2023
* Add two convnext 12k -> 1k fine-tunes at 384x384
* `convnext_tiny.in12k_ft_in1k_384` - 85.1 @ 384
* `convnext_small.in12k_ft_in1k_384` - 86.2 @ 384
* Push all MaxxViT weights to HF hub, and add new ImageNet-12k -> 1k fine-tunes for `rw` base MaxViT and CoAtNet 1/2 models
|model |top1 |top5 |samples / sec |Params (M) |GMAC |Act (M)|
|------------------------------------------------------------------------------------------------------------------------|----:|----:|--------------:|--------------:|-----:|------:|
|[maxvit_xlarge_tf_512.in21k_ft_in1k](https://huggingface.co/timm/maxvit_xlarge_tf_512.in21k_ft_in1k) |88.53|98.64| 21.76| 475.77|534.14|1413.22|
|[maxvit_xlarge_tf_384.in21k_ft_in1k](https://huggingface.co/timm/maxvit_xlarge_tf_384.in21k_ft_in1k) |88.32|98.54| 42.53| 475.32|292.78| 668.76|
|[maxvit_base_tf_512.in21k_ft_in1k](https://huggingface.co/timm/maxvit_base_tf_512.in21k_ft_in1k) |88.20|98.53| 50.87| 119.88|138.02| 703.99|
|[maxvit_large_tf_512.in21k_ft_in1k](https://huggingface.co/timm/maxvit_large_tf_512.in21k_ft_in1k) |88.04|98.40| 36.42| 212.33|244.75| 942.15|
|[maxvit_large_tf_384.in21k_ft_in1k](https://huggingface.co/timm/maxvit_large_tf_384.in21k_ft_in1k) |87.98|98.56| 71.75| 212.03|132.55| 445.84|
|[maxvit_base_tf_384.in21k_ft_in1k](https://huggingface.co/timm/maxvit_base_tf_384.in21k_ft_in1k) |87.92|98.54| 104.71| 119.65| 73.80| 332.90|
|[maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k](https://huggingface.co/timm/maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k) |87.81|98.37| 106.55| 116.14| 70.97| 318.95|
|[maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k](https://huggingface.co/timm/maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k) |87.47|98.37| 149.49| 116.09| 72.98| 213.74|
|[coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k) |87.39|98.31| 160.80| 73.88| 47.69| 209.43|
|[maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k) |86.89|98.02| 375.86| 116.14| 23.15| 92.64|
|[maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k) |86.64|98.02| 501.03| 116.09| 24.20| 62.77|
|[maxvit_base_tf_512.in1k](https://huggingface.co/timm/maxvit_base_tf_512.in1k) |86.60|97.92| 50.75| 119.88|138.02| 703.99|
|[coatnet_2_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_2_rw_224.sw_in12k_ft_in1k) |86.57|97.89| 631.88| 73.87| 15.09| 49.22|
|[maxvit_large_tf_512.in1k](https://huggingface.co/timm/maxvit_large_tf_512.in1k) |86.52|97.88| 36.04| 212.33|244.75| 942.15|
|[coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k) |86.49|97.90| 620.58| 73.88| 15.18| 54.78|
|[maxvit_base_tf_384.in1k](https://huggingface.co/timm/maxvit_base_tf_384.in1k) |86.29|97.80| 101.09| 119.65| 73.80| 332.90|
|[maxvit_large_tf_384.in1k](https://huggingface.co/timm/maxvit_large_tf_384.in1k) |86.23|97.69| 70.56| 212.03|132.55| 445.84|
|[maxvit_small_tf_512.in1k](https://huggingface.co/timm/maxvit_small_tf_512.in1k) |86.10|97.76| 88.63| 69.13| 67.26| 383.77|
|[maxvit_tiny_tf_512.in1k](https://huggingface.co/timm/maxvit_tiny_tf_512.in1k) |85.67|97.58| 144.25| 31.05| 33.49| 257.59|
|[maxvit_small_tf_384.in1k](https://huggingface.co/timm/maxvit_small_tf_384.in1k) |85.54|97.46| 188.35| 69.02| 35.87| 183.65|
|[maxvit_tiny_tf_384.in1k](https://huggingface.co/timm/maxvit_tiny_tf_384.in1k) |85.11|97.38| 293.46| 30.98| 17.53| 123.42|
|[maxvit_large_tf_224.in1k](https://huggingface.co/timm/maxvit_large_tf_224.in1k) |84.93|96.97| 247.71| 211.79| 43.68| 127.35|
|[coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k) |84.90|96.96| 1025.45| 41.72| 8.11| 40.13|
|[maxvit_base_tf_224.in1k](https://huggingface.co/timm/maxvit_base_tf_224.in1k) |84.85|96.99| 358.25| 119.47| 24.04| 95.01|
|[maxxvit_rmlp_small_rw_256.sw_in1k](https://huggingface.co/timm/maxxvit_rmlp_small_rw_256.sw_in1k) |84.63|97.06| 575.53| 66.01| 14.67| 58.38|
|[coatnet_rmlp_2_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_rmlp_2_rw_224.sw_in1k) |84.61|96.74| 625.81| 73.88| 15.18| 54.78|
|[maxvit_rmlp_small_rw_224.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_small_rw_224.sw_in1k) |84.49|96.76| 693.82| 64.90| 10.75| 49.30|
|[maxvit_small_tf_224.in1k](https://huggingface.co/timm/maxvit_small_tf_224.in1k) |84.43|96.83| 647.96| 68.93| 11.66| 53.17|
|[maxvit_rmlp_tiny_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_tiny_rw_256.sw_in1k) |84.23|96.78| 807.21| 29.15| 6.77| 46.92|
|[coatnet_1_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_1_rw_224.sw_in1k) |83.62|96.38| 989.59| 41.72| 8.04| 34.60|
|[maxvit_tiny_rw_224.sw_in1k](https://huggingface.co/timm/maxvit_tiny_rw_224.sw_in1k) |83.50|96.50| 1100.53| 29.06| 5.11| 33.11|
|[maxvit_tiny_tf_224.in1k](https://huggingface.co/timm/maxvit_tiny_tf_224.in1k) |83.41|96.59| 1004.94| 30.92| 5.60| 35.78|
|[coatnet_rmlp_1_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_rmlp_1_rw_224.sw_in1k) |83.36|96.45| 1093.03| 41.69| 7.85| 35.47|
|[maxxvitv2_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxxvitv2_nano_rw_256.sw_in1k) |83.11|96.33| 1276.88| 23.70| 6.26| 23.05|
|[maxxvit_rmlp_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxxvit_rmlp_nano_rw_256.sw_in1k) |83.03|96.34| 1341.24| 16.78| 4.37| 26.05|
|[maxvit_rmlp_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_nano_rw_256.sw_in1k) |82.96|96.26| 1283.24| 15.50| 4.47| 31.92|
|[maxvit_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_nano_rw_256.sw_in1k) |82.93|96.23| 1218.17| 15.45| 4.46| 30.28|
|[coatnet_bn_0_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_bn_0_rw_224.sw_in1k) |82.39|96.19| 1600.14| 27.44| 4.67| 22.04|
|[coatnet_0_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_0_rw_224.sw_in1k) |82.39|95.84| 1831.21| 27.44| 4.43| 18.73|
|[coatnet_rmlp_nano_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_rmlp_nano_rw_224.sw_in1k) |82.05|95.87| 2109.09| 15.15| 2.62| 20.34|
|[coatnext_nano_rw_224.sw_in1k](https://huggingface.co/timm/coatnext_nano_rw_224.sw_in1k) |81.95|95.92| 2525.52| 14.70| 2.47| 12.80|
|[coatnet_nano_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_nano_rw_224.sw_in1k) |81.70|95.64| 2344.52| 15.14| 2.41| 15.41|
|[maxvit_rmlp_pico_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_pico_rw_256.sw_in1k) |80.53|95.21| 1594.71| 7.52| 1.85| 24.86|
### Jan 11, 2023
* Update ConvNeXt ImageNet-12k pretrain series w/ two new fine-tuned weights (and pre FT `.in12k` tags)
* `convnext_nano.in12k_ft_in1k` - 82.3 @ 224, 82.9 @ 288 (previously released)
* `convnext_tiny.in12k_ft_in1k` - 84.2 @ 224, 84.5 @ 288
* `convnext_small.in12k_ft_in1k` - 85.2 @ 224, 85.3 @ 288
### Jan 6, 2023
* Finally got around to adding `--model-kwargs` and `--opt-kwargs` to scripts to pass through rare args directly to model classes from cmd line
* `train.py /imagenet --model resnet50 --amp --model-kwargs output_stride=16 act_layer=silu`
* `train.py /imagenet --model vit_base_patch16_clip_224 --img-size 240 --amp --model-kwargs img_size=240 patch_size=12`
* Cleanup some popular models to better support arg passthrough / merge with model configs, more to go.
### Jan 5, 2023
* ConvNeXt-V2 models and weights added to existing `convnext.py`

@ -22,7 +22,7 @@ from timm.data import resolve_data_config
from timm.layers import set_fast_norm
from timm.models import create_model, is_model, list_models
from timm.optim import create_optimizer_v2
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs
has_apex = False
try:
@ -108,12 +108,15 @@ parser.add_argument('--grad-checkpointing', action='store_true', default=False,
help='Enable gradient checkpointing through model blocks/stages')
parser.add_argument('--amp', action='store_true', default=False,
help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.')
parser.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16). Overrides --precision arg if args.amp True.')
parser.add_argument('--precision', default='float32', type=str,
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
# codegen (model compilation) options
scripting_group = parser.add_mutually_exclusive_group()
@ -124,7 +127,6 @@ scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd optimization.")
# train optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"')
@ -168,19 +170,21 @@ def count_params(model: nn.Module):
def resolve_precision(precision: str):
assert precision in ('amp', 'float16', 'bfloat16', 'float32')
use_amp = False
assert precision in ('amp', 'amp_bfloat16', 'float16', 'bfloat16', 'float32')
amp_dtype = None # amp disabled
model_dtype = torch.float32
data_dtype = torch.float32
if precision == 'amp':
use_amp = True
amp_dtype = torch.float16
elif precision == 'amp_bfloat16':
amp_dtype = torch.bfloat16
elif precision == 'float16':
model_dtype = torch.float16
data_dtype = torch.float16
elif precision == 'bfloat16':
model_dtype = torch.bfloat16
data_dtype = torch.bfloat16
return use_amp, model_dtype, data_dtype
return amp_dtype, model_dtype, data_dtype
def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False):
@ -228,9 +232,12 @@ class BenchmarkRunner:
self.model_name = model_name
self.detail = detail
self.device = device
self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision)
self.amp_dtype, self.model_dtype, self.data_dtype = resolve_precision(precision)
self.channels_last = kwargs.pop('channels_last', False)
self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=torch.float16) if self.use_amp else suppress
if self.amp_dtype is not None:
self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=self.amp_dtype)
else:
self.amp_autocast = suppress
if fuser:
set_jit_fuser(fuser)
@ -243,6 +250,7 @@ class BenchmarkRunner:
drop_rate=kwargs.pop('drop', 0.),
drop_path_rate=kwargs.pop('drop_path', None),
drop_block_rate=kwargs.pop('drop_block', None),
**kwargs.pop('model_kwargs', {}),
)
self.model.to(
device=self.device,
@ -560,7 +568,7 @@ def _try_run(
def benchmark(args):
if args.amp:
_logger.warning("Overriding precision to 'amp' since --amp flag set.")
args.precision = 'amp'
args.precision = 'amp' if args.amp_dtype == 'float16' else '_'.join(['amp', args.amp_dtype])
_logger.info(f'Benchmarking in {args.precision} precision. '
f'{"NHWC" if args.channels_last else "NCHW"} layout. '
f'torchscript {"enabled" if args.torchscript else "disabled"}')

@ -20,7 +20,7 @@ import torch
from timm.data import create_dataset, create_loader, resolve_data_config
from timm.layers import apply_test_time_pool
from timm.models import create_model
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs
try:
from apex import amp
@ -72,6 +72,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--in-chans', type=int, default=None, metavar='N',
help='Image input channels (default: None => 3)')
parser.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--use-train-size', action='store_true', default=False,
@ -110,6 +112,7 @@ parser.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', default=False, action='store_true',
@ -170,12 +173,19 @@ def main():
set_jit_fuser(args.fuser)
# create model
in_chans = 3
if args.in_chans is not None:
in_chans = args.in_chans
elif args.input_size is not None:
in_chans = args.input_size[0]
model = create_model(
args.model,
num_classes=args.num_classes,
in_chans=3,
in_chans=in_chans,
pretrained=args.pretrained,
checkpoint_path=args.checkpoint,
**args.model_kwargs,
)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'

@ -38,7 +38,7 @@ An ImageNet test set of 10,000 images sampled from new images roughly 10 years a
### ImageNet-Adversarial - [`results-imagenet-a.csv`](results-imagenet-a.csv)
A collection of 7500 images covering 200 of the 1000 ImageNet classes. Images are naturally occuring adversarial examples that confuse typical ImageNet classifiers. This is a challenging dataset, your typical ResNet-50 will score 0% top-1.
A collection of 7500 images covering 200 of the 1000 ImageNet classes. Images are naturally occurring adversarial examples that confuse typical ImageNet classifiers. This is a challenging dataset, your typical ResNet-50 will score 0% top-1.
For clean validation with same 200 classes, see [`results-imagenet-a-clean.csv`](results-imagenet-a-clean.csv)

@ -27,7 +27,7 @@ NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*', 'flexivit*'
'eva_*', 'flexivit*'
]
NUM_NON_STD = len(NON_STD_FILTERS)
@ -38,7 +38,7 @@ if 'GITHUB_ACTIONS' in os.environ:
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
'swin*giant*', 'convnextv2_huge*']
'swin*giant*', 'convnextv2_huge*', 'maxvit_xlarge*', 'davit_giant', 'davit_huge']
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*']
else:
EXCLUDE_FILTERS = []
@ -53,7 +53,7 @@ MAX_JIT_SIZE = 320
TARGET_FFEAT_SIZE = 96
MAX_FFEAT_SIZE = 256
TARGET_FWD_FX_SIZE = 128
MAX_FWD_FX_SIZE = 224
MAX_FWD_FX_SIZE = 256
TARGET_BWD_FX_SIZE = 128
MAX_BWD_FX_SIZE = 224
@ -269,7 +269,7 @@ if 'GITHUB_ACTIONS' not in os.environ:
EXCLUDE_JIT_FILTERS = [
'*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable
'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point
'dla*', 'hrnet*', 'ghostnet*' # hopefully fix at some point
'vit_large_*', 'vit_huge_*', 'vit_gi*',
]

@ -1,6 +1,6 @@
from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\
rand_augment_transform, auto_augment_transform
from .config import resolve_data_config
from .config import resolve_data_config, resolve_model_data_config
from .constants import *
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset

@ -6,16 +6,18 @@ _logger = logging.getLogger(__name__)
def resolve_data_config(
args,
default_cfg=None,
args=None,
pretrained_cfg=None,
model=None,
use_test_size=False,
verbose=False
):
new_config = {}
default_cfg = default_cfg or {}
if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
default_cfg = model.default_cfg
assert model or args or pretrained_cfg, "At least one of model, args, or pretrained_cfg required for data config."
args = args or {}
pretrained_cfg = pretrained_cfg or {}
if not pretrained_cfg and model is not None and hasattr(model, 'pretrained_cfg'):
pretrained_cfg = model.pretrained_cfg
data_config = {}
# Resolve input/image size
in_chans = 3
@ -32,65 +34,94 @@ def resolve_data_config(
assert isinstance(args['img_size'], int)
input_size = (in_chans, args['img_size'], args['img_size'])
else:
if use_test_size and default_cfg.get('test_input_size', None) is not None:
input_size = default_cfg['test_input_size']
elif default_cfg.get('input_size', None) is not None:
input_size = default_cfg['input_size']
new_config['input_size'] = input_size
if use_test_size and pretrained_cfg.get('test_input_size', None) is not None:
input_size = pretrained_cfg['test_input_size']
elif pretrained_cfg.get('input_size', None) is not None:
input_size = pretrained_cfg['input_size']
data_config['input_size'] = input_size
# resolve interpolation method
new_config['interpolation'] = 'bicubic'
data_config['interpolation'] = 'bicubic'
if args.get('interpolation', None):
new_config['interpolation'] = args['interpolation']
elif default_cfg.get('interpolation', None):
new_config['interpolation'] = default_cfg['interpolation']
data_config['interpolation'] = args['interpolation']
elif pretrained_cfg.get('interpolation', None):
data_config['interpolation'] = pretrained_cfg['interpolation']
# resolve dataset + model mean for normalization
new_config['mean'] = IMAGENET_DEFAULT_MEAN
data_config['mean'] = IMAGENET_DEFAULT_MEAN
if args.get('mean', None) is not None:
mean = tuple(args['mean'])
if len(mean) == 1:
mean = tuple(list(mean) * in_chans)
else:
assert len(mean) == in_chans
new_config['mean'] = mean
elif default_cfg.get('mean', None):
new_config['mean'] = default_cfg['mean']
data_config['mean'] = mean
elif pretrained_cfg.get('mean', None):
data_config['mean'] = pretrained_cfg['mean']
# resolve dataset + model std deviation for normalization
new_config['std'] = IMAGENET_DEFAULT_STD
data_config['std'] = IMAGENET_DEFAULT_STD
if args.get('std', None) is not None:
std = tuple(args['std'])
if len(std) == 1:
std = tuple(list(std) * in_chans)
else:
assert len(std) == in_chans
new_config['std'] = std
elif default_cfg.get('std', None):
new_config['std'] = default_cfg['std']
data_config['std'] = std
elif pretrained_cfg.get('std', None):
data_config['std'] = pretrained_cfg['std']
# resolve default inference crop
crop_pct = DEFAULT_CROP_PCT
if args.get('crop_pct', None):
crop_pct = args['crop_pct']
else:
if use_test_size and default_cfg.get('test_crop_pct', None):
crop_pct = default_cfg['test_crop_pct']
elif default_cfg.get('crop_pct', None):
crop_pct = default_cfg['crop_pct']
new_config['crop_pct'] = crop_pct
if use_test_size and pretrained_cfg.get('test_crop_pct', None):
crop_pct = pretrained_cfg['test_crop_pct']
elif pretrained_cfg.get('crop_pct', None):
crop_pct = pretrained_cfg['crop_pct']
data_config['crop_pct'] = crop_pct
# resolve default crop percentage
crop_mode = DEFAULT_CROP_MODE
if args.get('crop_mode', None):
crop_mode = args['crop_mode']
elif default_cfg.get('crop_mode', None):
crop_mode = default_cfg['crop_mode']
new_config['crop_mode'] = crop_mode
elif pretrained_cfg.get('crop_mode', None):
crop_mode = pretrained_cfg['crop_mode']
data_config['crop_mode'] = crop_mode
if verbose:
_logger.info('Data processing configuration for current model + dataset:')
for n, v in new_config.items():
for n, v in data_config.items():
_logger.info('\t%s: %s' % (n, str(v)))
return new_config
return data_config
def resolve_model_data_config(
model,
args=None,
pretrained_cfg=None,
use_test_size=False,
verbose=False,
):
""" Resolve Model Data Config
This is equivalent to resolve_data_config() but with arguments re-ordered to put model first.
Args:
model (nn.Module): the model instance
args (dict): command line arguments / configuration in dict form (overrides pretrained_cfg)
pretrained_cfg (dict): pretrained model config (overrides pretrained_cfg attached to model)
use_test_size (bool): use the test time input resolution (if one exists) instead of default train resolution
verbose (bool): enable extra logging of resolved values
Returns:
dictionary of config
"""
return resolve_data_config(
args=args,
pretrained_cfg=pretrained_cfg,
model=model,
use_test_size=use_test_size,
verbose=verbose,
)

@ -29,7 +29,8 @@ from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
from .padding import get_padding, get_same_padding, pad_same
from .patch_embed import PatchEmbed, resample_patch_embed
from .pool2d_same import AvgPool2dSame, create_pool2d

@ -38,13 +38,24 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False
class ClassifierHead(nn.Module):
"""Classifier head w/ configurable global pooling and dropout."""
def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
def __init__(self, in_features, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
super(ClassifierHead, self).__init__()
self.drop_rate = drop_rate
self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv)
self.in_features = in_features
self.use_conv = use_conv
self.global_pool, num_pooled_features = _create_pool(in_features, num_classes, pool_type, use_conv=use_conv)
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
def reset(self, num_classes, global_pool=None):
if global_pool is not None:
if global_pool != self.global_pool.pool_type:
self.global_pool, _ = _create_pool(self.in_features, num_classes, global_pool, use_conv=self.use_conv)
self.flatten = nn.Flatten(1) if self.use_conv and global_pool else nn.Identity()
num_pooled_features = self.in_features * self.global_pool.feat_mult()
self.fc = _create_fc(num_pooled_features, num_classes, use_conv=self.use_conv)
def forward(self, x, pre_logits: bool = False):
x = self.global_pool(x)
if self.drop_rate:

@ -17,6 +17,7 @@ from typing import Union, List, Optional, Any
import torch
from torch import nn as nn
from torch.nn import functional as F
from torchvision.ops.misc import FrozenBatchNorm2d
from .create_act import get_act_layer
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
@ -77,7 +78,7 @@ class BatchNormAct2d(nn.BatchNorm2d):
if self.training and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None
if self.num_batches_tracked is not None: # type: ignore[has-type]
self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type]
self.num_batches_tracked.add_(1) # type: ignore[has-type]
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
@ -169,6 +170,159 @@ def convert_sync_batchnorm(module, process_group=None):
return module_output
class FrozenBatchNormAct2d(torch.nn.Module):
"""
BatchNormAct2d where the batch statistics and the affine parameters are fixed
Args:
num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)``
eps (float): a value added to the denominator for numerical stability. Default: 1e-5
"""
def __init__(
self,
num_features: int,
eps: float = 1e-5,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super().__init__()
self.eps = eps
self.register_buffer("weight", torch.ones(num_features))
self.register_buffer("bias", torch.zeros(num_features))
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features))
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
def _load_from_state_dict(
self,
state_dict: dict,
prefix: str,
local_metadata: dict,
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
scale = w * (rv + self.eps).rsqrt()
bias = b - rm * scale
x = x * scale + bias
x = self.act(self.drop(x))
return x
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps}, act={self.act})"
def freeze_batch_norm_2d(module):
"""
Converts all `BatchNorm2d` and `SyncBatchNorm` or `BatchNormAct2d` and `SyncBatchNormAct2d` layers
of provided module into `FrozenBatchNorm2d` or `FrozenBatchNormAct2d` respectively.
Args:
module (torch.nn.Module): Any PyTorch module.
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
if isinstance(module, (BatchNormAct2d, SyncBatchNormAct)):
res = FrozenBatchNormAct2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
res.drop = module.drop
res.act = module.act
elif isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
res = FrozenBatchNorm2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = freeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def unfreeze_batch_norm_2d(module):
"""
Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance
of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked
recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
if isinstance(module, FrozenBatchNormAct2d):
res = BatchNormAct2d(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
res.drop = module.drop
res.act = module.act
elif isinstance(module, FrozenBatchNorm2d):
res = torch.nn.BatchNorm2d(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = unfreeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def _num_groups(num_channels, num_groups, group_size):
if group_size:
assert num_channels % group_size == 0
@ -179,10 +333,54 @@ def _num_groups(num_channels, num_groups, group_size):
class GroupNormAct(nn.GroupNorm):
# NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
def __init__(
self, num_channels, num_groups=32, eps=1e-5, affine=True, group_size=None,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
self,
num_channels,
num_groups=32,
eps=1e-5,
affine=True,
group_size=None,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super(GroupNormAct, self).__init__(
_num_groups(num_channels, num_groups, group_size), num_channels, eps=eps, affine=affine)
_num_groups(num_channels, num_groups, group_size),
num_channels,
eps=eps,
affine=affine,
)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
self._fast_norm = is_fast_norm()
def forward(self, x):
if self._fast_norm:
x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
x = self.drop(x)
x = self.act(x)
return x
class GroupNorm1Act(nn.GroupNorm):
def __init__(
self,
num_channels,
eps=1e-5,
affine=True,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super(GroupNorm1Act, self).__init__(1, num_channels, eps=eps, affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
@ -204,8 +402,15 @@ class GroupNormAct(nn.GroupNorm):
class LayerNormAct(nn.LayerNorm):
def __init__(
self, normalization_shape: Union[int, List[int], torch.Size], eps=1e-5, affine=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
self,
normalization_shape: Union[int, List[int], torch.Size],
eps=1e-5,
affine=True,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
@ -228,8 +433,15 @@ class LayerNormAct(nn.LayerNorm):
class LayerNormAct2d(nn.LayerNorm):
def __init__(
self, num_channels, eps=1e-5, affine=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
self,
num_channels,
eps=1e-5,
affine=True,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module

@ -8,6 +8,7 @@ from .convmixer import *
from .convnext import *
from .crossvit import *
from .cspnet import *
from .davit import *
from .deit import *
from .densenet import *
from .dla import *

@ -179,11 +179,11 @@ def load_pretrained(
return
if filter_fn is not None:
# for backwards compat with filter fn that take one arg, try one first, the two
try:
state_dict = filter_fn(state_dict)
except TypeError:
state_dict = filter_fn(state_dict, model)
except TypeError as e:
# for backwards compat with filter fn that take one arg
state_dict = filter_fn(state_dict)
input_convs = pretrained_cfg.get('first_conv', None)
if input_convs is not None and in_chans != 3:

@ -209,6 +209,7 @@ def push_to_hf_hub(
private: bool = False,
create_pr: bool = False,
model_config: Optional[dict] = None,
model_card: Optional[dict] = None,
):
# Create repo if it doesn't exist yet
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
@ -232,9 +233,10 @@ def push_to_hf_hub(
# Add readme if it does not exist
if not has_readme:
model_card = model_card or {}
model_name = repo_id.split('/')[-1]
readme_path = Path(tmpdir) / "README.md"
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {model_name}'
readme_text = generate_readme(model_card, model_name)
readme_path.write_text(readme_text)
# Upload model and return
@ -245,3 +247,51 @@ def push_to_hf_hub(
create_pr=create_pr,
commit_message=commit_message,
)
def generate_readme(model_card, model_name):
readme_text = "---\n"
readme_text += "tags:\n- image-classification\n- timm\n"
readme_text += "library_tag: timm\n"
readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n"
if 'details' in model_card and 'Dataset' in model_card['details']:
readme_text += 'datasets:\n'
readme_text += f"- {model_card['details']['Dataset'].lower()}\n"
if 'Pretrain Dataset' in model_card['details']:
readme_text += f"- {model_card['details']['Pretrain Dataset'].lower()}\n"
readme_text += "---\n"
readme_text += f"# Model card for {model_name}\n"
if 'description' in model_card:
readme_text += f"\n{model_card['description']}\n"
if 'details' in model_card:
readme_text += f"\n## Model Details\n"
for k, v in model_card['details'].items():
if isinstance(v, (list, tuple)):
readme_text += f"- **{k}:**\n"
for vi in v:
readme_text += f" - {vi}\n"
elif isinstance(v, dict):
readme_text += f"- **{k}:**\n"
for ki, vi in v.items():
readme_text += f" - {ki}: {vi}\n"
else:
readme_text += f"- **{k}:** {v}\n"
if 'usage' in model_card:
readme_text += f"\n## Model Usage\n"
readme_text += model_card['usage']
readme_text += '\n'
if 'comparison' in model_card:
readme_text += f"\n## Model Comparison\n"
readme_text += model_card['comparison']
readme_text += '\n'
if 'citation' in model_card:
readme_text += f"\n## Citation\n"
if not isinstance(model_card['citation'], (list, tuple)):
citations = [model_card['citation']]
else:
citations = model_card['citation']
for c in citations:
readme_text += f"```bibtex\n{c}\n```\n"
return readme_text

@ -218,7 +218,10 @@ def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0):
def interleave_blocks(
types: Tuple[str, str], d, every: Union[int, List[int]] = 1, first: bool = False, **kwargs
types: Tuple[str, str], d,
every: Union[int, List[int]] = 1,
first: bool = False,
**kwargs,
) -> Tuple[ByoBlockCfg]:
""" interleave 2 block types in stack
"""
@ -1587,15 +1590,32 @@ class ByobNet(nn.Module):
in_chans=3,
global_pool='avg',
output_stride=32,
zero_init_last=True,
img_size=None,
drop_rate=0.,
drop_path_rate=0.,
zero_init_last=True,
**kwargs,
):
"""
Args:
cfg (ByoModelCfg): Model architecture configuration
num_classes (int): Number of classifier classes (default: 1000)
in_chans (int): Number of input channels (default: 3)
global_pool (str): Global pooling type (default: 'avg')
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
img_size (Union[int, Tuple[int]): Image size for fixed image size models (i.e. self-attn)
drop_rate (float): Dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
zero_init_last (bool): Zero-init last weight of residual path
kwargs (dict): Extra kwargs overlayed onto cfg
"""
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
self.grad_checkpointing = False
cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
layers = get_layer_fns(cfg)
if cfg.fixed_input_size:
assert img_size is not None, 'img_size argument is required for fixed input size model'

@ -43,7 +43,7 @@ from functools import partial
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
from ._builder import build_model_with_cfg
@ -167,7 +167,7 @@ class ConvNeXtStage(nn.Module):
conv_bias=conv_bias,
use_grn=use_grn,
act_layer=act_layer,
norm_layer=norm_layer if conv_mlp else norm_layer_cl
norm_layer=norm_layer if conv_mlp else norm_layer_cl,
))
in_chs = out_chs
self.blocks = nn.Sequential(*stage_blocks)
@ -184,16 +184,6 @@ class ConvNeXtStage(nn.Module):
class ConvNeXt(nn.Module):
r""" ConvNeXt
A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_rate (float): Head dropout rate
drop_path_rate (float): Stochastic depth rate. Default: 0.
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
"""
def __init__(
@ -215,19 +205,47 @@ class ConvNeXt(nn.Module):
use_grn=False,
act_layer='gelu',
norm_layer=None,
norm_eps=None,
drop_rate=0.,
drop_path_rate=0.,
):
"""
Args:
in_chans (int): Number of input image channels (default: 3)
num_classes (int): Number of classes for classification head (default: 1000)
global_pool (str): Global pooling type (default: 'avg')
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
depths (tuple(int)): Number of blocks at each stage. (default: [3, 3, 9, 3])
dims (tuple(int)): Feature dimension at each stage. (default: [96, 192, 384, 768])
kernel_sizes (Union[int, List[int]]: Depthwise convolution kernel-sizes for each stage (default: 7)
ls_init_value (float): Init value for Layer Scale (default: 1e-6)
stem_type (str): Type of stem (default: 'patch')
patch_size (int): Stem patch size for patch stem (default: 4)
head_init_scale (float): Init scaling value for classifier weights and biases (default: 1)
head_norm_first (bool): Apply normalization before global pool + head (default: False)
conv_mlp (bool): Use 1x1 conv in MLP, improves speed for small networks w/ chan last (default: False)
conv_bias (bool): Use bias layers w/ all convolutions (default: True)
use_grn (bool): Use Global Response Norm (ConvNeXt-V2) in MLP (default: False)
act_layer (Union[str, nn.Module]): Activation Layer
norm_layer (Union[str, nn.Module]): Normalization Layer
drop_rate (float): Head dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth rate (default: 0.)
"""
super().__init__()
assert output_stride in (8, 16, 32)
kernel_sizes = to_ntuple(4)(kernel_sizes)
if norm_layer is None:
norm_layer = LayerNorm2d
norm_layer_cl = norm_layer if conv_mlp else LayerNorm
if norm_eps is not None:
norm_layer = partial(norm_layer, eps=norm_eps)
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
else:
assert conv_mlp,\
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
norm_layer_cl = norm_layer
if norm_eps is not None:
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
self.num_classes = num_classes
self.drop_rate = drop_rate
@ -238,7 +256,7 @@ class ConvNeXt(nn.Module):
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias),
norm_layer(dims[0])
norm_layer(dims[0]),
)
stem_stride = patch_size
else:
@ -279,7 +297,7 @@ class ConvNeXt(nn.Module):
use_grn=use_grn,
act_layer=act_layer,
norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl
norm_layer_cl=norm_layer_cl,
))
prev_chs = out_chs
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
@ -289,11 +307,10 @@ class ConvNeXt(nn.Module):
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
self.head_norm_first = head_norm_first
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
('norm', nn.Identity() if head_norm_first or num_classes == 0 else norm_layer(self.num_features)),
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
@ -324,14 +341,7 @@ class ConvNeXt(nn.Module):
if global_pool is not None:
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
if num_classes == 0:
self.head.norm = nn.Identity()
self.head.fc = nn.Identity()
else:
if not self.head_norm_first:
norm_layer = type(self.stem[-1]) # obtain type from stem norm
self.head.norm = norm_layer(self.num_features)
self.head.fc = nn.Linear(self.num_features, num_classes)
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.stem(x)
@ -372,7 +382,15 @@ def checkpoint_filter_fn(state_dict, model):
return state_dict # non-FB checkpoint
if 'model' in state_dict:
state_dict = state_dict['model']
out_dict = {}
if 'visual.trunk.stem.0.weight' in state_dict:
out_dict = {k.replace('visual.trunk.', ''): v for k, v in state_dict.items() if k.startswith('visual.trunk.')}
if 'visual.head.proj.weight' in state_dict:
out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
return out_dict
import re
for k, v in state_dict.items():
k = k.replace('downsample_layers.0.', 'stem.')
@ -391,10 +409,16 @@ def checkpoint_filter_fn(state_dict, model):
model_shape = model.state_dict()[k].shape
v = v.reshape(model_shape)
out_dict[k] = v
return out_dict
def _create_convnext(variant, pretrained=False, **kwargs):
if kwargs.get('pretrained_cfg', '') == 'fcmae':
# NOTE fcmae pretrained weights have no classifier or final norm-layer (`head.norm`)
# This is workaround loading with num_classes=0 w/o removing norm-layer.
kwargs.setdefault('pretrained_strict', False)
model = build_model_with_cfg(
ConvNeXt, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
@ -469,10 +493,29 @@ default_cfgs = generate_default_cfgs({
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.in12k_ft_in1k_384': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_small.in12k_ft_in1k_384': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_nano.in12k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, num_classes=11821),
'convnext_tiny.in12k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, num_classes=11821),
'convnext_small.in12k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, num_classes=11821),
'convnext_tiny.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
@ -664,6 +707,33 @@ default_cfgs = generate_default_cfgs({
num_classes=0),
'convnextv2_small.untrained': _cfg(),
# CLIP based weights, original image tower weights and fine-tunes
'convnext_base.clip_laion2b': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laion2b_augreg': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laiona': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laiona_320': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laiona_augreg_320': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
})

@ -12,7 +12,7 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage
Hacked together by / Copyright 2020 Ross Wightman
"""
from dataclasses import dataclass, asdict
from dataclasses import dataclass, asdict, replace
from functools import partial
from typing import Any, Dict, Optional, Tuple, Union
@ -518,7 +518,7 @@ class CrossStage(nn.Module):
cross_linear=False,
block_dpr=None,
block_fn=BottleneckBlock,
**block_kwargs
**block_kwargs,
):
super(CrossStage, self).__init__()
first_dilation = first_dilation or dilation
@ -558,7 +558,7 @@ class CrossStage(nn.Module):
bottle_ratio=bottle_ratio,
groups=groups,
drop_path=block_dpr[i] if block_dpr is not None else 0.,
**block_kwargs
**block_kwargs,
))
prev_chs = block_out_chs
@ -597,7 +597,7 @@ class CrossStage3(nn.Module):
cross_linear=False,
block_dpr=None,
block_fn=BottleneckBlock,
**block_kwargs
**block_kwargs,
):
super(CrossStage3, self).__init__()
first_dilation = first_dilation or dilation
@ -635,7 +635,7 @@ class CrossStage3(nn.Module):
bottle_ratio=bottle_ratio,
groups=groups,
drop_path=block_dpr[i] if block_dpr is not None else 0.,
**block_kwargs
**block_kwargs,
))
prev_chs = block_out_chs
@ -668,7 +668,7 @@ class DarkStage(nn.Module):
avg_down=False,
block_fn=BottleneckBlock,
block_dpr=None,
**block_kwargs
**block_kwargs,
):
super(DarkStage, self).__init__()
first_dilation = first_dilation or dilation
@ -715,7 +715,7 @@ def create_csp_stem(
padding='',
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
aa_layer=None
aa_layer=None,
):
stem = nn.Sequential()
feature_info = []
@ -738,7 +738,7 @@ def create_csp_stem(
stride=conv_stride,
padding=padding if i == 0 else '',
act_layer=act_layer,
norm_layer=norm_layer
norm_layer=norm_layer,
))
stem_stride *= conv_stride
prev_chs = chs
@ -800,7 +800,7 @@ def create_csp_stages(
cfg: CspModelCfg,
drop_path_rate: float,
output_stride: int,
stem_feat: Dict[str, Any]
stem_feat: Dict[str, Any],
):
cfg_dict = asdict(cfg.stages)
num_stages = len(cfg.stages.depth)
@ -868,12 +868,27 @@ class CspNet(nn.Module):
global_pool='avg',
drop_rate=0.,
drop_path_rate=0.,
zero_init_last=True
zero_init_last=True,
**kwargs,
):
"""
Args:
cfg (CspModelCfg): Model architecture configuration
in_chans (int): Number of input channels (default: 3)
num_classes (int): Number of classifier classes (default: 1000)
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
global_pool (str): Global pooling type (default: 'avg')
drop_rate (float): Dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
zero_init_last (bool): Zero-init last weight of residual path
kwargs (dict): Extra kwargs overlayed onto cfg
"""
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
assert output_stride in (8, 16, 32)
cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
layer_args = dict(
act_layer=cfg.act_layer,
norm_layer=cfg.norm_layer,
@ -898,7 +913,7 @@ class CspNet(nn.Module):
# Construct the head
self.num_features = prev_chs
self.head = ClassifierHead(
in_chs=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
in_features=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)

@ -0,0 +1,679 @@
""" DaViT: Dual Attention Vision Transformers
As described in https://arxiv.org/abs/2204.03645
Input size invariant transformer architecture that combines channel and spacial
attention in each block. The attention mechanisms used are linear in complexity.
DaViT model defs and weights adapted from https://github.com/dingmyu/davit, original copyright below
"""
# Copyright (c) 2022 Mingyu Ding
# All rights reserved.
# This source code is licensed under the MIT license
import itertools
from collections import OrderedDict
from functools import partial
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp, LayerNorm2d, get_norm_layer
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
__all__ = ['DaViT']
class ConvPosEnc(nn.Module):
def __init__(self, dim: int, k: int = 3, act: bool = False):
super(ConvPosEnc, self).__init__()
self.proj = nn.Conv2d(dim, dim, k, 1, k // 2, groups=dim)
self.act = nn.GELU() if act else nn.Identity()
def forward(self, x: Tensor):
feat = self.proj(x)
x = x + self.act(feat)
return x
class Stem(nn.Module):
""" Size-agnostic implementation of 2D image to patch embedding,
allowing input size to be adjusted during model forward operation
"""
def __init__(
self,
in_chs=3,
out_chs=96,
stride=4,
norm_layer=LayerNorm2d,
):
super().__init__()
stride = to_2tuple(stride)
self.stride = stride
self.in_chs = in_chs
self.out_chs = out_chs
assert stride[0] == 4 # only setup for stride==4
self.conv = nn.Conv2d(
in_chs,
out_chs,
kernel_size=7,
stride=stride,
padding=3,
)
self.norm = norm_layer(out_chs)
def forward(self, x: Tensor):
B, C, H, W = x.shape
x = F.pad(x, (0, (self.stride[1] - W % self.stride[1]) % self.stride[1]))
x = F.pad(x, (0, 0, 0, (self.stride[0] - H % self.stride[0]) % self.stride[0]))
x = self.conv(x)
x = self.norm(x)
return x
class Downsample(nn.Module):
def __init__(
self,
in_chs,
out_chs,
norm_layer=LayerNorm2d,
):
super().__init__()
self.in_chs = in_chs
self.out_chs = out_chs
self.norm = norm_layer(in_chs)
self.conv = nn.Conv2d(
in_chs,
out_chs,
kernel_size=2,
stride=2,
padding=0,
)
def forward(self, x: Tensor):
B, C, H, W = x.shape
x = self.norm(x)
x = F.pad(x, (0, (2 - W % 2) % 2))
x = F.pad(x, (0, 0, 0, (2 - H % 2) % 2))
x = self.conv(x)
return x
class ChannelAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
def forward(self, x: Tensor):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
k = k * self.scale
attention = k.transpose(-1, -2) @ v
attention = attention.softmax(dim=-1)
x = (attention @ q.transpose(-1, -2)).transpose(-1, -2)
x = x.transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
return x
class ChannelBlock(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.,
qkv_bias=False,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
ffn=True,
cpe_act=False,
):
super().__init__()
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
self.ffn = ffn
self.norm1 = norm_layer(dim)
self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
if self.ffn:
self.norm2 = norm_layer(dim)
self.mlp = Mlp(
in_features=dim,
hidden_features=int(dim * mlp_ratio),
act_layer=act_layer,
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
else:
self.norm2 = None
self.mlp = None
self.drop_path2 = None
def forward(self, x: Tensor):
B, C, H, W = x.shape
x = self.cpe1(x).flatten(2).transpose(1, 2)
cur = self.norm1(x)
cur = self.attn(cur)
x = x + self.drop_path1(cur)
x = self.cpe2(x.transpose(1, 2).view(B, C, H, W))
if self.mlp is not None:
x = x.flatten(2).transpose(1, 2)
x = x + self.drop_path2(self.mlp(self.norm2(x)))
x = x.transpose(1, 2).view(B, C, H, W)
return x
def window_partition(x: Tensor, window_size: Tuple[int, int]):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
return windows
@register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows: Tensor, window_size: Tuple[int, int], H: int, W: int):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True):
super().__init__()
self.dim = dim
self.window_size = window_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x: Tensor):
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = self.softmax(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
return x
class SpatialBlock(nn.Module):
r""" Windows Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(
self,
dim,
num_heads,
window_size=7,
mlp_ratio=4.,
qkv_bias=True,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
ffn=True,
cpe_act=False,
):
super().__init__()
self.dim = dim
self.ffn = ffn
self.num_heads = num_heads
self.window_size = to_2tuple(window_size)
self.mlp_ratio = mlp_ratio
self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(
dim,
self.window_size,
num_heads=num_heads,
qkv_bias=qkv_bias,
)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act)
if self.ffn:
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
else:
self.norm2 = None
self.mlp = None
self.drop_path1 = None
def forward(self, x: Tensor):
B, C, H, W = x.shape
shortcut = self.cpe1(x).flatten(2).transpose(1, 2)
x = self.norm1(shortcut)
x = x.view(B, H, W, C)
pad_l = pad_t = 0
pad_r = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
pad_b = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
x_windows = window_partition(x, self.window_size)
x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C)
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows)
# merge windows
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
x = window_reverse(attn_windows, self.window_size, Hp, Wp)
# if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.view(B, H * W, C)
x = shortcut + self.drop_path1(x)
x = self.cpe2(x.transpose(1, 2).view(B, C, H, W))
if self.mlp is not None:
x = x.flatten(2).transpose(1, 2)
x = x + self.drop_path2(self.mlp(self.norm2(x)))
x = x.transpose(1, 2).view(B, C, H, W)
return x
class DaViTStage(nn.Module):
def __init__(
self,
in_chs,
out_chs,
depth=1,
downsample=True,
attn_types=('spatial', 'channel'),
num_heads=3,
window_size=7,
mlp_ratio=4,
qkv_bias=True,
drop_path_rates=(0, 0),
norm_layer=LayerNorm2d,
norm_layer_cl=nn.LayerNorm,
ffn=True,
cpe_act=False
):
super().__init__()
self.grad_checkpointing = False
# downsample embedding layer at the beginning of each stage
if downsample:
self.downsample = Downsample(in_chs, out_chs, norm_layer=norm_layer)
else:
self.downsample = nn.Identity()
'''
repeating alternating attention blocks in each stage
default: (spatial -> channel) x depth
potential opportunity to integrate with a more general version of ByobNet/ByoaNet
since the logic is similar
'''
stage_blocks = []
for block_idx in range(depth):
dual_attention_block = []
for attn_idx, attn_type in enumerate(attn_types):
if attn_type == 'spatial':
dual_attention_block.append(SpatialBlock(
dim=out_chs,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=drop_path_rates[block_idx],
norm_layer=norm_layer_cl,
ffn=ffn,
cpe_act=cpe_act,
window_size=window_size,
))
elif attn_type == 'channel':
dual_attention_block.append(ChannelBlock(
dim=out_chs,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=drop_path_rates[block_idx],
norm_layer=norm_layer_cl,
ffn=ffn,
cpe_act=cpe_act
))
stage_blocks.append(nn.Sequential(*dual_attention_block))
self.blocks = nn.Sequential(*stage_blocks)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
def forward(self, x: Tensor):
x = self.downsample(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
class DaViT(nn.Module):
r""" DaViT
A PyTorch implementation of `DaViT: Dual Attention Vision Transformers` - https://arxiv.org/abs/2204.03645
Supports arbitrary input sizes and pyramid feature extraction
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks in each stage. Default: (1, 1, 3, 1)
embed_dims (tuple(int)): Patch embedding dimension. Default: (96, 192, 384, 768)
num_heads (tuple(int)): Number of attention heads in different layers. Default: (3, 6, 12, 24)
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
"""
def __init__(
self,
in_chans=3,
depths=(1, 1, 3, 1),
embed_dims=(96, 192, 384, 768),
num_heads=(3, 6, 12, 24),
window_size=7,
mlp_ratio=4,
qkv_bias=True,
norm_layer='layernorm2d',
norm_layer_cl='layernorm',
norm_eps=1e-5,
attn_types=('spatial', 'channel'),
ffn=True,
cpe_act=False,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_classes=1000,
global_pool='avg',
head_norm_first=False,
):
super().__init__()
num_stages = len(embed_dims)
assert num_stages == len(num_heads) == len(depths)
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
self.num_classes = num_classes
self.num_features = embed_dims[-1]
self.drop_rate = drop_rate
self.grad_checkpointing = False
self.feature_info = []
self.stem = Stem(in_chans, embed_dims[0], norm_layer=norm_layer)
in_chs = embed_dims[0]
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
stages = []
for stage_idx in range(num_stages):
out_chs = embed_dims[stage_idx]
stage = DaViTStage(
in_chs,
out_chs,
depth=depths[stage_idx],
downsample=stage_idx > 0,
attn_types=attn_types,
num_heads=num_heads[stage_idx],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path_rates=dpr[stage_idx],
norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl,
ffn=ffn,
cpe_act=cpe_act,
)
in_chs = out_chs
stages.append(stage)
self.feature_info += [dict(num_chs=out_chs, reduction=2, module=f'stages.{stage_idx}')]
self.stages = nn.Sequential(*stages)
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
# otherwise pool -> norm -> fc, the default DaViT order, similar to ConvNeXt
# FIXME generalize this structure to ClassifierHead
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
for stage in self.stages:
stage.set_grad_checkpointing(enable=enable)
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
if global_pool is not None:
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.stem(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stages, x)
else:
x = self.stages(x)
x = self.norm_pre(x)
return x
def forward_head(self, x, pre_logits: bool = False):
x = self.head.global_pool(x)
x = self.head.norm(x)
x = self.head.flatten(x)
x = self.head.drop(x)
return x if pre_logits else self.head.fc(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """
if 'head.fc.weight' in state_dict:
return state_dict # non-MSFT checkpoint
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
import re
out_dict = {}
for k, v in state_dict.items():
k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.downsample', k)
k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
k = k.replace('downsample.proj', 'downsample.conv')
k = k.replace('stages.0.downsample', 'stem')
k = k.replace('head.', 'head.fc.')
k = k.replace('norms.', 'head.norm.')
k = k.replace('cpe.0', 'cpe1')
k = k.replace('cpe.1', 'cpe2')
out_dict[k] = v
return out_dict
def _create_davit(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
model = build_model_with_cfg(
DaViT,
variant,
pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(flatten_sequential=True, out_indices=out_indices),
**kwargs)
return model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.95, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
**kwargs
}
# TODO contact authors to get larger pretrained models
default_cfgs = generate_default_cfgs({
# official microsoft weights from https://github.com/dingmyu/davit
'davit_tiny.msft_in1k': _cfg(
hf_hub_id='timm/'),
'davit_small.msft_in1k': _cfg(
hf_hub_id='timm/'),
'davit_base.msft_in1k': _cfg(
hf_hub_id='timm/'),
'davit_large': _cfg(),
'davit_huge': _cfg(),
'davit_giant': _cfg(),
})
@register_model
def davit_tiny(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24), **kwargs)
return _create_davit('davit_tiny', pretrained=pretrained, **model_kwargs)
@register_model
def davit_small(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24), **kwargs)
return _create_davit('davit_small', pretrained=pretrained, **model_kwargs)
@register_model
def davit_base(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32), **kwargs)
return _create_davit('davit_base', pretrained=pretrained, **model_kwargs)
@register_model
def davit_large(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536), num_heads=(6, 12, 24, 48), **kwargs)
return _create_davit('davit_large', pretrained=pretrained, **model_kwargs)
@register_model
def davit_huge(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64), **kwargs)
return _create_davit('davit_huge', pretrained=pretrained, **model_kwargs)
@register_model
def davit_giant(pretrained=False, **kwargs):
model_kwargs = dict(
depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96), **kwargs)
return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs)

@ -12,7 +12,7 @@ import torch.utils.checkpoint as cp
from torch.jit.annotations import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier
from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier
from ._builder import build_model_with_cfg
from ._manipulate import MATCH_PREV_GROUP
from ._registry import register_model
@ -115,8 +115,15 @@ class DenseBlock(nn.ModuleDict):
_version = 2
def __init__(
self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=BatchNormAct2d,
drop_rate=0., memory_efficient=False):
self,
num_layers,
num_input_features,
bn_size,
growth_rate,
norm_layer=BatchNormAct2d,
drop_rate=0.,
memory_efficient=False,
):
super(DenseBlock, self).__init__()
for i in range(num_layers):
layer = DenseLayer(
@ -165,12 +172,25 @@ class DenseNet(nn.Module):
"""
def __init__(
self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=1000, in_chans=3, global_pool='avg',
bn_size=4, stem_type='', norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0,
memory_efficient=False, aa_stem_only=True):
self,
growth_rate=32,
block_config=(6, 12, 24, 16),
num_classes=1000,
in_chans=3,
global_pool='avg',
bn_size=4,
stem_type='',
act_layer='relu',
norm_layer='batchnorm2d',
aa_layer=None,
drop_rate=0,
memory_efficient=False,
aa_stem_only=True,
):
self.num_classes = num_classes
self.drop_rate = drop_rate
super(DenseNet, self).__init__()
norm_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
# Stem
deep_stem = 'deep' in stem_type # 3x3 deep stem
@ -226,8 +246,11 @@ class DenseNet(nn.Module):
dict(num_chs=num_features, reduction=current_stride, module='features.' + module_name)]
current_stride *= 2
trans = DenseTransition(
num_input_features=num_features, num_output_features=num_features // 2,
norm_layer=norm_layer, aa_layer=transition_aa_layer)
num_input_features=num_features,
num_output_features=num_features // 2,
norm_layer=norm_layer,
aa_layer=transition_aa_layer,
)
self.features.add_module(f'transition{i + 1}', trans)
num_features = num_features // 2
@ -322,8 +345,8 @@ def densenetblur121d(pretrained=False, **kwargs):
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
model = _create_densenet(
'densenetblur121d', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, stem_type='deep',
aa_layer=BlurPool2d, **kwargs)
'densenetblur121d', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained,
stem_type='deep', aa_layer=BlurPool2d, **kwargs)
return model
@ -382,11 +405,9 @@ def densenet264(pretrained=False, **kwargs):
def densenet264d_iabn(pretrained=False, **kwargs):
r"""Densenet-264 model with deep stem and Inplace-ABN
"""
def norm_act_fn(num_features, **kwargs):
return create_norm_act_layer('iabn', num_features, act_layer='leaky_relu', **kwargs)
model = _create_densenet(
'densenet264d_iabn', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep',
norm_layer=norm_act_fn, pretrained=pretrained, **kwargs)
norm_layer='iabn', act_layer='leaky_relu', pretrained=pretrained, **kwargs)
return model

@ -15,7 +15,7 @@ import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier
from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier, get_norm_act_layer
from ._builder import build_model_with_cfg
from ._registry import register_model
@ -33,6 +33,7 @@ def _cfg(url='', **kwargs):
default_cfgs = {
'dpn48b': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'dpn68': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'),
'dpn68b': _cfg(
@ -82,7 +83,16 @@ class BnActConv2d(nn.Module):
class DualPathBlock(nn.Module):
def __init__(
self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, groups, block_type='normal', b=False):
self,
in_chs,
num_1x1_a,
num_3x3_b,
num_1x1_c,
inc,
groups,
block_type='normal',
b=False,
):
super(DualPathBlock, self).__init__()
self.num_1x1_c = num_1x1_c
self.inc = inc
@ -167,16 +177,31 @@ class DualPathBlock(nn.Module):
class DPN(nn.Module):
def __init__(
self, small=False, num_init_features=64, k_r=96, groups=32, global_pool='avg',
b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), output_stride=32,
num_classes=1000, in_chans=3, drop_rate=0., fc_act_layer=nn.ELU):
self,
k_sec=(3, 4, 20, 3),
inc_sec=(16, 32, 24, 128),
k_r=96,
groups=32,
num_classes=1000,
in_chans=3,
output_stride=32,
global_pool='avg',
small=False,
num_init_features=64,
b=False,
drop_rate=0.,
norm_layer='batchnorm2d',
act_layer='relu',
fc_act_layer='elu',
):
super(DPN, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
self.b = b
assert output_stride == 32 # FIXME look into dilation support
norm_layer = partial(BatchNormAct2d, eps=.001)
fc_norm_layer = partial(BatchNormAct2d, eps=.001, act_layer=fc_act_layer, inplace=False)
norm_layer = partial(get_norm_act_layer(norm_layer, act_layer=act_layer), eps=.001)
fc_norm_layer = partial(get_norm_act_layer(norm_layer, act_layer=fc_act_layer), eps=.001, inplace=False)
bw_factor = 1 if small else 4
blocks = OrderedDict()
@ -291,49 +316,57 @@ def _create_dpn(variant, pretrained=False, **kwargs):
**kwargs)
@register_model
def dpn48b(pretrained=False, **kwargs):
model_kwargs = dict(
small=True, num_init_features=10, k_r=128, groups=32,
b=True, k_sec=(3, 4, 6, 3), inc_sec=(16, 32, 32, 64), act_layer='silu')
return _create_dpn('dpn48b', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
def dpn68(pretrained=False, **kwargs):
model_kwargs = dict(
small=True, num_init_features=10, k_r=128, groups=32,
k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs)
return _create_dpn('dpn68', pretrained=pretrained, **model_kwargs)
k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64))
return _create_dpn('dpn68', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
def dpn68b(pretrained=False, **kwargs):
model_kwargs = dict(
small=True, num_init_features=10, k_r=128, groups=32,
b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs)
return _create_dpn('dpn68b', pretrained=pretrained, **model_kwargs)
b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64))
return _create_dpn('dpn68b', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
def dpn92(pretrained=False, **kwargs):
model_kwargs = dict(
num_init_features=64, k_r=96, groups=32,
k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), **kwargs)
return _create_dpn('dpn92', pretrained=pretrained, **model_kwargs)
k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128))
return _create_dpn('dpn92', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
def dpn98(pretrained=False, **kwargs):
model_kwargs = dict(
num_init_features=96, k_r=160, groups=40,
k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128), **kwargs)
return _create_dpn('dpn98', pretrained=pretrained, **model_kwargs)
k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128))
return _create_dpn('dpn98', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
def dpn131(pretrained=False, **kwargs):
model_kwargs = dict(
num_init_features=128, k_r=160, groups=40,
k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128), **kwargs)
return _create_dpn('dpn131', pretrained=pretrained, **model_kwargs)
k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128))
return _create_dpn('dpn131', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
def dpn107(pretrained=False, **kwargs):
model_kwargs = dict(
num_init_features=128, k_r=200, groups=50,
k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128), **kwargs)
return _create_dpn('dpn107', pretrained=pretrained, **model_kwargs)
k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128))
return _create_dpn('dpn107', pretrained=pretrained, **dict(model_kwargs, **kwargs))

@ -12,9 +12,6 @@ These configs work well and appear to be a bit faster / lower resource than the
The models without extra prefix / suffix' (coatnet_0_224, maxvit_tiny_224, etc), are intended to
match paper, BUT, without any official pretrained weights it's difficult to confirm a 100% match.
# FIXME / WARNING
This impl remains a WIP, some configs and models may vanish or change...
Papers:
MaxViT: Multi-Axis Vision Transformer - https://arxiv.org/abs/2204.01697
@ -76,6 +73,8 @@ class MaxxVitTransformerCfg:
partition_ratio: int = 32
window_size: Optional[Tuple[int, int]] = None
grid_size: Optional[Tuple[int, int]] = None
no_block_attn: bool = False # disable window block attention for maxvit (ie only grid)
use_nchw_attn: bool = False # for MaxViT variants (not used for CoAt), keep tensors in NCHW order
init_values: Optional[float] = None
act_layer: str = 'gelu'
norm_layer: str = 'layernorm2d'
@ -889,19 +888,17 @@ class MaxxVitBlock(nn.Module):
stride: int = 1,
conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(),
transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(),
use_nchw_attn: bool = False, # FIXME move to cfg? True is ~20-30% faster on TPU, 5-10% slower on GPU
use_block_attn: bool = True, # FIXME for testing ConvNeXt conv w/o block attention
drop_path: float = 0.,
):
super().__init__()
self.nchw_attn = transformer_cfg.use_nchw_attn
conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock
self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path)
attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path)
partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttentionCl
self.nchw_attn = use_nchw_attn
self.attn_block = partition_layer(**attn_kwargs) if use_block_attn else None
partition_layer = PartitionAttention2d if self.nchw_attn else PartitionAttentionCl
self.attn_block = None if transformer_cfg.no_block_attn else partition_layer(**attn_kwargs)
self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs)
def init_weights(self, scheme=''):
@ -1084,26 +1081,48 @@ class NormMlpHead(nn.Module):
hidden_size=None,
pool_type='avg',
drop_rate=0.,
norm_layer=nn.LayerNorm,
act_layer=nn.Tanh,
norm_layer='layernorm2d',
act_layer='tanh',
):
super().__init__()
self.drop_rate = drop_rate
self.in_features = in_features
self.hidden_size = hidden_size
self.num_features = in_features
self.use_conv = not pool_type
norm_layer = get_norm_layer(norm_layer)
act_layer = get_act_layer(act_layer)
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
self.norm = norm_layer(in_features)
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
if hidden_size:
self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(in_features, hidden_size)),
('fc', linear_layer(in_features, hidden_size)),
('act', act_layer()),
]))
self.num_features = hidden_size
else:
self.pre_logits = nn.Identity()
self.drop = nn.Dropout(self.drop_rate)
self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def reset(self, num_classes, global_pool=None):
if global_pool is not None:
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.use_conv = self.global_pool.is_identity()
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
if self.hidden_size:
if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or
(isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)):
with torch.no_grad():
new_fc = linear_layer(self.in_features, self.hidden_size)
new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape))
new_fc.bias.copy_(self.pre_logits.fc.bias)
self.pre_logits.fc = new_fc
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x, pre_logits: bool = False):
x = self.global_pool(x)
@ -1116,6 +1135,26 @@ class NormMlpHead(nn.Module):
return x
def _overlay_kwargs(cfg: MaxxVitCfg, **kwargs):
transformer_kwargs = {}
conv_kwargs = {}
base_kwargs = {}
for k, v in kwargs.items():
if k.startswith('transformer_'):
transformer_kwargs[k.replace('transformer_', '')] = v
elif k.startswith('conv_'):
conv_kwargs[k.replace('conv_', '')] = v
else:
base_kwargs[k] = v
cfg = replace(
cfg,
transformer_cfg=replace(cfg.transformer_cfg, **transformer_kwargs),
conv_cfg=replace(cfg.conv_cfg, **conv_kwargs),
**base_kwargs
)
return cfg
class MaxxVit(nn.Module):
""" CoaTNet + MaxVit base model.
@ -1130,16 +1169,20 @@ class MaxxVit(nn.Module):
num_classes: int = 1000,
global_pool: str = 'avg',
drop_rate: float = 0.,
drop_path_rate: float = 0.
drop_path_rate: float = 0.,
**kwargs,
):
super().__init__()
img_size = to_2tuple(img_size)
if kwargs:
cfg = _overlay_kwargs(cfg, **kwargs)
transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size)
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = cfg.embed_dim[-1]
self.drop_rate = drop_rate
self.grad_checkpointing = False
self.feature_info = []
self.stem = Stem(
in_chs=in_chans,
@ -1150,8 +1193,8 @@ class MaxxVit(nn.Module):
norm_layer=cfg.conv_cfg.norm_layer,
norm_eps=cfg.conv_cfg.norm_eps,
)
stride = self.stem.stride
self.feature_info += [dict(num_chs=self.stem.out_chs, reduction=2, module='stem')]
feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))])
num_stages = len(cfg.embed_dim)
@ -1175,15 +1218,17 @@ class MaxxVit(nn.Module):
)]
stride *= stage_stride
in_chs = out_chs
self.feature_info += [dict(num_chs=out_chs, reduction=stride, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
final_norm_layer = partial(get_norm_layer(cfg.transformer_cfg.norm_layer), eps=cfg.transformer_cfg.norm_eps)
if cfg.head_hidden_size:
self.head_hidden_size = cfg.head_hidden_size
if self.head_hidden_size:
self.norm = nn.Identity()
self.head = NormMlpHead(
self.num_features,
num_classes,
hidden_size=cfg.head_hidden_size,
hidden_size=self.head_hidden_size,
pool_type=global_pool,
drop_rate=drop_rate,
norm_layer=final_norm_layer,
@ -1230,9 +1275,7 @@ class MaxxVit(nn.Module):
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is None:
global_pool = self.head.global_pool.pool_type
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
self.head.reset(num_classes, global_pool)
def forward_features(self, x):
x = self.stem(x)
@ -1353,6 +1396,7 @@ def _next_cfg(
transformer_norm_layer='layernorm2d',
transformer_norm_layer_cl='layernorm',
window_size=None,
no_block_attn=False,
init_values=1e-6,
rel_pos_type='mlp', # MLP by default for maxxvit
rel_pos_dim=512,
@ -1373,6 +1417,7 @@ def _next_cfg(
expand_first=False,
pool_type=pool_type,
window_size=window_size,
no_block_attn=no_block_attn, # enabled for MaxxViT-V2
init_values=init_values[1],
norm_layer=transformer_norm_layer,
norm_layer_cl=transformer_norm_layer_cl,
@ -1399,8 +1444,8 @@ def _tf_cfg():
model_cfgs = dict(
# Fiddling with configs / defaults / still pretraining
coatnet_pico_rw_224=MaxxVitCfg(
# timm specific CoAtNet configs
coatnet_pico_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 3, 5, 2),
stem_width=(32, 64),
@ -1409,7 +1454,7 @@ model_cfgs = dict(
conv_attn_ratio=0.25,
),
),
coatnet_nano_rw_224=MaxxVitCfg(
coatnet_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3),
stem_width=(32, 64),
@ -1419,7 +1464,7 @@ model_cfgs = dict(
conv_attn_ratio=0.25,
),
),
coatnet_0_rw_224=MaxxVitCfg(
coatnet_0_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 3, 7, 2), # deeper than paper '0' model
stem_width=(32, 64),
@ -1428,7 +1473,7 @@ model_cfgs = dict(
transformer_shortcut_bias=False,
),
),
coatnet_1_rw_224=MaxxVitCfg(
coatnet_1_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
stem_width=(32, 64),
@ -1438,7 +1483,7 @@ model_cfgs = dict(
transformer_shortcut_bias=False,
)
),
coatnet_2_rw_224=MaxxVitCfg(
coatnet_2_rw=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2),
stem_width=(64, 128),
@ -1448,7 +1493,7 @@ model_cfgs = dict(
#init_values=1e-6,
),
),
coatnet_3_rw_224=MaxxVitCfg(
coatnet_3_rw=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2),
stem_width=(96, 192),
@ -1459,8 +1504,8 @@ model_cfgs = dict(
),
),
# Highly experimental configs
coatnet_bn_0_rw_224=MaxxVitCfg(
# Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
coatnet_bn_0_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 3, 7, 2), # deeper than paper '0' model
stem_width=(32, 64),
@ -1471,7 +1516,7 @@ model_cfgs = dict(
transformer_norm_layer='batchnorm2d',
)
),
coatnet_rmlp_nano_rw_224=MaxxVitCfg(
coatnet_rmlp_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3),
stem_width=(32, 64),
@ -1482,7 +1527,7 @@ model_cfgs = dict(
rel_pos_dim=384,
),
),
coatnet_rmlp_0_rw_224=MaxxVitCfg(
coatnet_rmlp_0_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 3, 7, 2), # deeper than paper '0' model
stem_width=(32, 64),
@ -1491,7 +1536,7 @@ model_cfgs = dict(
rel_pos_type='mlp',
),
),
coatnet_rmlp_1_rw_224=MaxxVitCfg(
coatnet_rmlp_1_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
stem_width=(32, 64),
@ -1503,7 +1548,7 @@ model_cfgs = dict(
rel_pos_dim=384, # was supposed to be 512, woops
),
),
coatnet_rmlp_1_rw2_224=MaxxVitCfg(
coatnet_rmlp_1_rw2=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
stem_width=(32, 64),
@ -1513,7 +1558,7 @@ model_cfgs = dict(
rel_pos_dim=512, # was supposed to be 512, woops
),
),
coatnet_rmlp_2_rw_224=MaxxVitCfg(
coatnet_rmlp_2_rw=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2),
stem_width=(64, 128),
@ -1524,7 +1569,7 @@ model_cfgs = dict(
rel_pos_type='mlp'
),
),
coatnet_rmlp_3_rw_224=MaxxVitCfg(
coatnet_rmlp_3_rw=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2),
stem_width=(96, 192),
@ -1536,14 +1581,14 @@ model_cfgs = dict(
),
),
coatnet_nano_cc_224=MaxxVitCfg(
coatnet_nano_cc=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3),
stem_width=(32, 64),
block_type=('C', 'C', ('C', 'T'), ('C', 'T')),
**_rw_coat_cfg(),
),
coatnext_nano_rw_224=MaxxVitCfg(
coatnext_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(3, 4, 6, 3),
stem_width=(32, 64),
@ -1555,89 +1600,95 @@ model_cfgs = dict(
),
# Trying to be like the CoAtNet paper configs
coatnet_0_224=MaxxVitCfg(
coatnet_0=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 3, 5, 2),
stem_width=64,
head_hidden_size=768,
),
coatnet_1_224=MaxxVitCfg(
coatnet_1=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
stem_width=64,
head_hidden_size=768,
),
coatnet_2_224=MaxxVitCfg(
coatnet_2=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 14, 2),
stem_width=128,
head_hidden_size=1024,
),
coatnet_3_224=MaxxVitCfg(
coatnet_3=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536),
depths=(2, 6, 14, 2),
stem_width=192,
head_hidden_size=1536,
),
coatnet_4_224=MaxxVitCfg(
coatnet_4=MaxxVitCfg(
embed_dim=(192, 384, 768, 1536),
depths=(2, 12, 28, 2),
stem_width=192,
head_hidden_size=1536,
),
coatnet_5_224=MaxxVitCfg(
coatnet_5=MaxxVitCfg(
embed_dim=(256, 512, 1280, 2048),
depths=(2, 12, 28, 2),
stem_width=192,
head_hidden_size=2048,
),
# Experimental MaxVit configs
maxvit_pico_rw_256=MaxxVitCfg(
maxvit_pico_rw=MaxxVitCfg(
embed_dim=(32, 64, 128, 256),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(24, 32),
**_rw_max_cfg(),
),
maxvit_nano_rw_256=MaxxVitCfg(
maxvit_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(),
),
maxvit_tiny_rw_224=MaxxVitCfg(
maxvit_tiny_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(),
),
maxvit_tiny_rw_256=MaxxVitCfg(
maxvit_tiny_pm=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
block_type=('PM',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(),
),
maxvit_rmlp_pico_rw_256=MaxxVitCfg(
maxvit_rmlp_pico_rw=MaxxVitCfg(
embed_dim=(32, 64, 128, 256),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(24, 32),
**_rw_max_cfg(rel_pos_type='mlp'),
),
maxvit_rmlp_nano_rw_256=MaxxVitCfg(
maxvit_rmlp_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(rel_pos_type='mlp'),
),
maxvit_rmlp_tiny_rw_256=MaxxVitCfg(
maxvit_rmlp_tiny_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(rel_pos_type='mlp'),
),
maxvit_rmlp_small_rw_224=MaxxVitCfg(
maxvit_rmlp_small_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
@ -1647,26 +1698,18 @@ model_cfgs = dict(
init_values=1e-6,
),
),
maxvit_rmlp_small_rw_256=MaxxVitCfg(
maxvit_rmlp_base_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2),
depths=(2, 6, 14, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
head_hidden_size=768,
**_rw_max_cfg(
rel_pos_type='mlp',
init_values=1e-6,
),
),
maxvit_tiny_pm_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('PM',) * 4,
stem_width=(32, 64),
**_rw_max_cfg(),
),
maxxvit_rmlp_nano_rw_256=MaxxVitCfg(
maxxvit_rmlp_nano_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
@ -1674,33 +1717,50 @@ model_cfgs = dict(
weight_init='normal',
**_next_cfg(),
),
maxxvit_rmlp_tiny_rw_256=MaxxVitCfg(
maxxvit_rmlp_tiny_rw=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
**_next_cfg(),
),
maxxvit_rmlp_small_rw_256=MaxxVitCfg(
maxxvit_rmlp_small_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 2, 5, 2),
block_type=('M',) * 4,
stem_width=(48, 96),
**_next_cfg(),
),
maxxvit_rmlp_base_rw_224=MaxxVitCfg(
maxxvitv2_nano_rw=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
depths=(1, 2, 3, 1),
block_type=('M',) * 4,
stem_width=(48, 96),
**_next_cfg(),
weight_init='normal',
**_next_cfg(
no_block_attn=True,
rel_pos_type='bias',
),
),
maxxvit_rmlp_large_rw_224=MaxxVitCfg(
maxxvitv2_rmlp_base_rw=MaxxVitCfg(
embed_dim=(128, 256, 512, 1024),
depths=(2, 6, 12, 2),
block_type=('M',) * 4,
stem_width=(64, 128),
**_next_cfg(),
**_next_cfg(
no_block_attn=True,
),
),
maxxvitv2_rmlp_large_rw=MaxxVitCfg(
embed_dim=(160, 320, 640, 1280),
depths=(2, 6, 16, 2),
block_type=('M',) * 4,
stem_width=(80, 160),
head_hidden_size=1280,
**_next_cfg(
no_block_attn=True,
),
),
# Trying to be like the MaxViT paper configs
@ -1752,11 +1812,29 @@ model_cfgs = dict(
)
def checkpoint_filter_fn(state_dict, model: nn.Module):
model_state_dict = model.state_dict()
out_dict = {}
for k, v in state_dict.items():
if k in model_state_dict and v.ndim != model_state_dict[k].ndim and v.numel() == model_state_dict[k].numel():
# adapt between conv2d / linear layers
assert v.ndim in (2, 4)
v = v.reshape(model_state_dict[k].shape)
out_dict[k] = v
return out_dict
def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs):
if cfg_variant is None:
if variant in model_cfgs:
cfg_variant = variant
else:
cfg_variant = '_'.join(variant.split('_')[:-1])
return build_model_with_cfg(
MaxxVit, variant, pretrained,
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
model_cfg=model_cfgs[cfg_variant],
feature_cfg=dict(flatten_sequential=True),
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
@ -1772,149 +1850,218 @@ def _cfg(url='', **kwargs):
default_cfgs = generate_default_cfgs({
# Fiddling with configs / defaults / still pretraining
'coatnet_pico_rw_224': _cfg(url=''),
'coatnet_nano_rw_224': _cfg(
# timm specific CoAtNet configs, ImageNet-1k pretrain, fixed rel-pos
'coatnet_pico_rw_224.untrained': _cfg(url=''),
'coatnet_nano_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth',
crop_pct=0.9),
'coatnet_0_rw_224': _cfg(
'coatnet_0_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'),
'coatnet_1_rw_224': _cfg(
'coatnet_1_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth'
),
'coatnet_2_rw_224': _cfg(url=''),
'coatnet_3_rw_224': _cfg(url=''),
# Highly experimental configs
'coatnet_bn_0_rw_224': _cfg(
# timm specific CoAtNet configs, ImageNet-12k pretrain w/ 1k fine-tune, fixed rel-pos
'coatnet_2_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
#'coatnet_3_rw_224.untrained': _cfg(url=''),
# Experimental CoAtNet configs w/ ImageNet-12k pretrain -> 1k fine-tune (different norm layers, MLP rel-pos)
'coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
'coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
'coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
# Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos)
'coatnet_bn_0_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
crop_pct=0.95),
'coatnet_rmlp_nano_rw_224': _cfg(
'coatnet_rmlp_nano_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth',
crop_pct=0.9),
'coatnet_rmlp_0_rw_224': _cfg(url=''),
'coatnet_rmlp_1_rw_224': _cfg(
'coatnet_rmlp_0_rw_224.untrained': _cfg(url=''),
'coatnet_rmlp_1_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'),
'coatnet_rmlp_1_rw2_224': _cfg(url=''),
'coatnet_rmlp_2_rw_224': _cfg(
'coatnet_rmlp_2_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'),
'coatnet_rmlp_3_rw_224': _cfg(url=''),
'coatnet_nano_cc_224': _cfg(url=''),
'coatnext_nano_rw_224': _cfg(
'coatnet_rmlp_3_rw_224.untrained': _cfg(url=''),
'coatnet_nano_cc_224.untrained': _cfg(url=''),
'coatnext_nano_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth',
crop_pct=0.9),
# Trying to be like the CoAtNet paper configs
'coatnet_0_224': _cfg(url=''),
'coatnet_1_224': _cfg(url=''),
'coatnet_2_224': _cfg(url=''),
'coatnet_3_224': _cfg(url=''),
'coatnet_4_224': _cfg(url=''),
'coatnet_5_224': _cfg(url=''),
# Experimental configs
'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_nano_rw_256': _cfg(
# ImagenNet-12k pretrain CoAtNet
'coatnet_2_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
'coatnet_3_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
'coatnet_rmlp_1_rw2_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
'coatnet_rmlp_2_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
# Trying to be like the CoAtNet paper configs (will adapt if 'tf' weights are ever released)
'coatnet_0_224.untrained': _cfg(url=''),
'coatnet_1_224.untrained': _cfg(url=''),
'coatnet_2_224.untrained': _cfg(url=''),
'coatnet_3_224.untrained': _cfg(url=''),
'coatnet_4_224.untrained': _cfg(url=''),
'coatnet_5_224.untrained': _cfg(url=''),
# timm specific MaxVit configs, ImageNet-1k pretrain or untrained
'maxvit_pico_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_rw_224': _cfg(
'maxvit_tiny_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'),
'maxvit_tiny_rw_256': _cfg(
'maxvit_tiny_rw_256.untrained': _cfg(
url='',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_pico_rw_256': _cfg(
'maxvit_tiny_pm_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
# timm specific MaxVit w/ MLP rel-pos, ImageNet-1k pretrain
'maxvit_rmlp_pico_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_nano_rw_256': _cfg(
'maxvit_rmlp_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_tiny_rw_256': _cfg(
'maxvit_rmlp_tiny_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_small_rw_224': _cfg(
'maxvit_rmlp_small_rw_224.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth',
crop_pct=0.9,
),
'maxvit_rmlp_small_rw_256': _cfg(
'maxvit_rmlp_small_rw_256.untrained': _cfg(
url='',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
# timm specific MaxVit w/ ImageNet-12k pretrain and 1k fine-tune
'maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
),
'maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
# timm specific MaxVit w/ ImageNet-12k pretrain
'maxvit_rmlp_base_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821,
),
'maxxvit_rmlp_nano_rw_256': _cfg(
# timm MaxxViT configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks)
'maxxvit_rmlp_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_rmlp_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_rmlp_small_rw_256': _cfg(
'maxxvit_rmlp_tiny_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_rmlp_small_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvit_rmlp_base_rw_224': _cfg(url=''),
'maxxvit_rmlp_large_rw_224': _cfg(url=''),
# timm MaxxViT-V2 configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks, more width, no block attn)
'maxxvitv2_nano_rw_256.sw_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/'),
'maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxxvitv2_rmlp_large_rw_224.untrained': _cfg(url=''),
'maxxvitv2_rmlp_base_rw_224.sw_in12k': _cfg(
hf_hub_id='timm/',
num_classes=11821),
# MaxViT models ported from official Tensorflow impl
'maxvit_tiny_tf_224.in1k': _cfg(
hf_hub_id='timm/maxvit_tiny_tf_224.in1k',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_tiny_tf_384.in1k': _cfg(
hf_hub_id='timm/maxvit_tiny_tf_384.in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_tiny_tf_512.in1k': _cfg(
hf_hub_id='timm/maxvit_tiny_tf_512.in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_small_tf_224.in1k': _cfg(
hf_hub_id='timm/maxvit_small_tf_224.in1k',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_small_tf_384.in1k': _cfg(
hf_hub_id='timm/maxvit_small_tf_384.in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_small_tf_512.in1k': _cfg(
hf_hub_id='timm/maxvit_small_tf_512.in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_224.in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_224.in1k',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_base_tf_384.in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_384.in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_512.in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_512.in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_224.in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_224.in1k',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'maxvit_large_tf_384.in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_384.in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_512.in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_512.in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_224.in21k': _cfg(
url=''),
'maxvit_base_tf_384.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_384.in21k_ft_in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_base_tf_512.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_base_tf_512.in21k_ft_in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_224.in21k': _cfg(
url=''),
'maxvit_large_tf_384.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_384.in21k_ft_in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_large_tf_512.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_large_tf_512.in21k_ft_in1k',
hf_hub_id='timm/',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
'maxvit_xlarge_tf_224.in21k': _cfg(
url=''),
'maxvit_xlarge_tf_384.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_xlarge_tf_384.in21k_ft_in1k',
input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'maxvit_xlarge_tf_512.in21k_ft_in1k': _cfg(
hf_hub_id='timm/maxvit_xlarge_tf_512.in21k_ft_in1k',
input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'),
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'),
})
@ -1978,6 +2125,11 @@ def coatnet_rmlp_2_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs)
@register_model
def coatnet_rmlp_2_rw_384(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_2_rw_384', pretrained=pretrained, **kwargs)
@register_model
def coatnet_rmlp_3_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs)
@ -2068,6 +2220,16 @@ def maxvit_rmlp_small_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_base_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_base_rw_384(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_base_rw_384', pretrained=pretrained, **kwargs)
@register_model
def maxvit_tiny_pm_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)
@ -2089,13 +2251,23 @@ def maxxvit_rmlp_small_rw_256(pretrained=False, **kwargs):
@register_model
def maxxvit_rmlp_base_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
def maxxvitv2_nano_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxxvitv2_nano_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxxvitv2_rmlp_base_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxxvitv2_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
@register_model
def maxxvitv2_rmlp_base_rw_384(pretrained=False, **kwargs):
return _create_maxxvit('maxxvitv2_rmlp_base_rw_384', pretrained=pretrained, **kwargs)
@register_model
def maxxvit_rmlp_large_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxxvit_rmlp_large_rw_224', pretrained=pretrained, **kwargs)
def maxxvitv2_rmlp_large_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxxvitv2_rmlp_large_rw_224', pretrained=pretrained, **kwargs)
@register_model

@ -266,9 +266,16 @@ class MobileVitBlock(nn.Module):
self.transformer = nn.Sequential(*[
TransformerBlock(
transformer_dim, mlp_ratio=mlp_ratio, num_heads=num_heads, qkv_bias=True,
attn_drop=attn_drop, drop=drop, drop_path=drop_path_rate,
act_layer=layers.act, norm_layer=transformer_norm_layer)
transformer_dim,
mlp_ratio=mlp_ratio,
num_heads=num_heads,
qkv_bias=True,
attn_drop=attn_drop,
drop=drop,
drop_path=drop_path_rate,
act_layer=layers.act,
norm_layer=transformer_norm_layer,
)
for _ in range(transformer_depth)
])
self.norm = transformer_norm_layer(transformer_dim)

@ -17,7 +17,7 @@ Status:
Hacked together by / copyright Ross Wightman, 2021.
"""
from collections import OrderedDict
from dataclasses import dataclass
from dataclasses import dataclass, replace
from functools import partial
from typing import Tuple, Optional
@ -159,11 +159,25 @@ class NfCfg:
def _nfres_cfg(
depths, channels=(256, 512, 1024, 2048), group_size=None, act_layer='relu', attn_layer=None, attn_kwargs=None):
depths,
channels=(256, 512, 1024, 2048),
group_size=None,
act_layer='relu',
attn_layer=None,
attn_kwargs=None,
):
attn_kwargs = attn_kwargs or {}
cfg = NfCfg(
depths=depths, channels=channels, stem_type='7x7_pool', stem_chs=64, bottle_ratio=0.25,
group_size=group_size, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs)
depths=depths,
channels=channels,
stem_type='7x7_pool',
stem_chs=64,
bottle_ratio=0.25,
group_size=group_size,
act_layer=act_layer,
attn_layer=attn_layer,
attn_kwargs=attn_kwargs,
)
return cfg
@ -171,28 +185,70 @@ def _nfreg_cfg(depths, channels=(48, 104, 208, 440)):
num_features = 1280 * channels[-1] // 440
attn_kwargs = dict(rd_ratio=0.5)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25,
num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs)
depths=depths,
channels=channels,
stem_type='3x3',
group_size=8,
width_factor=0.75,
bottle_ratio=2.25,
num_features=num_features,
reg=True,
attn_layer='se',
attn_kwargs=attn_kwargs,
)
return cfg
def _nfnet_cfg(
depths, channels=(256, 512, 1536, 1536), group_size=128, bottle_ratio=0.5, feat_mult=2.,
act_layer='gelu', attn_layer='se', attn_kwargs=None):
depths,
channels=(256, 512, 1536, 1536),
group_size=128,
bottle_ratio=0.5,
feat_mult=2.,
act_layer='gelu',
attn_layer='se',
attn_kwargs=None,
):
num_features = int(channels[-1] * feat_mult)
attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(rd_ratio=0.5)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=group_size,
bottle_ratio=bottle_ratio, extra_conv=True, num_features=num_features, act_layer=act_layer,
attn_layer=attn_layer, attn_kwargs=attn_kwargs)
depths=depths,
channels=channels,
stem_type='deep_quad',
stem_chs=128,
group_size=group_size,
bottle_ratio=bottle_ratio,
extra_conv=True,
num_features=num_features,
act_layer=act_layer,
attn_layer=attn_layer,
attn_kwargs=attn_kwargs,
)
return cfg
def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', skipinit=True):
def _dm_nfnet_cfg(
depths,
channels=(256, 512, 1536, 1536),
act_layer='gelu',
skipinit=True,
):
cfg = NfCfg(
depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=128,
bottle_ratio=0.5, extra_conv=True, gamma_in_act=True, same_padding=True, skipinit=skipinit,
num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=dict(rd_ratio=0.5))
depths=depths,
channels=channels,
stem_type='deep_quad',
stem_chs=128,
group_size=128,
bottle_ratio=0.5,
extra_conv=True,
gamma_in_act=True,
same_padding=True,
skipinit=skipinit,
num_features=int(channels[-1] * 2.0),
act_layer=act_layer,
attn_layer='se',
attn_kwargs=dict(rd_ratio=0.5),
)
return cfg
@ -278,7 +334,14 @@ def act_with_gamma(act_type, gamma: float = 1.):
class DownsampleAvg(nn.Module):
def __init__(
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, conv_layer=ScaledStdConv2d):
self,
in_chs,
out_chs,
stride=1,
dilation=1,
first_dilation=None,
conv_layer=ScaledStdConv2d,
):
""" AvgPool Downsampling as in 'D' ResNet variants. Support for dilation."""
super(DownsampleAvg, self).__init__()
avg_stride = stride if dilation == 1 else 1
@ -299,9 +362,26 @@ class NormFreeBlock(nn.Module):
"""
def __init__(
self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None,
alpha=1.0, beta=1.0, bottle_ratio=0.25, group_size=None, ch_div=1, reg=True, extra_conv=False,
skipinit=False, attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0.):
self,
in_chs,
out_chs=None,
stride=1,
dilation=1,
first_dilation=None,
alpha=1.0,
beta=1.0,
bottle_ratio=0.25,
group_size=None,
ch_div=1,
reg=True,
extra_conv=False,
skipinit=False,
attn_layer=None,
attn_gain=2.0,
act_layer=None,
conv_layer=None,
drop_path_rate=0.,
):
super().__init__()
first_dilation = first_dilation or dilation
out_chs = out_chs or in_chs
@ -316,7 +396,13 @@ class NormFreeBlock(nn.Module):
if in_chs != out_chs or stride != 1 or dilation != first_dilation:
self.downsample = DownsampleAvg(
in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, conv_layer=conv_layer)
in_chs,
out_chs,
stride=stride,
dilation=dilation,
first_dilation=first_dilation,
conv_layer=conv_layer,
)
else:
self.downsample = None
@ -452,14 +538,33 @@ class NormFreeNet(nn.Module):
for what it is/does. Approx 8-10% throughput loss.
"""
def __init__(
self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
drop_rate=0., drop_path_rate=0.
self,
cfg: NfCfg,
num_classes=1000,
in_chans=3,
global_pool='avg',
output_stride=32,
drop_rate=0.,
drop_path_rate=0.,
**kwargs,
):
"""
Args:
cfg (NfCfg): Model architecture configuration
num_classes (int): Number of classifier classes (default: 1000)
in_chans (int): Number of input channels (default: 3)
global_pool (str): Global pooling type (default: 'avg')
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
drop_rate (float): Dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
kwargs (dict): Extra kwargs overlayed onto cfg
"""
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
self.grad_checkpointing = False
cfg = replace(cfg, **kwargs)
assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})."
conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d
if cfg.gamma_in_act:
@ -472,7 +577,12 @@ class NormFreeNet(nn.Module):
stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div)
self.stem, stem_stride, stem_feat = create_stem(
in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer)
in_chans,
stem_chs,
cfg.stem_type,
conv_layer=conv_layer,
act_layer=act_layer,
)
self.feature_info = [stem_feat]
drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]

@ -14,7 +14,7 @@ Weights from original impl have been modified
Hacked together by / Copyright 2020 Ross Wightman
"""
import math
from dataclasses import dataclass
from dataclasses import dataclass, replace
from functools import partial
from typing import Optional, Union, Callable
@ -237,7 +237,15 @@ def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_la
def create_shortcut(
downsample_type, in_chs, out_chs, kernel_size, stride, dilation=(1, 1), norm_layer=None, preact=False):
downsample_type,
in_chs,
out_chs,
kernel_size,
stride,
dilation=(1, 1),
norm_layer=None,
preact=False,
):
assert downsample_type in ('avg', 'conv1x1', '', None)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
dargs = dict(stride=stride, dilation=dilation[0], norm_layer=norm_layer, preact=preact)
@ -259,9 +267,21 @@ class Bottleneck(nn.Module):
"""
def __init__(
self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25,
downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
drop_block=None, drop_path_rate=0.):
self,
in_chs,
out_chs,
stride=1,
dilation=(1, 1),
bottle_ratio=1,
group_size=1,
se_ratio=0.25,
downsample='conv1x1',
linear_out=False,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
drop_block=None,
drop_path_rate=0.,
):
super(Bottleneck, self).__init__()
act_layer = get_act_layer(act_layer)
bottleneck_chs = int(round(out_chs * bottle_ratio))
@ -307,9 +327,21 @@ class PreBottleneck(nn.Module):
"""
def __init__(
self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25,
downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
drop_block=None, drop_path_rate=0.):
self,
in_chs,
out_chs,
stride=1,
dilation=(1, 1),
bottle_ratio=1,
group_size=1,
se_ratio=0.25,
downsample='conv1x1',
linear_out=False,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
drop_block=None,
drop_path_rate=0.,
):
super(PreBottleneck, self).__init__()
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
bottleneck_chs = int(round(out_chs * bottle_ratio))
@ -353,8 +385,16 @@ class RegStage(nn.Module):
"""Stage (sequence of blocks w/ the same output shape)."""
def __init__(
self, depth, in_chs, out_chs, stride, dilation,
drop_path_rates=None, block_fn=Bottleneck, **block_kwargs):
self,
depth,
in_chs,
out_chs,
stride,
dilation,
drop_path_rates=None,
block_fn=Bottleneck,
**block_kwargs,
):
super(RegStage, self).__init__()
self.grad_checkpointing = False
@ -367,8 +407,13 @@ class RegStage(nn.Module):
name = "b{}".format(i + 1)
self.add_module(
name, block_fn(
block_in_chs, out_chs, stride=block_stride, dilation=block_dilation,
drop_path_rate=dpr, **block_kwargs)
block_in_chs,
out_chs,
stride=block_stride,
dilation=block_dilation,
drop_path_rate=dpr,
**block_kwargs,
)
)
first_dilation = dilation
@ -389,12 +434,35 @@ class RegNet(nn.Module):
"""
def __init__(
self, cfg: RegNetCfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg',
drop_rate=0., drop_path_rate=0., zero_init_last=True):
self,
cfg: RegNetCfg,
in_chans=3,
num_classes=1000,
output_stride=32,
global_pool='avg',
drop_rate=0.,
drop_path_rate=0.,
zero_init_last=True,
**kwargs,
):
"""
Args:
cfg (RegNetCfg): Model architecture configuration
in_chans (int): Number of input channels (default: 3)
num_classes (int): Number of classifier classes (default: 1000)
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
global_pool (str): Global pooling type (default: 'avg')
drop_rate (float): Dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
zero_init_last (bool): Zero-init last weight of residual path
kwargs (dict): Extra kwargs overlayed onto cfg
"""
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
assert output_stride in (8, 16, 32)
cfg = replace(cfg, **kwargs) # update cfg with extra passed kwargs
# Construct the stem
stem_width = cfg.stem_width
@ -428,7 +496,7 @@ class RegNet(nn.Module):
self.final_conv = get_act_layer(cfg.act_layer)() if final_act else nn.Identity()
self.num_features = prev_width
self.head = ClassifierHead(
in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
@ -461,8 +529,12 @@ class RegNet(nn.Module):
dict(zip(arg_names, params)) for params in
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_br, stage_gs, stage_dpr)]
common_args = dict(
downsample=cfg.downsample, se_ratio=cfg.se_ratio, linear_out=cfg.linear_out,
act_layer=cfg.act_layer, norm_layer=cfg.norm_layer)
downsample=cfg.downsample,
se_ratio=cfg.se_ratio,
linear_out=cfg.linear_out,
act_layer=cfg.act_layer,
norm_layer=cfg.norm_layer,
)
return per_stage_args, common_args
@torch.jit.ignore
@ -518,7 +590,6 @@ def _init_weights(module, name='', zero_init_last=False):
def _filter_fn(state_dict):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
if 'classy_state_dict' in state_dict:
import re
state_dict = state_dict['classy_state_dict']['base_model']['model']

@ -156,8 +156,8 @@ def res2net50_26w_4s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4), **kwargs)
return _create_res2net('res2net50_26w_4s', pretrained, **model_args)
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4))
return _create_res2net('res2net50_26w_4s', pretrained, **dict(model_args, **kwargs))
@register_model
@ -167,8 +167,8 @@ def res2net101_26w_4s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model_args = dict(
block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4), **kwargs)
return _create_res2net('res2net101_26w_4s', pretrained, **model_args)
block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4))
return _create_res2net('res2net101_26w_4s', pretrained, **dict(model_args, **kwargs))
@register_model
@ -178,8 +178,8 @@ def res2net50_26w_6s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6), **kwargs)
return _create_res2net('res2net50_26w_6s', pretrained, **model_args)
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6))
return _create_res2net('res2net50_26w_6s', pretrained, **dict(model_args, **kwargs))
@register_model
@ -189,8 +189,8 @@ def res2net50_26w_8s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8), **kwargs)
return _create_res2net('res2net50_26w_8s', pretrained, **model_args)
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8))
return _create_res2net('res2net50_26w_8s', pretrained, **dict(model_args, **kwargs))
@register_model
@ -200,8 +200,8 @@ def res2net50_48w_2s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2), **kwargs)
return _create_res2net('res2net50_48w_2s', pretrained, **model_args)
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2))
return _create_res2net('res2net50_48w_2s', pretrained, **dict(model_args, **kwargs))
@register_model
@ -211,8 +211,8 @@ def res2net50_14w_8s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8), **kwargs)
return _create_res2net('res2net50_14w_8s', pretrained, **model_args)
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8))
return _create_res2net('res2net50_14w_8s', pretrained, **dict(model_args, **kwargs))
@register_model
@ -222,5 +222,5 @@ def res2next50(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4), **kwargs)
return _create_res2net('res2next50', pretrained, **model_args)
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4))
return _create_res2net('res2next50', pretrained, **dict(model_args, **kwargs))

@ -163,8 +163,8 @@ def resnest14d(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[1, 1, 1, 1],
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
return _create_resnest('resnest14d', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest14d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
@ -174,8 +174,8 @@ def resnest26d(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[2, 2, 2, 2],
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
return _create_resnest('resnest26d', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest26d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
@ -186,8 +186,8 @@ def resnest50d(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 4, 6, 3],
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
return _create_resnest('resnest50d', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest50d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
@ -198,8 +198,8 @@ def resnest101e(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 4, 23, 3],
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
return _create_resnest('resnest101e', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest101e', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
@ -210,8 +210,8 @@ def resnest200e(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 24, 36, 3],
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
return _create_resnest('resnest200e', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest200e', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
@ -222,8 +222,8 @@ def resnest269e(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 30, 48, 8],
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
return _create_resnest('resnest269e', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest269e', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
@ -233,8 +233,8 @@ def resnest50d_4s2x40d(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 4, 6, 3],
stem_type='deep', stem_width=32, avg_down=True, base_width=40, cardinality=2,
block_args=dict(radix=4, avd=True, avd_first=True), **kwargs)
return _create_resnest('resnest50d_4s2x40d', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=4, avd=True, avd_first=True))
return _create_resnest('resnest50d_4s2x40d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
@ -244,5 +244,5 @@ def resnest50d_1s4x24d(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 4, 6, 3],
stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4,
block_args=dict(radix=1, avd=True, avd_first=True), **kwargs)
return _create_resnest('resnest50d_1s4x24d', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=1, avd=True, avd_first=True))
return _create_resnest('resnest50d_1s4x24d', pretrained=pretrained, **dict(model_kwargs, **kwargs))

File diff suppressed because it is too large Load Diff

@ -37,7 +37,7 @@ import torch.nn as nn
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0, FilterResponseNormTlu2d, \
ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d
ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv
from ._registry import register_model
@ -276,8 +276,16 @@ class Bottleneck(nn.Module):
class DownsampleConv(nn.Module):
def __init__(
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, preact=True,
conv_layer=None, norm_layer=None):
self,
in_chs,
out_chs,
stride=1,
dilation=1,
first_dilation=None,
preact=True,
conv_layer=None,
norm_layer=None,
):
super(DownsampleConv, self).__init__()
self.conv = conv_layer(in_chs, out_chs, 1, stride=stride)
self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
@ -288,8 +296,16 @@ class DownsampleConv(nn.Module):
class DownsampleAvg(nn.Module):
def __init__(
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None,
preact=True, conv_layer=None, norm_layer=None):
self,
in_chs,
out_chs,
stride=1,
dilation=1,
first_dilation=None,
preact=True,
conv_layer=None,
norm_layer=None,
):
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
super(DownsampleAvg, self).__init__()
avg_stride = stride if dilation == 1 else 1
@ -334,9 +350,18 @@ class ResNetStage(nn.Module):
drop_path_rate = block_dpr[block_idx] if block_dpr else 0.
stride = stride if block_idx == 0 else 1
self.blocks.add_module(str(block_idx), block_fn(
prev_chs, out_chs, stride=stride, dilation=dilation, bottle_ratio=bottle_ratio, groups=groups,
first_dilation=first_dilation, proj_layer=proj_layer, drop_path_rate=drop_path_rate,
**layer_kwargs, **block_kwargs))
prev_chs,
out_chs,
stride=stride,
dilation=dilation,
bottle_ratio=bottle_ratio,
groups=groups,
first_dilation=first_dilation,
proj_layer=proj_layer,
drop_path_rate=drop_path_rate,
**layer_kwargs,
**block_kwargs,
))
prev_chs = out_chs
first_dilation = dilation
proj_layer = None
@ -413,21 +438,49 @@ class ResNetV2(nn.Module):
avg_down=False,
preact=True,
act_layer=nn.ReLU,
conv_layer=StdConv2d,
norm_layer=partial(GroupNormAct, num_groups=32),
conv_layer=StdConv2d,
drop_rate=0.,
drop_path_rate=0.,
zero_init_last=False,
):
"""
Args:
layers (List[int]) : number of layers in each block
channels (List[int]) : number of channels in each block:
num_classes (int): number of classification classes (default 1000)
in_chans (int): number of input (color) channels. (default 3)
global_pool (str): Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' (default 'avg')
output_stride (int): output stride of the network, 32, 16, or 8. (default 32)
width_factor (int): channel (width) multiplication factor
stem_chs (int): stem width (default: 64)
stem_type (str): stem type (default: '' == 7x7)
avg_down (bool): average pooling in residual downsampling (default: False)
preact (bool): pre-activiation (default: True)
act_layer (Union[str, nn.Module]): activation layer
norm_layer (Union[str, nn.Module]): normalization layer
conv_layer (nn.Module): convolution module
drop_rate: classifier dropout rate (default: 0.)
drop_path_rate: stochastic depth rate (default: 0.)
zero_init_last: zero-init last weight in residual path (default: False)
"""
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
wf = width_factor
norm_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
act_layer = get_act_layer(act_layer)
self.feature_info = []
stem_chs = make_div(stem_chs * wf)
self.stem = create_resnetv2_stem(
in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
in_chans,
stem_chs,
stem_type,
preact,
conv_layer=conv_layer,
norm_layer=norm_layer,
)
stem_feat = ('stem.conv3' if is_stem_deep(stem_type) else 'stem.conv') if preact else 'stem.norm'
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))
@ -693,86 +746,83 @@ def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs):
@register_model
def resnetv2_50(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
model_args = dict(layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d)
return _create_resnetv2('resnetv2_50', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_50d(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d', pretrained=pretrained,
model_args = dict(
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
stem_type='deep', avg_down=True, **kwargs)
stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_50d', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_50t(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50t', pretrained=pretrained,
model_args = dict(
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
stem_type='tiered', avg_down=True, **kwargs)
stem_type='tiered', avg_down=True)
return _create_resnetv2('resnetv2_50t', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_101(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101', pretrained=pretrained,
layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
model_args = dict(layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d)
return _create_resnetv2('resnetv2_101', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_101d(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101d', pretrained=pretrained,
model_args = dict(
layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
stem_type='deep', avg_down=True, **kwargs)
stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_101d', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_152(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152', pretrained=pretrained,
layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
model_args = dict(layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d)
return _create_resnetv2('resnetv2_152', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_152d(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152d', pretrained=pretrained,
model_args = dict(
layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
stem_type='deep', avg_down=True, **kwargs)
stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_152d', pretrained=pretrained, **dict(model_args, **kwargs))
# Experimental configs (may change / be removed)
@register_model
def resnetv2_50d_gn(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d_gn', pretrained=pretrained,
model_args = dict(
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=GroupNormAct,
stem_type='deep', avg_down=True, **kwargs)
stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_50d_gn', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_50d_evob(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d_evob', pretrained=pretrained,
model_args = dict(
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dB0,
stem_type='deep', avg_down=True, zero_init_last=True, **kwargs)
stem_type='deep', avg_down=True, zero_init_last=True)
return _create_resnetv2('resnetv2_50d_evob', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_50d_evos(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d_evos', pretrained=pretrained,
model_args = dict(
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dS0,
stem_type='deep', avg_down=True, **kwargs)
stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_50d_evos', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_50d_frn(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d_frn', pretrained=pretrained,
model_args = dict(
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=FilterResponseNormTlu2d,
stem_type='deep', avg_down=True, **kwargs)
stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_50d_frn', pretrained=pretrained, **dict(model_args, **kwargs))

@ -697,6 +697,13 @@ def _cfg(url='', **kwargs):
default_cfgs = generate_default_cfgs({
# re-finetuned augreg 21k FT on in1k weights
'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg(
hf_hub_id='timm/'),
'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg(),
'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg(
hf_hub_id='timm/'),
# How to train your ViT (augreg) weights, pretrained on 21k FT on in1k
'vit_tiny_patch16_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
@ -751,13 +758,6 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/',
custom_load=True, input_size=(3, 384, 384), crop_pct=1.0),
# re-finetuned augreg 21k FT on in1k weights
'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg(
hf_hub_id='timm/'),
'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg(),
'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg(
hf_hub_id='timm/'),
# patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k
'vit_base_patch16_224.orig_in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth',
@ -802,7 +802,6 @@ default_cfgs = generate_default_cfgs({
'vit_giant_patch14_224.untrained': _cfg(url=''),
'vit_gigantic_patch14_224.untrained': _cfg(url=''),
# patch models, imagenet21k (weights from official Google JAX impl)
'vit_large_patch32_224.orig_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth',
@ -869,7 +868,6 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
# ViT ImageNet-21K-P pretraining by MILL
'vit_base_patch16_224_miil.in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth',
@ -880,7 +878,7 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/',
mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'),
# custom timm variants
# Custom timm variants
'vit_base_patch16_rpn_224.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth',
hf_hub_id='timm/'),
@ -896,52 +894,6 @@ default_cfgs = generate_default_cfgs({
'vit_base_patch16_gap_224': _cfg(),
# CLIP pretrained image tower and related fine-tuned weights
'vit_base_patch32_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_base_patch16_clip_224.laion2b': _cfg(
#hf_hub_id='laion/CLIP-ViT-B-16-laion2B-s34B-b88K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_large_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768),
'vit_huge_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_giant_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_base_patch32_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
'vit_large_patch14_clip_336.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
'vit_huge_patch14_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
'vit_huge_patch14_clip_336.laion2b_ft_in1k': _cfg(
hf_hub_id='',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
'vit_base_patch32_clip_224.laion2b_ft_in12k_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
@ -973,28 +925,52 @@ default_cfgs = generate_default_cfgs({
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
'vit_base_patch32_clip_224.laion2b_ft_in12k': _cfg(
#hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
'vit_base_patch16_clip_224.laion2b_ft_in12k': _cfg(
'vit_base_patch32_clip_224.openai_ft_in12k_in1k': _cfg(
# hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
'vit_large_patch14_clip_224.laion2b_ft_in12k': _cfg(
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=11821),
'vit_huge_patch14_clip_224.laion2b_ft_in12k': _cfg(
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821),
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
'vit_large_patch14_clip_336.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
'vit_base_patch32_clip_224.openai': _cfg(
'vit_base_patch32_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_base_patch16_clip_224.openai': _cfg(
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_large_patch14_clip_224.openai': _cfg(
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'),
'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0),
'vit_large_patch14_clip_336.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD,
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
'vit_huge_patch14_clip_224.laion2b_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
'vit_huge_patch14_clip_336.laion2b_ft_in1k': _cfg(
hf_hub_id='',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
'vit_base_patch32_clip_224.openai_ft_in1k': _cfg(
hf_hub_id='timm/',
@ -1010,30 +986,21 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
'vit_base_patch32_clip_224.openai_ft_in12k_in1k': _cfg(
#hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k_in1k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD),
'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95),
'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg(
'vit_base_patch32_clip_224.laion2b_ft_in12k': _cfg(
#hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
'vit_base_patch16_clip_224.laion2b_ft_in12k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'),
'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg(
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
'vit_large_patch14_clip_224.laion2b_ft_in12k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0),
'vit_large_patch14_clip_336.openai_ft_in12k_in1k': _cfg(
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=11821),
'vit_huge_patch14_clip_224.laion2b_ft_in12k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'),
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821),
'vit_base_patch32_clip_224.openai_ft_in12k': _cfg(
#hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k',
# hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821),
'vit_base_patch16_clip_224.openai_ft_in12k': _cfg(
hf_hub_id='timm/',
@ -1042,6 +1009,41 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821),
'vit_base_patch32_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_base_patch16_clip_224.laion2b': _cfg(
# hf_hub_id='laion/CLIP-ViT-B-16-laion2B-s34B-b88K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512),
'vit_large_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768),
'vit_huge_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_giant_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024),
'vit_gigantic_patch14_clip_224.laion2b': _cfg(
hf_hub_id='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280),
'vit_base_patch32_clip_224.openai': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_base_patch16_clip_224.openai': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_large_patch14_clip_224.openai': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768),
# experimental (may be removed)
'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),
'vit_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95),
@ -1152,8 +1154,8 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
def vit_tiny_patch16_224(pretrained=False, **kwargs):
""" ViT-Tiny (Vit-Ti/16)
"""
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1161,8 +1163,8 @@ def vit_tiny_patch16_224(pretrained=False, **kwargs):
def vit_tiny_patch16_384(pretrained=False, **kwargs):
""" ViT-Tiny (Vit-Ti/16) @ 384x384.
"""
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs)
model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3)
model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1170,8 +1172,8 @@ def vit_tiny_patch16_384(pretrained=False, **kwargs):
def vit_small_patch32_224(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/32)
"""
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6)
model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1179,8 +1181,8 @@ def vit_small_patch32_224(pretrained=False, **kwargs):
def vit_small_patch32_384(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/32) at 384x384.
"""
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6)
model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1188,8 +1190,8 @@ def vit_small_patch32_384(pretrained=False, **kwargs):
def vit_small_patch16_224(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1197,8 +1199,8 @@ def vit_small_patch16_224(pretrained=False, **kwargs):
def vit_small_patch16_384(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/16)
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6)
model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1206,8 +1208,8 @@ def vit_small_patch16_384(pretrained=False, **kwargs):
def vit_small_patch8_224(pretrained=False, **kwargs):
""" ViT-Small (ViT-S/8)
"""
model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs)
model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6)
model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1216,8 +1218,8 @@ def vit_base_patch32_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)
model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1226,8 +1228,8 @@ def vit_base_patch32_384(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12)
model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1236,8 +1238,8 @@ def vit_base_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1246,8 +1248,8 @@ def vit_base_patch16_384(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12)
model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1256,8 +1258,8 @@ def vit_base_patch8_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12)
model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1265,8 +1267,8 @@ def vit_base_patch8_224(pretrained=False, **kwargs):
def vit_large_patch32_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights.
"""
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16)
model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1275,8 +1277,8 @@ def vit_large_patch32_384(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16)
model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1285,8 +1287,8 @@ def vit_large_patch16_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1295,8 +1297,8 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929).
ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer.
"""
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16)
model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1304,8 +1306,8 @@ def vit_large_patch16_384(pretrained=False, **kwargs):
def vit_large_patch14_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/14)
"""
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16)
model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1313,8 +1315,8 @@ def vit_large_patch14_224(pretrained=False, **kwargs):
def vit_huge_patch14_224(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929).
"""
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16)
model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1322,8 +1324,8 @@ def vit_huge_patch14_224(pretrained=False, **kwargs):
def vit_giant_patch14_224(pretrained=False, **kwargs):
""" ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
"""
model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16)
model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1331,8 +1333,9 @@ def vit_giant_patch14_224(pretrained=False, **kwargs):
def vit_gigantic_patch14_224(pretrained=False, **kwargs):
""" ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
"""
model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16)
model = _create_vision_transformer(
'vit_gigantic_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1341,8 +1344,9 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929).
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False)
model = _create_vision_transformer(
'vit_base_patch16_224_miil', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1352,8 +1356,9 @@ def vit_medium_patch16_gap_240(pretrained=False, **kwargs):
"""
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
model = _create_vision_transformer('vit_medium_patch16_gap_240', pretrained=pretrained, **model_kwargs)
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
model = _create_vision_transformer(
'vit_medium_patch16_gap_240', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1363,8 +1368,9 @@ def vit_medium_patch16_gap_256(pretrained=False, **kwargs):
"""
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
model = _create_vision_transformer('vit_medium_patch16_gap_256', pretrained=pretrained, **model_kwargs)
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
model = _create_vision_transformer(
'vit_medium_patch16_gap_256', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1374,8 +1380,9 @@ def vit_medium_patch16_gap_384(pretrained=False, **kwargs):
"""
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False,
global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs)
model = _create_vision_transformer('vit_medium_patch16_gap_384', pretrained=pretrained, **model_kwargs)
global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False)
model = _create_vision_transformer(
'vit_medium_patch16_gap_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1384,9 +1391,9 @@ def vit_base_patch16_gap_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False,
global_pool=kwargs.get('global_pool', 'avg'), fc_norm=False, **kwargs)
model = _create_vision_transformer('vit_base_patch16_gap_224', pretrained=pretrained, **model_kwargs)
patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
model = _create_vision_transformer(
'vit_base_patch16_gap_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1395,8 +1402,9 @@ def vit_base_patch32_clip_224(pretrained=False, **kwargs):
""" ViT-B/32 CLIP image tower @ 224x224
"""
model_kwargs = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_base_patch32_clip_224', pretrained=pretrained, **model_kwargs)
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_base_patch32_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1405,8 +1413,9 @@ def vit_base_patch32_clip_384(pretrained=False, **kwargs):
""" ViT-B/32 CLIP image tower @ 384x384
"""
model_kwargs = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_base_patch32_clip_384', pretrained=pretrained, **model_kwargs)
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_base_patch32_clip_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1415,8 +1424,9 @@ def vit_base_patch32_clip_448(pretrained=False, **kwargs):
""" ViT-B/32 CLIP image tower @ 448x448
"""
model_kwargs = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_base_patch32_clip_448', pretrained=pretrained, **model_kwargs)
patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_base_patch32_clip_448', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1424,9 +1434,9 @@ def vit_base_patch32_clip_448(pretrained=False, **kwargs):
def vit_base_patch16_clip_224(pretrained=False, **kwargs):
""" ViT-B/16 CLIP image tower
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_base_patch16_clip_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_base_patch16_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1434,9 +1444,9 @@ def vit_base_patch16_clip_224(pretrained=False, **kwargs):
def vit_base_patch16_clip_384(pretrained=False, **kwargs):
""" ViT-B/16 CLIP image tower @ 384x384
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_base_patch16_clip_384', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_base_patch16_clip_384', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1444,9 +1454,9 @@ def vit_base_patch16_clip_384(pretrained=False, **kwargs):
def vit_large_patch14_clip_224(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/14) CLIP image tower
"""
model_kwargs = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_large_patch14_clip_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_large_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1454,9 +1464,9 @@ def vit_large_patch14_clip_224(pretrained=False, **kwargs):
def vit_large_patch14_clip_336(pretrained=False, **kwargs):
""" ViT-Large model (ViT-L/14) CLIP image tower @ 336x336
"""
model_kwargs = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_large_patch14_clip_336', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_large_patch14_clip_336', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1464,9 +1474,9 @@ def vit_large_patch14_clip_336(pretrained=False, **kwargs):
def vit_huge_patch14_clip_224(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) CLIP image tower.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_clip_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_huge_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1474,9 +1484,9 @@ def vit_huge_patch14_clip_224(pretrained=False, **kwargs):
def vit_huge_patch14_clip_336(pretrained=False, **kwargs):
""" ViT-Huge model (ViT-H/14) CLIP image tower @ 336x336
"""
model_kwargs = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_clip_336', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_huge_patch14_clip_336', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1486,20 +1496,32 @@ def vit_giant_patch14_clip_224(pretrained=False, **kwargs):
Pretrained weights from CLIP image tower.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16,
pre_norm=True, norm_layer=nn.LayerNorm, **kwargs)
model = _create_vision_transformer('vit_giant_patch14_clip_224', pretrained=pretrained, **model_kwargs)
patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_giant_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def vit_gigantic_patch14_clip_224(pretrained=False, **kwargs):
""" ViT-bigG model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560
Pretrained weights from CLIP image tower.
"""
model_kwargs = dict(
patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_gigantic_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
# Experimental models below
@register_model
def vit_base_patch32_plus_256(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32+)
"""
model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs)
model = _create_vision_transformer('vit_base_patch32_plus_256', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5)
model = _create_vision_transformer(
'vit_base_patch32_plus_256', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1507,8 +1529,9 @@ def vit_base_patch32_plus_256(pretrained=False, **kwargs):
def vit_base_patch16_plus_240(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16+)
"""
model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs)
model = _create_vision_transformer('vit_base_patch16_plus_240', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5)
model = _create_vision_transformer(
'vit_base_patch16_plus_240', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1517,9 +1540,10 @@ def vit_base_patch16_rpn_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ residual post-norm
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, class_token=False,
block_fn=ResPostBlock, global_pool=kwargs.pop('global_pool', 'avg'), **kwargs)
model = _create_vision_transformer('vit_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs)
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5,
class_token=False, block_fn=ResPostBlock, global_pool='avg')
model = _create_vision_transformer(
'vit_base_patch16_rpn_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1529,8 +1553,9 @@ def vit_small_patch16_36x1_224(pretrained=False, **kwargs):
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5, **kwargs)
model = _create_vision_transformer('vit_small_patch16_36x1_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5)
model = _create_vision_transformer(
'vit_small_patch16_36x1_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1541,8 +1566,9 @@ def vit_small_patch16_18x2_224(pretrained=False, **kwargs):
Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow.
"""
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock, **kwargs)
model = _create_vision_transformer('vit_small_patch16_18x2_224', pretrained=pretrained, **model_kwargs)
patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock)
model = _create_vision_transformer(
'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1551,27 +1577,26 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs):
""" ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove.
Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs)
model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock)
model = _create_vision_transformer(
'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def eva_large_patch14_196(pretrained=False, **kwargs):
""" EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain"""
model_kwargs = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs)
model = _create_vision_transformer('eva_large_patch14_196', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg')
model = _create_vision_transformer(
'eva_large_patch14_196', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@register_model
def eva_large_patch14_336(pretrained=False, **kwargs):
""" EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain"""
model_kwargs = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs)
model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg')
model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1579,8 +1604,8 @@ def eva_large_patch14_336(pretrained=False, **kwargs):
def flexivit_small(pretrained=False, **kwargs):
""" FlexiViT-Small
"""
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, **kwargs)
model = _create_vision_transformer('flexivit_small', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True)
model = _create_vision_transformer('flexivit_small', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1588,8 +1613,8 @@ def flexivit_small(pretrained=False, **kwargs):
def flexivit_base(pretrained=False, **kwargs):
""" FlexiViT-Base
"""
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, **kwargs)
model = _create_vision_transformer('flexivit_base', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True)
model = _create_vision_transformer('flexivit_base', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model
@ -1597,6 +1622,6 @@ def flexivit_base(pretrained=False, **kwargs):
def flexivit_large(pretrained=False, **kwargs):
""" FlexiViT-Large
"""
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, **kwargs)
model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **model_kwargs)
model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True)
model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_kwargs, **kwargs))
return model

@ -181,8 +181,18 @@ class SequentialAppendList(nn.Sequential):
class OsaBlock(nn.Module):
def __init__(
self, in_chs, mid_chs, out_chs, layer_per_block, residual=False,
depthwise=False, attn='', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path=None):
self,
in_chs,
mid_chs,
out_chs,
layer_per_block,
residual=False,
depthwise=False,
attn='',
norm_layer=BatchNormAct2d,
act_layer=nn.ReLU,
drop_path=None,
):
super(OsaBlock, self).__init__()
self.residual = residual
@ -232,9 +242,20 @@ class OsaBlock(nn.Module):
class OsaStage(nn.Module):
def __init__(
self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block, downsample=True,
residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d, act_layer=nn.ReLU,
drop_path_rates=None):
self,
in_chs,
mid_chs,
out_chs,
block_per_stage,
layer_per_block,
downsample=True,
residual=True,
depthwise=False,
attn='ese',
norm_layer=BatchNormAct2d,
act_layer=nn.ReLU,
drop_path_rates=None,
):
super(OsaStage, self).__init__()
self.grad_checkpointing = False
@ -270,16 +291,38 @@ class OsaStage(nn.Module):
class VovNet(nn.Module):
def __init__(
self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4,
output_stride=32, norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path_rate=0.):
""" VovNet (v2)
self,
cfg,
in_chans=3,
num_classes=1000,
global_pool='avg',
output_stride=32,
norm_layer=BatchNormAct2d,
act_layer=nn.ReLU,
drop_rate=0.,
drop_path_rate=0.,
**kwargs,
):
"""
Args:
cfg (dict): Model architecture configuration
in_chans (int): Number of input channels (default: 3)
num_classes (int): Number of classifier classes (default: 1000)
global_pool (str): Global pooling type (default: 'avg')
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
norm_layer (Union[str, nn.Module]): normalization layer
act_layer (Union[str, nn.Module]): activation layer
drop_rate (float): Dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
kwargs (dict): Extra kwargs overlayed onto cfg
"""
super(VovNet, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
assert stem_stride in (4, 2)
assert output_stride == 32 # FIXME support dilation
cfg = dict(cfg, **kwargs)
stem_stride = cfg.get("stem_stride", 4)
stem_chs = cfg["stem_chs"]
stage_conv_chs = cfg["stage_conv_chs"]
stage_out_chs = cfg["stage_out_chs"]
@ -307,9 +350,15 @@ class VovNet(nn.Module):
for i in range(4): # num_stages
downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4
stages += [OsaStage(
in_ch_list[i], stage_conv_chs[i], stage_out_chs[i], block_per_stage[i], layer_per_block,
downsample=downsample, drop_path_rates=stage_dpr[i], **stage_args)
]
in_ch_list[i],
stage_conv_chs[i],
stage_out_chs[i],
block_per_stage[i],
layer_per_block,
downsample=downsample,
drop_path_rates=stage_dpr[i],
**stage_args,
)]
self.num_features = stage_out_chs[i]
current_stride *= 2 if downsample else 1
self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')]
@ -324,7 +373,6 @@ class VovNet(nn.Module):
elif isinstance(m, nn.Linear):
nn.init.zeros_(m.bias)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(

@ -216,7 +216,7 @@ class XceptionAligned(nn.Module):
num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))]
self.act = act_layer(inplace=True) if preact else nn.Identity()
self.head = ClassifierHead(
in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
@torch.jit.ignore
def group_matcher(self, coarse=False):

@ -8,7 +8,7 @@ from .distributed import distribute_bn, reduce_tensor, init_distributed_device,\
from .jit import set_jit_legacy, set_jit_fuser
from .log import setup_default_logging, FormatterNoInfo
from .metrics import AverageMeter, accuracy
from .misc import natural_key, add_bool_arg
from .misc import natural_key, add_bool_arg, ParseKwargs
from .model import unwrap_model, get_state_dict, freeze, unfreeze
from .model_ema import ModelEma, ModelEmaV2
from .random import random_seed

@ -2,6 +2,8 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import argparse
import ast
import re
@ -16,3 +18,15 @@ def add_bool_arg(parser, name, default=False, help=''):
group.add_argument('--' + name, dest=dest_name, action='store_true', help=help)
group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help)
parser.set_defaults(**{dest_name: default})
class ParseKwargs(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
kw = {}
for value in values:
key, value = value.split('=')
try:
kw[key] = ast.literal_eval(value)
except ValueError:
kw[key] = str(value) # fallback to string (avoid need to escape on command line)
setattr(namespace, self.dest, kw)

@ -7,6 +7,8 @@ import fnmatch
import torch
from torchvision.ops.misc import FrozenBatchNorm2d
from timm.layers import BatchNormAct2d, SyncBatchNormAct, FrozenBatchNormAct2d,\
freeze_batch_norm_2d, unfreeze_batch_norm_2d
from .model_ema import ModelEma
@ -100,70 +102,6 @@ def extract_spp_stats(
return hook.stats
def freeze_batch_norm_2d(module):
"""
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
returned. Otherwise, the module is walked recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
res = FrozenBatchNorm2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = freeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def unfreeze_batch_norm_2d(module):
"""
Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance
of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked
recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
if isinstance(module, FrozenBatchNorm2d):
res = torch.nn.BatchNorm2d(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = unfreeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'):
"""
Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is
@ -179,7 +117,12 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
"""
assert mode in ["freeze", "unfreeze"], '`mode` must be one of "freeze" or "unfreeze"'
if isinstance(root_module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
if isinstance(root_module, (
torch.nn.modules.batchnorm.BatchNorm2d,
torch.nn.modules.batchnorm.SyncBatchNorm,
BatchNormAct2d,
SyncBatchNormAct,
)):
# Raise assertion here because we can't convert it in place
raise AssertionError(
"You have provided a batch norm layer as the `root module`. Please use "
@ -213,13 +156,18 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
# It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't
# convert it in place, but will return the converted result. In this case `res` holds the converted
# result and we may try to re-assign the named module
if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
if isinstance(m, (
torch.nn.modules.batchnorm.BatchNorm2d,
torch.nn.modules.batchnorm.SyncBatchNorm,
BatchNormAct2d,
SyncBatchNormAct,
)):
_add_submodule(root_module, n, res)
# Unfreeze batch norm
else:
res = unfreeze_batch_norm_2d(m)
# Ditto. See note above in mode == 'freeze' branch
if isinstance(m, FrozenBatchNorm2d):
if isinstance(m, (FrozenBatchNorm2d, FrozenBatchNormAct2d)):
_add_submodule(root_module, n, res)

@ -1 +1 @@
__version__ = '0.8.4dev0'
__version__ = '0.8.8dev0'

@ -89,56 +89,58 @@ parser.add_argument('--data-dir', metavar='DIR',
parser.add_argument('--dataset', metavar='NAME', default='',
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
group.add_argument('--train-split', metavar='NAME', default='train',
help='dataset train split (default: train)')
help='dataset train split (default: train)')
group.add_argument('--val-split', metavar='NAME', default='validation',
help='dataset validation split (default: validation)')
help='dataset validation split (default: validation)')
group.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
group.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
help='path to class to idx mapping file (default: "")')
# Model parameters
group = parser.add_argument_group('Model parameters')
group.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
help='Name of model to train (default: "resnet50")')
help='Name of model to train (default: "resnet50")')
group.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
help='Start with pretrained version of specified network (if avail)')
group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
help='Initialize model from this checkpoint (default: none)')
group.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
help='Resume full model and optimizer state from checkpoint (default: none)')
group.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
help='prevent resume of optimizer state when resuming model')
group.add_argument('--num-classes', type=int, default=None, metavar='N',
help='number of label classes (Model default if None)')
help='number of label classes (Model default if None)')
group.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
group.add_argument('--img-size', type=int, default=None, metavar='N',
help='Image size (default: None => model default)')
help='Image size (default: None => model default)')
group.add_argument('--in-chans', type=int, default=None, metavar='N',
help='Image input channels (default: None => 3)')
help='Image input channels (default: None => 3)')
group.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
metavar='N N N',
help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
group.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='Input image center crop percent (for validation only)')
metavar='N', help='Input image center crop percent (for validation only)')
group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
help='Override mean pixel value of dataset')
group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of dataset')
help='Override std deviation of dataset')
group.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
help='Image resize interpolation type (overrides model)')
group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='Input batch size for training (default: 128)')
help='Input batch size for training (default: 128)')
group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
help='Validation batch size override (default: None)')
help='Validation batch size override (default: None)')
group.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
help='Use channels_last memory layout')
group.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
group.add_argument('--grad-checkpointing', action='store_true', default=False,
help='Enable gradient checkpointing through model blocks/stages')
help='Enable gradient checkpointing through model blocks/stages')
group.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
help='enable experimental fast-norm')
group.add_argument('--model-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
scripting_group = group.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
@ -151,199 +153,200 @@ scripting_group.add_argument('--aot-autograd', default=False, action='store_true
# Optimizer parameters
group = parser.add_argument_group('Optimizer parameters')
group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd")')
help='Optimizer (default: "sgd")')
group.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)')
help='Optimizer Epsilon (default: None, use opt default)')
group.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
help='Optimizer Betas (default: None, use opt default)')
group.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)')
help='Optimizer momentum (default: 0.9)')
group.add_argument('--weight-decay', type=float, default=2e-5,
help='weight decay (default: 2e-5)')
help='weight decay (default: 2e-5)')
group.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
help='Clip gradient norm (default: None, no clipping)')
group.add_argument('--clip-mode', type=str, default='norm',
help='Gradient clipping mode. One of ("norm", "value", "agc")')
help='Gradient clipping mode. One of ("norm", "value", "agc")')
group.add_argument('--layer-decay', type=float, default=None,
help='layer-wise learning rate decay (default: None)')
help='layer-wise learning rate decay (default: None)')
group.add_argument('--opt-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
# Learning rate schedule parameters
group = parser.add_argument_group('Learning rate schedule parameters')
group.add_argument('--sched', type=str, default='cosine', metavar='SCHEDULER',
help='LR scheduler (default: "step"')
help='LR scheduler (default: "step"')
group.add_argument('--sched-on-updates', action='store_true', default=False,
help='Apply LR scheduler step on update instead of epoch end.')
help='Apply LR scheduler step on update instead of epoch end.')
group.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate, overrides lr-base if set (default: None)')
help='learning rate, overrides lr-base if set (default: None)')
group.add_argument('--lr-base', type=float, default=0.1, metavar='LR',
help='base learning rate: lr = lr_base * global_batch_size / base_size')
help='base learning rate: lr = lr_base * global_batch_size / base_size')
group.add_argument('--lr-base-size', type=int, default=256, metavar='DIV',
help='base learning rate batch size (divisor, default: 256).')
help='base learning rate batch size (divisor, default: 256).')
group.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE',
help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
help='learning rate noise on/off epoch percentages')
group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
help='learning rate noise limit percent (default: 0.67)')
group.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
help='learning rate noise std-dev (default: 1.0)')
group.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)')
help='learning rate cycle len multiplier (default: 1.0)')
group.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',
help='amount to decay each learning rate cycle (default: 0.5)')
help='amount to decay each learning rate cycle (default: 0.5)')
group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit, cycles enabled if > 1')
help='learning rate cycle limit, cycles enabled if > 1')
group.add_argument('--lr-k-decay', type=float, default=1.0,
help='learning rate k-decay for cosine/poly (default: 1.0)')
help='learning rate k-decay for cosine/poly (default: 1.0)')
group.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR',
help='warmup learning rate (default: 1e-5)')
help='warmup learning rate (default: 1e-5)')
group.add_argument('--min-lr', type=float, default=0, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (default: 0)')
help='lower lr bound for cyclic schedulers that hit 0 (default: 0)')
group.add_argument('--epochs', type=int, default=300, metavar='N',
help='number of epochs to train (default: 300)')
help='number of epochs to train (default: 300)')
group.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
group.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
help='manual epoch number (useful on restarts)')
group.add_argument('--decay-milestones', default=[90, 180, 270], type=int, nargs='+', metavar="MILESTONES",
help='list of decay epoch indices for multistep lr. must be increasing')
help='list of decay epoch indices for multistep lr. must be increasing')
group.add_argument('--decay-epochs', type=float, default=90, metavar='N',
help='epoch interval to decay LR')
help='epoch interval to decay LR')
group.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
help='epochs to warmup LR, if scheduler supports')
group.add_argument('--warmup-prefix', action='store_true', default=False,
help='Exclude warmup period from decay schedule.'),
help='Exclude warmup period from decay schedule.'),
group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
group.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10)')
help='patience epochs for Plateau LR scheduler (default: 10)')
group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
help='LR decay rate (default: 0.1)')
# Augmentation & regularization parameters
group = parser.add_argument_group('Augmentation and regularization parameters')
group.add_argument('--no-aug', action='store_true', default=False,
help='Disable all training augmentation, override other train aug args')
help='Disable all training augmentation, override other train aug args')
group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
group.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
help='Random resize scale (default: 0.08 1.0)')
group.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
group.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
help='Horizontal flip training aug probability')
group.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
help='Vertical flip training aug probability')
group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
help='Color jitter factor (default: 0.4)')
group.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
group.add_argument('--aug-repeats', type=float, default=0,
help='Number of augmentation repetitions (distributed training only) (default: 0)')
help='Number of augmentation repetitions (distributed training only) (default: 0)')
group.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
group.add_argument('--jsd-loss', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
group.add_argument('--bce-loss', action='store_true', default=False,
help='Enable BCE loss w/ Mixup/CutMix use.')
help='Enable BCE loss w/ Mixup/CutMix use.')
group.add_argument('--bce-target-thresh', type=float, default=None,
help='Threshold for binarizing softened BCE targets (default: None, disabled)')
help='Threshold for binarizing softened BCE targets (default: None, disabled)')
group.add_argument('--reprob', type=float, default=0., metavar='PCT',
help='Random erase prob (default: 0.)')
help='Random erase prob (default: 0.)')
group.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
help='Random erase mode (default: "pixel")')
group.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
help='Random erase count (default: 1)')
group.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
help='Do not random erase first (clean) augmentation split')
group.add_argument('--mixup', type=float, default=0.0,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
group.add_argument('--cutmix', type=float, default=0.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
group.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
help='Probability of performing mixup or cutmix when either/both is enabled')
group.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
help='Probability of switching to cutmix when both mixup and cutmix enabled')
group.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
group.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
help='Label smoothing (default: 0.1)')
group.add_argument('--train-interpolation', type=str, default='random',
help='Training interpolation (random, bilinear, bicubic default: "random")')
help='Training interpolation (random, bilinear, bicubic default: "random")')
group.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.)')
help='Dropout rate (default: 0.)')
group.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
group.add_argument('--drop-path', type=float, default=None, metavar='PCT',
help='Drop path rate (default: None)')
help='Drop path rate (default: None)')
group.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
help='Drop block rate (default: None)')
# Batch norm parameters (only works with gen_efficientnet based models currently)
group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.')
group.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
help='BatchNorm momentum override (if not None)')
group.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
help='BatchNorm epsilon override (if not None)')
group.add_argument('--sync-bn', action='store_true',
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
group.add_argument('--dist-bn', type=str, default='reduce',
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
group.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.')
help='Enable separate BN layers per augmentation split.')
# Model Exponential Moving Average
group = parser.add_argument_group('Model exponential moving average parameters')
group.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
help='Enable tracking moving average of model weights')
group.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
group.add_argument('--model-ema-decay', type=float, default=0.9998,
help='decay factor for model weights moving average (default: 0.9998)')
help='decay factor for model weights moving average (default: 0.9998)')
# Misc
group = parser.add_argument_group('Miscellaneous parameters')
group.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
help='random seed (default: 42)')
group.add_argument('--worker-seeding', type=str, default='all',
help='worker seed mode (default: all)')
help='worker seed mode (default: all)')
group.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
help='how many batches to wait before logging training status')
group.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
help='how many batches to wait before writing recovery checkpoint')
group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',
help='number of checkpoints to keep (default: 10)')
help='number of checkpoints to keep (default: 10)')
group.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 4)')
help='how many training processes to use (default: 4)')
group.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging')
help='save images of input bathes every log interval for debugging')
group.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
group.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
help='Force broadcast buffers for native DDP to off.')
group.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
group.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
help='disable fast prefetcher')
group.add_argument('--output', default='', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
help='path to output folder (default: none, current dir)')
group.add_argument('--experiment', default='', type=str, metavar='NAME',
help='name of train experiment, name of sub-folder for output')
help='name of train experiment, name of sub-folder for output')
group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"')
help='Best metric (default: "top1"')
group.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
group.add_argument("--local_rank", default=0, type=int)
group.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
help='use the multi-epochs-loader to save time at the beginning of every epoch')
group.add_argument('--log-wandb', action='store_true', default=False,
help='log training and validation metrics to wandb')
help='log training and validation metrics to wandb')
def _parse_args():
@ -371,8 +374,6 @@ def main():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
if args.data and not args.data_dir:
args.data_dir = args.data
args.prefetcher = not args.no_prefetcher
device = utils.init_distributed_device(args)
if args.distributed:
@ -383,14 +384,6 @@ def main():
_logger.info(f'Training with a single process on 1 device ({args.device}).')
assert args.rank >= 0
if utils.is_primary(args) and args.log_wandb:
if has_wandb:
wandb.init(project=args.experiment, config=args)
else:
_logger.warning(
"You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
amp_dtype = torch.float16
@ -432,6 +425,7 @@ def main():
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint,
**args.model_kwargs,
)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
@ -504,7 +498,11 @@ def main():
f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
optimizer = create_optimizer_v2(
model,
**optimizer_kwargs(cfg=args),
**args.opt_kwargs,
)
# setup automatic mixed-precision (AMP) loss scaling and op casting
amp_autocast = suppress # do nothing
@ -559,6 +557,8 @@ def main():
# NOTE: EMA model does not need to be wrapped by DDP
# create the train and eval datasets
if args.data and not args.data_dir:
args.data_dir = args.data
dataset_train = create_dataset(
args.dataset,
root=args.data_dir,
@ -712,6 +712,14 @@ def main():
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
if utils.is_primary(args) and args.log_wandb:
if has_wandb:
wandb.init(project=args.experiment, config=args)
else:
_logger.warning(
"You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
# setup learning rate schedule and starting epoch
updates_per_epoch = len(loader_train)
lr_scheduler, num_epochs = create_scheduler_v2(

@ -26,7 +26,7 @@ from timm.data import create_dataset, create_loader, resolve_data_config, RealLa
from timm.layers import apply_test_time_pool, set_fast_norm
from timm.models import create_model, load_checkpoint, is_model, list_models
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \
decay_batch_step, check_batch_size_retry
decay_batch_step, check_batch_size_retry, ParseKwargs
try:
from apex import amp
@ -71,6 +71,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--in-chans', type=int, default=None, metavar='N',
help='Image input channels (default: None => 3)')
parser.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--use-train-size', action='store_true', default=False,
@ -123,6 +125,8 @@ parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', default=False, action='store_true',
@ -181,13 +185,20 @@ def validate(args):
set_fast_norm()
# create model
in_chans = 3
if args.in_chans is not None:
in_chans = args.in_chans
elif args.input_size is not None:
in_chans = args.input_size[0]
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
in_chans=3,
in_chans=in_chans,
global_pool=args.gp,
scriptable=args.torchscript,
**args.model_kwargs,
)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
@ -232,8 +243,9 @@ def validate(args):
criterion = nn.CrossEntropyLoss().to(device)
root_dir = args.data or args.data_dir
dataset = create_dataset(
root=args.data,
root=root_dir,
name=args.dataset,
split=args.split,
download=args.dataset_download,
@ -389,7 +401,7 @@ def main():
if args.model == 'all':
# validate all models in a list of names with pretrained checkpoints
args.pretrained = True
model_names = list_models(pretrained=True, exclude_filters=['*_in21k', '*_in22k', '*_dino'])
model_names = list_models('convnext*', pretrained=True, exclude_filters=['*_in21k', '*_in22k', '*in12k', '*_dino', '*fcmae'])
model_cfgs = [(n, '') for n in model_names]
elif not is_model(args.model):
# model name doesn't exist, try as wildcard filter

Loading…
Cancel
Save