diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 5690c88c..70352d0a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -40,9 +40,10 @@ jobs: - name: Install torch on ubuntu if: startsWith(matrix.os, 'ubuntu') run: | - pip install --no-cache-dir torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html + sudo sed -i 's/azure\.//' /etc/apt/sources.list sudo apt update sudo apt install -y google-perftools + pip install --no-cache-dir torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html - name: Install requirements run: | pip install -r requirements.txt diff --git a/README.md b/README.md index 459b70c1..287b6f66 100644 --- a/README.md +++ b/README.md @@ -21,12 +21,73 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before ## What's New -### 🤗 Survey: Feedback Appreciated 🤗 - -For a few months now, `timm` has been part of the Hugging Face ecosystem. Yearly, we survey users of our tools to see what we could do better, what we need to continue doing, or what we need to stop doing. - -If you have a couple of minutes and want to participate in shaping the future of the ecosystem, please share your thoughts: -[**hf.co/oss-survey**](https://hf.co/oss-survey) 🙏 +* ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗ +* Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch. + +### Jan 20, 2023 +* Add two convnext 12k -> 1k fine-tunes at 384x384 + * `convnext_tiny.in12k_ft_in1k_384` - 85.1 @ 384 + * `convnext_small.in12k_ft_in1k_384` - 86.2 @ 384 + +* Push all MaxxViT weights to HF hub, and add new ImageNet-12k -> 1k fine-tunes for `rw` base MaxViT and CoAtNet 1/2 models + +|model |top1 |top5 |samples / sec |Params (M) |GMAC |Act (M)| +|------------------------------------------------------------------------------------------------------------------------|----:|----:|--------------:|--------------:|-----:|------:| +|[maxvit_xlarge_tf_512.in21k_ft_in1k](https://huggingface.co/timm/maxvit_xlarge_tf_512.in21k_ft_in1k) |88.53|98.64| 21.76| 475.77|534.14|1413.22| +|[maxvit_xlarge_tf_384.in21k_ft_in1k](https://huggingface.co/timm/maxvit_xlarge_tf_384.in21k_ft_in1k) |88.32|98.54| 42.53| 475.32|292.78| 668.76| +|[maxvit_base_tf_512.in21k_ft_in1k](https://huggingface.co/timm/maxvit_base_tf_512.in21k_ft_in1k) |88.20|98.53| 50.87| 119.88|138.02| 703.99| +|[maxvit_large_tf_512.in21k_ft_in1k](https://huggingface.co/timm/maxvit_large_tf_512.in21k_ft_in1k) |88.04|98.40| 36.42| 212.33|244.75| 942.15| +|[maxvit_large_tf_384.in21k_ft_in1k](https://huggingface.co/timm/maxvit_large_tf_384.in21k_ft_in1k) |87.98|98.56| 71.75| 212.03|132.55| 445.84| +|[maxvit_base_tf_384.in21k_ft_in1k](https://huggingface.co/timm/maxvit_base_tf_384.in21k_ft_in1k) |87.92|98.54| 104.71| 119.65| 73.80| 332.90| +|[maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k](https://huggingface.co/timm/maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k) |87.81|98.37| 106.55| 116.14| 70.97| 318.95| +|[maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k](https://huggingface.co/timm/maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k) |87.47|98.37| 149.49| 116.09| 72.98| 213.74| +|[coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k) |87.39|98.31| 160.80| 73.88| 47.69| 209.43| +|[maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k) |86.89|98.02| 375.86| 116.14| 23.15| 92.64| +|[maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k) |86.64|98.02| 501.03| 116.09| 24.20| 62.77| +|[maxvit_base_tf_512.in1k](https://huggingface.co/timm/maxvit_base_tf_512.in1k) |86.60|97.92| 50.75| 119.88|138.02| 703.99| +|[coatnet_2_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_2_rw_224.sw_in12k_ft_in1k) |86.57|97.89| 631.88| 73.87| 15.09| 49.22| +|[maxvit_large_tf_512.in1k](https://huggingface.co/timm/maxvit_large_tf_512.in1k) |86.52|97.88| 36.04| 212.33|244.75| 942.15| +|[coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k) |86.49|97.90| 620.58| 73.88| 15.18| 54.78| +|[maxvit_base_tf_384.in1k](https://huggingface.co/timm/maxvit_base_tf_384.in1k) |86.29|97.80| 101.09| 119.65| 73.80| 332.90| +|[maxvit_large_tf_384.in1k](https://huggingface.co/timm/maxvit_large_tf_384.in1k) |86.23|97.69| 70.56| 212.03|132.55| 445.84| +|[maxvit_small_tf_512.in1k](https://huggingface.co/timm/maxvit_small_tf_512.in1k) |86.10|97.76| 88.63| 69.13| 67.26| 383.77| +|[maxvit_tiny_tf_512.in1k](https://huggingface.co/timm/maxvit_tiny_tf_512.in1k) |85.67|97.58| 144.25| 31.05| 33.49| 257.59| +|[maxvit_small_tf_384.in1k](https://huggingface.co/timm/maxvit_small_tf_384.in1k) |85.54|97.46| 188.35| 69.02| 35.87| 183.65| +|[maxvit_tiny_tf_384.in1k](https://huggingface.co/timm/maxvit_tiny_tf_384.in1k) |85.11|97.38| 293.46| 30.98| 17.53| 123.42| +|[maxvit_large_tf_224.in1k](https://huggingface.co/timm/maxvit_large_tf_224.in1k) |84.93|96.97| 247.71| 211.79| 43.68| 127.35| +|[coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k](https://huggingface.co/timm/coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k) |84.90|96.96| 1025.45| 41.72| 8.11| 40.13| +|[maxvit_base_tf_224.in1k](https://huggingface.co/timm/maxvit_base_tf_224.in1k) |84.85|96.99| 358.25| 119.47| 24.04| 95.01| +|[maxxvit_rmlp_small_rw_256.sw_in1k](https://huggingface.co/timm/maxxvit_rmlp_small_rw_256.sw_in1k) |84.63|97.06| 575.53| 66.01| 14.67| 58.38| +|[coatnet_rmlp_2_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_rmlp_2_rw_224.sw_in1k) |84.61|96.74| 625.81| 73.88| 15.18| 54.78| +|[maxvit_rmlp_small_rw_224.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_small_rw_224.sw_in1k) |84.49|96.76| 693.82| 64.90| 10.75| 49.30| +|[maxvit_small_tf_224.in1k](https://huggingface.co/timm/maxvit_small_tf_224.in1k) |84.43|96.83| 647.96| 68.93| 11.66| 53.17| +|[maxvit_rmlp_tiny_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_tiny_rw_256.sw_in1k) |84.23|96.78| 807.21| 29.15| 6.77| 46.92| +|[coatnet_1_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_1_rw_224.sw_in1k) |83.62|96.38| 989.59| 41.72| 8.04| 34.60| +|[maxvit_tiny_rw_224.sw_in1k](https://huggingface.co/timm/maxvit_tiny_rw_224.sw_in1k) |83.50|96.50| 1100.53| 29.06| 5.11| 33.11| +|[maxvit_tiny_tf_224.in1k](https://huggingface.co/timm/maxvit_tiny_tf_224.in1k) |83.41|96.59| 1004.94| 30.92| 5.60| 35.78| +|[coatnet_rmlp_1_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_rmlp_1_rw_224.sw_in1k) |83.36|96.45| 1093.03| 41.69| 7.85| 35.47| +|[maxxvitv2_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxxvitv2_nano_rw_256.sw_in1k) |83.11|96.33| 1276.88| 23.70| 6.26| 23.05| +|[maxxvit_rmlp_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxxvit_rmlp_nano_rw_256.sw_in1k) |83.03|96.34| 1341.24| 16.78| 4.37| 26.05| +|[maxvit_rmlp_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_nano_rw_256.sw_in1k) |82.96|96.26| 1283.24| 15.50| 4.47| 31.92| +|[maxvit_nano_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_nano_rw_256.sw_in1k) |82.93|96.23| 1218.17| 15.45| 4.46| 30.28| +|[coatnet_bn_0_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_bn_0_rw_224.sw_in1k) |82.39|96.19| 1600.14| 27.44| 4.67| 22.04| +|[coatnet_0_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_0_rw_224.sw_in1k) |82.39|95.84| 1831.21| 27.44| 4.43| 18.73| +|[coatnet_rmlp_nano_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_rmlp_nano_rw_224.sw_in1k) |82.05|95.87| 2109.09| 15.15| 2.62| 20.34| +|[coatnext_nano_rw_224.sw_in1k](https://huggingface.co/timm/coatnext_nano_rw_224.sw_in1k) |81.95|95.92| 2525.52| 14.70| 2.47| 12.80| +|[coatnet_nano_rw_224.sw_in1k](https://huggingface.co/timm/coatnet_nano_rw_224.sw_in1k) |81.70|95.64| 2344.52| 15.14| 2.41| 15.41| +|[maxvit_rmlp_pico_rw_256.sw_in1k](https://huggingface.co/timm/maxvit_rmlp_pico_rw_256.sw_in1k) |80.53|95.21| 1594.71| 7.52| 1.85| 24.86| + +### Jan 11, 2023 +* Update ConvNeXt ImageNet-12k pretrain series w/ two new fine-tuned weights (and pre FT `.in12k` tags) + * `convnext_nano.in12k_ft_in1k` - 82.3 @ 224, 82.9 @ 288 (previously released) + * `convnext_tiny.in12k_ft_in1k` - 84.2 @ 224, 84.5 @ 288 + * `convnext_small.in12k_ft_in1k` - 85.2 @ 224, 85.3 @ 288 + +### Jan 6, 2023 +* Finally got around to adding `--model-kwargs` and `--opt-kwargs` to scripts to pass through rare args directly to model classes from cmd line + * `train.py /imagenet --model resnet50 --amp --model-kwargs output_stride=16 act_layer=silu` + * `train.py /imagenet --model vit_base_patch16_clip_224 --img-size 240 --amp --model-kwargs img_size=240 patch_size=12` +* Cleanup some popular models to better support arg passthrough / merge with model configs, more to go. ### Jan 5, 2023 * ConvNeXt-V2 models and weights added to existing `convnext.py` diff --git a/benchmark.py b/benchmark.py index 58435ff8..2cce3e2c 100755 --- a/benchmark.py +++ b/benchmark.py @@ -22,7 +22,7 @@ from timm.data import resolve_data_config from timm.layers import set_fast_norm from timm.models import create_model, is_model, list_models from timm.optim import create_optimizer_v2 -from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry +from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs has_apex = False try: @@ -108,12 +108,15 @@ parser.add_argument('--grad-checkpointing', action='store_true', default=False, help='Enable gradient checkpointing through model blocks/stages') parser.add_argument('--amp', action='store_true', default=False, help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.') +parser.add_argument('--amp-dtype', default='float16', type=str, + help='lower precision AMP dtype (default: float16). Overrides --precision arg if args.amp True.') parser.add_argument('--precision', default='float32', type=str, help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)') parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") parser.add_argument('--fast-norm', default=False, action='store_true', help='enable experimental fast-norm') +parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs) # codegen (model compilation) options scripting_group = parser.add_mutually_exclusive_group() @@ -124,7 +127,6 @@ scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None scripting_group.add_argument('--aot-autograd', default=False, action='store_true', help="Enable AOT Autograd optimization.") - # train optimizer parameters parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', help='Optimizer (default: "sgd"') @@ -168,19 +170,21 @@ def count_params(model: nn.Module): def resolve_precision(precision: str): - assert precision in ('amp', 'float16', 'bfloat16', 'float32') - use_amp = False + assert precision in ('amp', 'amp_bfloat16', 'float16', 'bfloat16', 'float32') + amp_dtype = None # amp disabled model_dtype = torch.float32 data_dtype = torch.float32 if precision == 'amp': - use_amp = True + amp_dtype = torch.float16 + elif precision == 'amp_bfloat16': + amp_dtype = torch.bfloat16 elif precision == 'float16': model_dtype = torch.float16 data_dtype = torch.float16 elif precision == 'bfloat16': model_dtype = torch.bfloat16 data_dtype = torch.bfloat16 - return use_amp, model_dtype, data_dtype + return amp_dtype, model_dtype, data_dtype def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False): @@ -228,9 +232,12 @@ class BenchmarkRunner: self.model_name = model_name self.detail = detail self.device = device - self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision) + self.amp_dtype, self.model_dtype, self.data_dtype = resolve_precision(precision) self.channels_last = kwargs.pop('channels_last', False) - self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=torch.float16) if self.use_amp else suppress + if self.amp_dtype is not None: + self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=self.amp_dtype) + else: + self.amp_autocast = suppress if fuser: set_jit_fuser(fuser) @@ -243,6 +250,7 @@ class BenchmarkRunner: drop_rate=kwargs.pop('drop', 0.), drop_path_rate=kwargs.pop('drop_path', None), drop_block_rate=kwargs.pop('drop_block', None), + **kwargs.pop('model_kwargs', {}), ) self.model.to( device=self.device, @@ -560,7 +568,7 @@ def _try_run( def benchmark(args): if args.amp: _logger.warning("Overriding precision to 'amp' since --amp flag set.") - args.precision = 'amp' + args.precision = 'amp' if args.amp_dtype == 'float16' else '_'.join(['amp', args.amp_dtype]) _logger.info(f'Benchmarking in {args.precision} precision. ' f'{"NHWC" if args.channels_last else "NCHW"} layout. ' f'torchscript {"enabled" if args.torchscript else "disabled"}') diff --git a/inference.py b/inference.py index 1509b323..cfbe62d1 100755 --- a/inference.py +++ b/inference.py @@ -20,7 +20,7 @@ import torch from timm.data import create_dataset, create_loader, resolve_data_config from timm.layers import apply_test_time_pool from timm.models import create_model -from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser +from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs try: from apex import amp @@ -72,6 +72,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256)') parser.add_argument('--img-size', default=None, type=int, metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--in-chans', type=int, default=None, metavar='N', + help='Image input channels (default: None => 3)') parser.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') parser.add_argument('--use-train-size', action='store_true', default=False, @@ -110,6 +112,7 @@ parser.add_argument('--amp-dtype', default='float16', type=str, help='lower precision AMP dtype (default: float16)') parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") +parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs) scripting_group = parser.add_mutually_exclusive_group() scripting_group.add_argument('--torchscript', default=False, action='store_true', @@ -170,12 +173,19 @@ def main(): set_jit_fuser(args.fuser) # create model + in_chans = 3 + if args.in_chans is not None: + in_chans = args.in_chans + elif args.input_size is not None: + in_chans = args.input_size[0] + model = create_model( args.model, num_classes=args.num_classes, - in_chans=3, + in_chans=in_chans, pretrained=args.pretrained, checkpoint_path=args.checkpoint, + **args.model_kwargs, ) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' diff --git a/results/README.md b/results/README.md index 4fabf64b..81f30061 100644 --- a/results/README.md +++ b/results/README.md @@ -38,7 +38,7 @@ An ImageNet test set of 10,000 images sampled from new images roughly 10 years a ### ImageNet-Adversarial - [`results-imagenet-a.csv`](results-imagenet-a.csv) -A collection of 7500 images covering 200 of the 1000 ImageNet classes. Images are naturally occuring adversarial examples that confuse typical ImageNet classifiers. This is a challenging dataset, your typical ResNet-50 will score 0% top-1. +A collection of 7500 images covering 200 of the 1000 ImageNet classes. Images are naturally occurring adversarial examples that confuse typical ImageNet classifiers. This is a challenging dataset, your typical ResNet-50 will score 0% top-1. For clean validation with same 200 classes, see [`results-imagenet-a-clean.csv`](results-imagenet-a-clean.csv) diff --git a/tests/test_models.py b/tests/test_models.py index 3e91d9a8..fdededc7 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -27,7 +27,7 @@ NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*', 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*', - 'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*', 'flexivit*' + 'eva_*', 'flexivit*' ] NUM_NON_STD = len(NON_STD_FILTERS) @@ -38,7 +38,7 @@ if 'GITHUB_ACTIONS' in os.environ: '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*', '*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*', - 'swin*giant*', 'convnextv2_huge*'] + 'swin*giant*', 'convnextv2_huge*', 'maxvit_xlarge*', 'davit_giant', 'davit_huge'] NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*'] else: EXCLUDE_FILTERS = [] @@ -53,7 +53,7 @@ MAX_JIT_SIZE = 320 TARGET_FFEAT_SIZE = 96 MAX_FFEAT_SIZE = 256 TARGET_FWD_FX_SIZE = 128 -MAX_FWD_FX_SIZE = 224 +MAX_FWD_FX_SIZE = 256 TARGET_BWD_FX_SIZE = 128 MAX_BWD_FX_SIZE = 224 @@ -269,7 +269,7 @@ if 'GITHUB_ACTIONS' not in os.environ: EXCLUDE_JIT_FILTERS = [ '*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable - 'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point + 'dla*', 'hrnet*', 'ghostnet*' # hopefully fix at some point 'vit_large_*', 'vit_huge_*', 'vit_gi*', ] diff --git a/timm/data/__init__.py b/timm/data/__init__.py index 7cc7b0b0..9f62a7d5 100644 --- a/timm/data/__init__.py +++ b/timm/data/__init__.py @@ -1,6 +1,6 @@ from .auto_augment import RandAugment, AutoAugment, rand_augment_ops, auto_augment_policy,\ rand_augment_transform, auto_augment_transform -from .config import resolve_data_config +from .config import resolve_data_config, resolve_model_data_config from .constants import * from .dataset import ImageDataset, IterableImageDataset, AugMixDataset from .dataset_factory import create_dataset diff --git a/timm/data/config.py b/timm/data/config.py index a65695d0..a6c2298c 100644 --- a/timm/data/config.py +++ b/timm/data/config.py @@ -6,16 +6,18 @@ _logger = logging.getLogger(__name__) def resolve_data_config( - args, - default_cfg=None, + args=None, + pretrained_cfg=None, model=None, use_test_size=False, verbose=False ): - new_config = {} - default_cfg = default_cfg or {} - if not default_cfg and model is not None and hasattr(model, 'default_cfg'): - default_cfg = model.default_cfg + assert model or args or pretrained_cfg, "At least one of model, args, or pretrained_cfg required for data config." + args = args or {} + pretrained_cfg = pretrained_cfg or {} + if not pretrained_cfg and model is not None and hasattr(model, 'pretrained_cfg'): + pretrained_cfg = model.pretrained_cfg + data_config = {} # Resolve input/image size in_chans = 3 @@ -32,65 +34,94 @@ def resolve_data_config( assert isinstance(args['img_size'], int) input_size = (in_chans, args['img_size'], args['img_size']) else: - if use_test_size and default_cfg.get('test_input_size', None) is not None: - input_size = default_cfg['test_input_size'] - elif default_cfg.get('input_size', None) is not None: - input_size = default_cfg['input_size'] - new_config['input_size'] = input_size + if use_test_size and pretrained_cfg.get('test_input_size', None) is not None: + input_size = pretrained_cfg['test_input_size'] + elif pretrained_cfg.get('input_size', None) is not None: + input_size = pretrained_cfg['input_size'] + data_config['input_size'] = input_size # resolve interpolation method - new_config['interpolation'] = 'bicubic' + data_config['interpolation'] = 'bicubic' if args.get('interpolation', None): - new_config['interpolation'] = args['interpolation'] - elif default_cfg.get('interpolation', None): - new_config['interpolation'] = default_cfg['interpolation'] + data_config['interpolation'] = args['interpolation'] + elif pretrained_cfg.get('interpolation', None): + data_config['interpolation'] = pretrained_cfg['interpolation'] # resolve dataset + model mean for normalization - new_config['mean'] = IMAGENET_DEFAULT_MEAN + data_config['mean'] = IMAGENET_DEFAULT_MEAN if args.get('mean', None) is not None: mean = tuple(args['mean']) if len(mean) == 1: mean = tuple(list(mean) * in_chans) else: assert len(mean) == in_chans - new_config['mean'] = mean - elif default_cfg.get('mean', None): - new_config['mean'] = default_cfg['mean'] + data_config['mean'] = mean + elif pretrained_cfg.get('mean', None): + data_config['mean'] = pretrained_cfg['mean'] # resolve dataset + model std deviation for normalization - new_config['std'] = IMAGENET_DEFAULT_STD + data_config['std'] = IMAGENET_DEFAULT_STD if args.get('std', None) is not None: std = tuple(args['std']) if len(std) == 1: std = tuple(list(std) * in_chans) else: assert len(std) == in_chans - new_config['std'] = std - elif default_cfg.get('std', None): - new_config['std'] = default_cfg['std'] + data_config['std'] = std + elif pretrained_cfg.get('std', None): + data_config['std'] = pretrained_cfg['std'] # resolve default inference crop crop_pct = DEFAULT_CROP_PCT if args.get('crop_pct', None): crop_pct = args['crop_pct'] else: - if use_test_size and default_cfg.get('test_crop_pct', None): - crop_pct = default_cfg['test_crop_pct'] - elif default_cfg.get('crop_pct', None): - crop_pct = default_cfg['crop_pct'] - new_config['crop_pct'] = crop_pct + if use_test_size and pretrained_cfg.get('test_crop_pct', None): + crop_pct = pretrained_cfg['test_crop_pct'] + elif pretrained_cfg.get('crop_pct', None): + crop_pct = pretrained_cfg['crop_pct'] + data_config['crop_pct'] = crop_pct # resolve default crop percentage crop_mode = DEFAULT_CROP_MODE if args.get('crop_mode', None): crop_mode = args['crop_mode'] - elif default_cfg.get('crop_mode', None): - crop_mode = default_cfg['crop_mode'] - new_config['crop_mode'] = crop_mode + elif pretrained_cfg.get('crop_mode', None): + crop_mode = pretrained_cfg['crop_mode'] + data_config['crop_mode'] = crop_mode if verbose: _logger.info('Data processing configuration for current model + dataset:') - for n, v in new_config.items(): + for n, v in data_config.items(): _logger.info('\t%s: %s' % (n, str(v))) - return new_config + return data_config + + +def resolve_model_data_config( + model, + args=None, + pretrained_cfg=None, + use_test_size=False, + verbose=False, +): + """ Resolve Model Data Config + This is equivalent to resolve_data_config() but with arguments re-ordered to put model first. + + Args: + model (nn.Module): the model instance + args (dict): command line arguments / configuration in dict form (overrides pretrained_cfg) + pretrained_cfg (dict): pretrained model config (overrides pretrained_cfg attached to model) + use_test_size (bool): use the test time input resolution (if one exists) instead of default train resolution + verbose (bool): enable extra logging of resolved values + + Returns: + dictionary of config + """ + return resolve_data_config( + args=args, + pretrained_cfg=pretrained_cfg, + model=model, + use_test_size=use_test_size, + verbose=verbose, + ) diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 625b4826..6b2dabba 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -29,7 +29,8 @@ from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d -from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm +from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\ + SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d from .padding import get_padding, get_same_padding, pad_same from .patch_embed import PatchEmbed, resample_patch_embed from .pool2d_same import AvgPool2dSame, create_pool2d diff --git a/timm/layers/classifier.py b/timm/layers/classifier.py index 3ac33387..e885084c 100644 --- a/timm/layers/classifier.py +++ b/timm/layers/classifier.py @@ -38,13 +38,24 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False class ClassifierHead(nn.Module): """Classifier head w/ configurable global pooling and dropout.""" - def __init__(self, in_chs, num_classes, pool_type='avg', drop_rate=0., use_conv=False): + def __init__(self, in_features, num_classes, pool_type='avg', drop_rate=0., use_conv=False): super(ClassifierHead, self).__init__() self.drop_rate = drop_rate - self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv) + self.in_features = in_features + self.use_conv = use_conv + + self.global_pool, num_pooled_features = _create_pool(in_features, num_classes, pool_type, use_conv=use_conv) self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv) self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity() + def reset(self, num_classes, global_pool=None): + if global_pool is not None: + if global_pool != self.global_pool.pool_type: + self.global_pool, _ = _create_pool(self.in_features, num_classes, global_pool, use_conv=self.use_conv) + self.flatten = nn.Flatten(1) if self.use_conv and global_pool else nn.Identity() + num_pooled_features = self.in_features * self.global_pool.feat_mult() + self.fc = _create_fc(num_pooled_features, num_classes, use_conv=self.use_conv) + def forward(self, x, pre_logits: bool = False): x = self.global_pool(x) if self.drop_rate: diff --git a/timm/layers/norm_act.py b/timm/layers/norm_act.py index ff075fbc..5ca21d18 100644 --- a/timm/layers/norm_act.py +++ b/timm/layers/norm_act.py @@ -17,6 +17,7 @@ from typing import Union, List, Optional, Any import torch from torch import nn as nn from torch.nn import functional as F +from torchvision.ops.misc import FrozenBatchNorm2d from .create_act import get_act_layer from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm @@ -77,7 +78,7 @@ class BatchNormAct2d(nn.BatchNorm2d): if self.training and self.track_running_stats: # TODO: if statement only here to tell the jit to skip emitting this when it is None if self.num_batches_tracked is not None: # type: ignore[has-type] - self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type] + self.num_batches_tracked.add_(1) # type: ignore[has-type] if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average @@ -169,6 +170,159 @@ def convert_sync_batchnorm(module, process_group=None): return module_output +class FrozenBatchNormAct2d(torch.nn.Module): + """ + BatchNormAct2d where the batch statistics and the affine parameters are fixed + + Args: + num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)`` + eps (float): a value added to the denominator for numerical stability. Default: 1e-5 + """ + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + apply_act=True, + act_layer=nn.ReLU, + inplace=True, + drop_layer=None, + ): + super().__init__() + self.eps = eps + self.register_buffer("weight", torch.ones(num_features)) + self.register_buffer("bias", torch.zeros(num_features)) + self.register_buffer("running_mean", torch.zeros(num_features)) + self.register_buffer("running_var", torch.ones(num_features)) + + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + act_layer = get_act_layer(act_layer) # string -> nn.Module + if act_layer is not None and apply_act: + act_args = dict(inplace=True) if inplace else {} + self.act = act_layer(**act_args) + else: + self.act = nn.Identity() + + def _load_from_state_dict( + self, + state_dict: dict, + prefix: str, + local_metadata: dict, + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + scale = w * (rv + self.eps).rsqrt() + bias = b - rm * scale + x = x * scale + bias + x = self.act(self.drop(x)) + return x + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps}, act={self.act})" + + +def freeze_batch_norm_2d(module): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` or `BatchNormAct2d` and `SyncBatchNormAct2d` layers + of provided module into `FrozenBatchNorm2d` or `FrozenBatchNormAct2d` respectively. + + Args: + module (torch.nn.Module): Any PyTorch module. + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + if isinstance(module, (BatchNormAct2d, SyncBatchNormAct)): + res = FrozenBatchNormAct2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + res.drop = module.drop + res.act = module.act + elif isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for name, child in module.named_children(): + new_child = freeze_batch_norm_2d(child) + if new_child is not child: + res.add_module(name, new_child) + return res + + +def unfreeze_batch_norm_2d(module): + """ + Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance + of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked + recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + if isinstance(module, FrozenBatchNormAct2d): + res = BatchNormAct2d(module.num_features) + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + res.drop = module.drop + res.act = module.act + elif isinstance(module, FrozenBatchNorm2d): + res = torch.nn.BatchNorm2d(module.num_features) + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for name, child in module.named_children(): + new_child = unfreeze_batch_norm_2d(child) + if new_child is not child: + res.add_module(name, new_child) + return res + + def _num_groups(num_channels, num_groups, group_size): if group_size: assert num_channels % group_size == 0 @@ -179,10 +333,54 @@ def _num_groups(num_channels, num_groups, group_size): class GroupNormAct(nn.GroupNorm): # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args def __init__( - self, num_channels, num_groups=32, eps=1e-5, affine=True, group_size=None, - apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): + self, + num_channels, + num_groups=32, + eps=1e-5, + affine=True, + group_size=None, + apply_act=True, + act_layer=nn.ReLU, + inplace=True, + drop_layer=None, + ): super(GroupNormAct, self).__init__( - _num_groups(num_channels, num_groups, group_size), num_channels, eps=eps, affine=affine) + _num_groups(num_channels, num_groups, group_size), + num_channels, + eps=eps, + affine=affine, + ) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + act_layer = get_act_layer(act_layer) # string -> nn.Module + if act_layer is not None and apply_act: + act_args = dict(inplace=True) if inplace else {} + self.act = act_layer(**act_args) + else: + self.act = nn.Identity() + self._fast_norm = is_fast_norm() + + def forward(self, x): + if self._fast_norm: + x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + else: + x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + x = self.drop(x) + x = self.act(x) + return x + + +class GroupNorm1Act(nn.GroupNorm): + def __init__( + self, + num_channels, + eps=1e-5, + affine=True, + apply_act=True, + act_layer=nn.ReLU, + inplace=True, + drop_layer=None, + ): + super(GroupNorm1Act, self).__init__(1, num_channels, eps=eps, affine=affine) self.drop = drop_layer() if drop_layer is not None else nn.Identity() act_layer = get_act_layer(act_layer) # string -> nn.Module if act_layer is not None and apply_act: @@ -204,8 +402,15 @@ class GroupNormAct(nn.GroupNorm): class LayerNormAct(nn.LayerNorm): def __init__( - self, normalization_shape: Union[int, List[int], torch.Size], eps=1e-5, affine=True, - apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): + self, + normalization_shape: Union[int, List[int], torch.Size], + eps=1e-5, + affine=True, + apply_act=True, + act_layer=nn.ReLU, + inplace=True, + drop_layer=None, + ): super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine) self.drop = drop_layer() if drop_layer is not None else nn.Identity() act_layer = get_act_layer(act_layer) # string -> nn.Module @@ -228,8 +433,15 @@ class LayerNormAct(nn.LayerNorm): class LayerNormAct2d(nn.LayerNorm): def __init__( - self, num_channels, eps=1e-5, affine=True, - apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): + self, + num_channels, + eps=1e-5, + affine=True, + apply_act=True, + act_layer=nn.ReLU, + inplace=True, + drop_layer=None, + ): super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine) self.drop = drop_layer() if drop_layer is not None else nn.Identity() act_layer = get_act_layer(act_layer) # string -> nn.Module diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 5ecc8915..ea945ccd 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -8,6 +8,7 @@ from .convmixer import * from .convnext import * from .crossvit import * from .cspnet import * +from .davit import * from .deit import * from .densenet import * from .dla import * diff --git a/timm/models/_builder.py b/timm/models/_builder.py index 901d7d44..32a35304 100644 --- a/timm/models/_builder.py +++ b/timm/models/_builder.py @@ -179,11 +179,11 @@ def load_pretrained( return if filter_fn is not None: - # for backwards compat with filter fn that take one arg, try one first, the two try: - state_dict = filter_fn(state_dict) - except TypeError: state_dict = filter_fn(state_dict, model) + except TypeError as e: + # for backwards compat with filter fn that take one arg + state_dict = filter_fn(state_dict) input_convs = pretrained_cfg.get('first_conv', None) if input_convs is not None and in_chans != 3: diff --git a/timm/models/_hub.py b/timm/models/_hub.py index 7c64df0b..378d646c 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -209,6 +209,7 @@ def push_to_hf_hub( private: bool = False, create_pr: bool = False, model_config: Optional[dict] = None, + model_card: Optional[dict] = None, ): # Create repo if it doesn't exist yet repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) @@ -232,9 +233,10 @@ def push_to_hf_hub( # Add readme if it does not exist if not has_readme: + model_card = model_card or {} model_name = repo_id.split('/')[-1] readme_path = Path(tmpdir) / "README.md" - readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {model_name}' + readme_text = generate_readme(model_card, model_name) readme_path.write_text(readme_text) # Upload model and return @@ -245,3 +247,51 @@ def push_to_hf_hub( create_pr=create_pr, commit_message=commit_message, ) + + +def generate_readme(model_card, model_name): + readme_text = "---\n" + readme_text += "tags:\n- image-classification\n- timm\n" + readme_text += "library_tag: timm\n" + readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n" + if 'details' in model_card and 'Dataset' in model_card['details']: + readme_text += 'datasets:\n' + readme_text += f"- {model_card['details']['Dataset'].lower()}\n" + if 'Pretrain Dataset' in model_card['details']: + readme_text += f"- {model_card['details']['Pretrain Dataset'].lower()}\n" + readme_text += "---\n" + readme_text += f"# Model card for {model_name}\n" + if 'description' in model_card: + readme_text += f"\n{model_card['description']}\n" + if 'details' in model_card: + readme_text += f"\n## Model Details\n" + for k, v in model_card['details'].items(): + if isinstance(v, (list, tuple)): + readme_text += f"- **{k}:**\n" + for vi in v: + readme_text += f" - {vi}\n" + elif isinstance(v, dict): + readme_text += f"- **{k}:**\n" + for ki, vi in v.items(): + readme_text += f" - {ki}: {vi}\n" + else: + readme_text += f"- **{k}:** {v}\n" + if 'usage' in model_card: + readme_text += f"\n## Model Usage\n" + readme_text += model_card['usage'] + readme_text += '\n' + + if 'comparison' in model_card: + readme_text += f"\n## Model Comparison\n" + readme_text += model_card['comparison'] + readme_text += '\n' + + if 'citation' in model_card: + readme_text += f"\n## Citation\n" + if not isinstance(model_card['citation'], (list, tuple)): + citations = [model_card['citation']] + else: + citations = model_card['citation'] + for c in citations: + readme_text += f"```bibtex\n{c}\n```\n" + return readme_text diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 15f78044..1c7f1137 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -218,7 +218,10 @@ def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0): def interleave_blocks( - types: Tuple[str, str], d, every: Union[int, List[int]] = 1, first: bool = False, **kwargs + types: Tuple[str, str], d, + every: Union[int, List[int]] = 1, + first: bool = False, + **kwargs, ) -> Tuple[ByoBlockCfg]: """ interleave 2 block types in stack """ @@ -1587,15 +1590,32 @@ class ByobNet(nn.Module): in_chans=3, global_pool='avg', output_stride=32, - zero_init_last=True, img_size=None, drop_rate=0., drop_path_rate=0., + zero_init_last=True, + **kwargs, ): + """ + + Args: + cfg (ByoModelCfg): Model architecture configuration + num_classes (int): Number of classifier classes (default: 1000) + in_chans (int): Number of input channels (default: 3) + global_pool (str): Global pooling type (default: 'avg') + output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32) + img_size (Union[int, Tuple[int]): Image size for fixed image size models (i.e. self-attn) + drop_rate (float): Dropout rate (default: 0.) + drop_path_rate (float): Stochastic depth drop-path rate (default: 0.) + zero_init_last (bool): Zero-init last weight of residual path + kwargs (dict): Extra kwargs overlayed onto cfg + """ super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate self.grad_checkpointing = False + + cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg layers = get_layer_fns(cfg) if cfg.fixed_input_size: assert img_size is not None, 'img_size argument is required for fixed input size model' diff --git a/timm/models/convnext.py b/timm/models/convnext.py index e9214429..2bbe0b11 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -43,7 +43,7 @@ from functools import partial import torch import torch.nn as nn -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \ LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple from ._builder import build_model_with_cfg @@ -167,7 +167,7 @@ class ConvNeXtStage(nn.Module): conv_bias=conv_bias, use_grn=use_grn, act_layer=act_layer, - norm_layer=norm_layer if conv_mlp else norm_layer_cl + norm_layer=norm_layer if conv_mlp else norm_layer_cl, )) in_chs = out_chs self.blocks = nn.Sequential(*stage_blocks) @@ -184,16 +184,6 @@ class ConvNeXtStage(nn.Module): class ConvNeXt(nn.Module): r""" ConvNeXt A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf - - Args: - in_chans (int): Number of input image channels. Default: 3 - num_classes (int): Number of classes for classification head. Default: 1000 - depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] - dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768] - drop_rate (float): Head dropout rate - drop_path_rate (float): Stochastic depth rate. Default: 0. - ls_init_value (float): Init value for Layer Scale. Default: 1e-6. - head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. """ def __init__( @@ -215,19 +205,47 @@ class ConvNeXt(nn.Module): use_grn=False, act_layer='gelu', norm_layer=None, + norm_eps=None, drop_rate=0., drop_path_rate=0., ): + """ + Args: + in_chans (int): Number of input image channels (default: 3) + num_classes (int): Number of classes for classification head (default: 1000) + global_pool (str): Global pooling type (default: 'avg') + output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32) + depths (tuple(int)): Number of blocks at each stage. (default: [3, 3, 9, 3]) + dims (tuple(int)): Feature dimension at each stage. (default: [96, 192, 384, 768]) + kernel_sizes (Union[int, List[int]]: Depthwise convolution kernel-sizes for each stage (default: 7) + ls_init_value (float): Init value for Layer Scale (default: 1e-6) + stem_type (str): Type of stem (default: 'patch') + patch_size (int): Stem patch size for patch stem (default: 4) + head_init_scale (float): Init scaling value for classifier weights and biases (default: 1) + head_norm_first (bool): Apply normalization before global pool + head (default: False) + conv_mlp (bool): Use 1x1 conv in MLP, improves speed for small networks w/ chan last (default: False) + conv_bias (bool): Use bias layers w/ all convolutions (default: True) + use_grn (bool): Use Global Response Norm (ConvNeXt-V2) in MLP (default: False) + act_layer (Union[str, nn.Module]): Activation Layer + norm_layer (Union[str, nn.Module]): Normalization Layer + drop_rate (float): Head dropout rate (default: 0.) + drop_path_rate (float): Stochastic depth rate (default: 0.) + """ super().__init__() assert output_stride in (8, 16, 32) kernel_sizes = to_ntuple(4)(kernel_sizes) if norm_layer is None: norm_layer = LayerNorm2d norm_layer_cl = norm_layer if conv_mlp else LayerNorm + if norm_eps is not None: + norm_layer = partial(norm_layer, eps=norm_eps) + norm_layer_cl = partial(norm_layer_cl, eps=norm_eps) else: assert conv_mlp,\ 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' norm_layer_cl = norm_layer + if norm_eps is not None: + norm_layer_cl = partial(norm_layer_cl, eps=norm_eps) self.num_classes = num_classes self.drop_rate = drop_rate @@ -238,7 +256,7 @@ class ConvNeXt(nn.Module): # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 self.stem = nn.Sequential( nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias), - norm_layer(dims[0]) + norm_layer(dims[0]), ) stem_stride = patch_size else: @@ -279,7 +297,7 @@ class ConvNeXt(nn.Module): use_grn=use_grn, act_layer=act_layer, norm_layer=norm_layer, - norm_layer_cl=norm_layer_cl + norm_layer_cl=norm_layer_cl, )) prev_chs = out_chs # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 @@ -289,11 +307,10 @@ class ConvNeXt(nn.Module): # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) - self.head_norm_first = head_norm_first self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() self.head = nn.Sequential(OrderedDict([ ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), - ('norm', nn.Identity() if head_norm_first or num_classes == 0 else norm_layer(self.num_features)), + ('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)), ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), ('drop', nn.Dropout(self.drop_rate)), ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())])) @@ -324,14 +341,7 @@ class ConvNeXt(nn.Module): if global_pool is not None: self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() - if num_classes == 0: - self.head.norm = nn.Identity() - self.head.fc = nn.Identity() - else: - if not self.head_norm_first: - norm_layer = type(self.stem[-1]) # obtain type from stem norm - self.head.norm = norm_layer(self.num_features) - self.head.fc = nn.Linear(self.num_features, num_classes) + self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.stem(x) @@ -372,7 +382,15 @@ def checkpoint_filter_fn(state_dict, model): return state_dict # non-FB checkpoint if 'model' in state_dict: state_dict = state_dict['model'] + out_dict = {} + if 'visual.trunk.stem.0.weight' in state_dict: + out_dict = {k.replace('visual.trunk.', ''): v for k, v in state_dict.items() if k.startswith('visual.trunk.')} + if 'visual.head.proj.weight' in state_dict: + out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight'] + out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0]) + return out_dict + import re for k, v in state_dict.items(): k = k.replace('downsample_layers.0.', 'stem.') @@ -391,10 +409,16 @@ def checkpoint_filter_fn(state_dict, model): model_shape = model.state_dict()[k].shape v = v.reshape(model_shape) out_dict[k] = v + return out_dict def _create_convnext(variant, pretrained=False, **kwargs): + if kwargs.get('pretrained_cfg', '') == 'fcmae': + # NOTE fcmae pretrained weights have no classifier or final norm-layer (`head.norm`) + # This is workaround loading with num_classes=0 w/o removing norm-layer. + kwargs.setdefault('pretrained_strict', False) + model = build_model_with_cfg( ConvNeXt, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, @@ -469,10 +493,29 @@ default_cfgs = generate_default_cfgs({ url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth', hf_hub_id='timm/', crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_tiny.in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + 'convnext_small.in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0), + + 'convnext_tiny.in12k_ft_in1k_384': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'convnext_small.in12k_ft_in1k_384': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'convnext_nano.in12k': _cfg( hf_hub_id='timm/', crop_pct=0.95, num_classes=11821), + 'convnext_tiny.in12k': _cfg( + hf_hub_id='timm/', + crop_pct=0.95, num_classes=11821), + 'convnext_small.in12k': _cfg( + hf_hub_id='timm/', + crop_pct=0.95, num_classes=11821), 'convnext_tiny.fb_in1k': _cfg( url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", @@ -664,6 +707,33 @@ default_cfgs = generate_default_cfgs({ num_classes=0), 'convnextv2_small.untrained': _cfg(), + + # CLIP based weights, original image tower weights and fine-tunes + 'convnext_base.clip_laion2b': _cfg( + hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640), + 'convnext_base.clip_laion2b_augreg': _cfg( + hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640), + 'convnext_base.clip_laiona': _cfg( + hf_hub_id='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=640), + 'convnext_base.clip_laiona_320': _cfg( + hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640), + 'convnext_base.clip_laiona_augreg_320': _cfg( + hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640), }) diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 280f929e..da9d1ae0 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -12,7 +12,7 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage Hacked together by / Copyright 2020 Ross Wightman """ -from dataclasses import dataclass, asdict +from dataclasses import dataclass, asdict, replace from functools import partial from typing import Any, Dict, Optional, Tuple, Union @@ -518,7 +518,7 @@ class CrossStage(nn.Module): cross_linear=False, block_dpr=None, block_fn=BottleneckBlock, - **block_kwargs + **block_kwargs, ): super(CrossStage, self).__init__() first_dilation = first_dilation or dilation @@ -558,7 +558,7 @@ class CrossStage(nn.Module): bottle_ratio=bottle_ratio, groups=groups, drop_path=block_dpr[i] if block_dpr is not None else 0., - **block_kwargs + **block_kwargs, )) prev_chs = block_out_chs @@ -597,7 +597,7 @@ class CrossStage3(nn.Module): cross_linear=False, block_dpr=None, block_fn=BottleneckBlock, - **block_kwargs + **block_kwargs, ): super(CrossStage3, self).__init__() first_dilation = first_dilation or dilation @@ -635,7 +635,7 @@ class CrossStage3(nn.Module): bottle_ratio=bottle_ratio, groups=groups, drop_path=block_dpr[i] if block_dpr is not None else 0., - **block_kwargs + **block_kwargs, )) prev_chs = block_out_chs @@ -668,7 +668,7 @@ class DarkStage(nn.Module): avg_down=False, block_fn=BottleneckBlock, block_dpr=None, - **block_kwargs + **block_kwargs, ): super(DarkStage, self).__init__() first_dilation = first_dilation or dilation @@ -715,7 +715,7 @@ def create_csp_stem( padding='', act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - aa_layer=None + aa_layer=None, ): stem = nn.Sequential() feature_info = [] @@ -738,7 +738,7 @@ def create_csp_stem( stride=conv_stride, padding=padding if i == 0 else '', act_layer=act_layer, - norm_layer=norm_layer + norm_layer=norm_layer, )) stem_stride *= conv_stride prev_chs = chs @@ -800,7 +800,7 @@ def create_csp_stages( cfg: CspModelCfg, drop_path_rate: float, output_stride: int, - stem_feat: Dict[str, Any] + stem_feat: Dict[str, Any], ): cfg_dict = asdict(cfg.stages) num_stages = len(cfg.stages.depth) @@ -868,12 +868,27 @@ class CspNet(nn.Module): global_pool='avg', drop_rate=0., drop_path_rate=0., - zero_init_last=True + zero_init_last=True, + **kwargs, ): + """ + Args: + cfg (CspModelCfg): Model architecture configuration + in_chans (int): Number of input channels (default: 3) + num_classes (int): Number of classifier classes (default: 1000) + output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32) + global_pool (str): Global pooling type (default: 'avg') + drop_rate (float): Dropout rate (default: 0.) + drop_path_rate (float): Stochastic depth drop-path rate (default: 0.) + zero_init_last (bool): Zero-init last weight of residual path + kwargs (dict): Extra kwargs overlayed onto cfg + """ super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate assert output_stride in (8, 16, 32) + + cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg layer_args = dict( act_layer=cfg.act_layer, norm_layer=cfg.norm_layer, @@ -898,7 +913,7 @@ class CspNet(nn.Module): # Construct the head self.num_features = prev_chs self.head = ClassifierHead( - in_chs=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + in_features=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) diff --git a/timm/models/davit.py b/timm/models/davit.py new file mode 100644 index 00000000..8b9e67b4 --- /dev/null +++ b/timm/models/davit.py @@ -0,0 +1,679 @@ +""" DaViT: Dual Attention Vision Transformers + +As described in https://arxiv.org/abs/2204.03645 + +Input size invariant transformer architecture that combines channel and spacial +attention in each block. The attention mechanisms used are linear in complexity. + +DaViT model defs and weights adapted from https://github.com/dingmyu/davit, original copyright below + +""" +# Copyright (c) 2022 Mingyu Ding +# All rights reserved. +# This source code is licensed under the MIT license +import itertools +from collections import OrderedDict +from functools import partial +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp, LayerNorm2d, get_norm_layer +from ._builder import build_model_with_cfg +from ._features_fx import register_notrace_function +from ._manipulate import checkpoint_seq +from ._pretrained import generate_default_cfgs +from ._registry import register_model + +__all__ = ['DaViT'] + + +class ConvPosEnc(nn.Module): + def __init__(self, dim: int, k: int = 3, act: bool = False): + super(ConvPosEnc, self).__init__() + + self.proj = nn.Conv2d(dim, dim, k, 1, k // 2, groups=dim) + self.act = nn.GELU() if act else nn.Identity() + + def forward(self, x: Tensor): + feat = self.proj(x) + x = x + self.act(feat) + return x + + +class Stem(nn.Module): + """ Size-agnostic implementation of 2D image to patch embedding, + allowing input size to be adjusted during model forward operation + """ + + def __init__( + self, + in_chs=3, + out_chs=96, + stride=4, + norm_layer=LayerNorm2d, + ): + super().__init__() + stride = to_2tuple(stride) + self.stride = stride + self.in_chs = in_chs + self.out_chs = out_chs + assert stride[0] == 4 # only setup for stride==4 + self.conv = nn.Conv2d( + in_chs, + out_chs, + kernel_size=7, + stride=stride, + padding=3, + ) + self.norm = norm_layer(out_chs) + + def forward(self, x: Tensor): + B, C, H, W = x.shape + x = F.pad(x, (0, (self.stride[1] - W % self.stride[1]) % self.stride[1])) + x = F.pad(x, (0, 0, 0, (self.stride[0] - H % self.stride[0]) % self.stride[0])) + x = self.conv(x) + x = self.norm(x) + return x + + +class Downsample(nn.Module): + def __init__( + self, + in_chs, + out_chs, + norm_layer=LayerNorm2d, + ): + super().__init__() + self.in_chs = in_chs + self.out_chs = out_chs + + self.norm = norm_layer(in_chs) + self.conv = nn.Conv2d( + in_chs, + out_chs, + kernel_size=2, + stride=2, + padding=0, + ) + + def forward(self, x: Tensor): + B, C, H, W = x.shape + x = self.norm(x) + x = F.pad(x, (0, (2 - W % 2) % 2)) + x = F.pad(x, (0, 0, 0, (2 - H % 2) % 2)) + x = self.conv(x) + return x + + +class ChannelAttention(nn.Module): + + def __init__(self, dim, num_heads=8, qkv_bias=False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + def forward(self, x: Tensor): + B, N, C = x.shape + + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + + k = k * self.scale + attention = k.transpose(-1, -2) @ v + attention = attention.softmax(dim=-1) + x = (attention @ q.transpose(-1, -2)).transpose(-1, -2) + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + return x + + +class ChannelBlock(nn.Module): + + def __init__( + self, + dim, + num_heads, + mlp_ratio=4., + qkv_bias=False, + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ffn=True, + cpe_act=False, + ): + super().__init__() + + self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act) + self.ffn = ffn + self.norm1 = norm_layer(dim) + self.attn = ChannelAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act) + + if self.ffn: + self.norm2 = norm_layer(dim) + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + else: + self.norm2 = None + self.mlp = None + self.drop_path2 = None + + def forward(self, x: Tensor): + B, C, H, W = x.shape + + x = self.cpe1(x).flatten(2).transpose(1, 2) + + cur = self.norm1(x) + cur = self.attn(cur) + x = x + self.drop_path1(cur) + + x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)) + + if self.mlp is not None: + x = x.flatten(2).transpose(1, 2) + x = x + self.drop_path2(self.mlp(self.norm2(x))) + x = x.transpose(1, 2).view(B, C, H, W) + + return x + + +def window_partition(x: Tensor, window_size: Tuple[int, int]): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def window_reverse(windows: Tensor, window_size: Tuple[int, int], H: int, W: int): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) + x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True): + super().__init__() + self.dim = dim + self.window_size = window_size + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x: Tensor): + B_, N, C = x.shape + + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + attn = self.softmax(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + return x + + +class SpatialBlock(nn.Module): + r""" Windows Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + num_heads, + window_size=7, + mlp_ratio=4., + qkv_bias=True, + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ffn=True, + cpe_act=False, + ): + super().__init__() + self.dim = dim + self.ffn = ffn + self.num_heads = num_heads + self.window_size = to_2tuple(window_size) + self.mlp_ratio = mlp_ratio + + self.cpe1 = ConvPosEnc(dim=dim, k=3, act=cpe_act) + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + self.window_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.cpe2 = ConvPosEnc(dim=dim, k=3, act=cpe_act) + if self.ffn: + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + else: + self.norm2 = None + self.mlp = None + self.drop_path1 = None + + def forward(self, x: Tensor): + B, C, H, W = x.shape + + shortcut = self.cpe1(x).flatten(2).transpose(1, 2) + + x = self.norm1(shortcut) + x = x.view(B, H, W, C) + + pad_l = pad_t = 0 + pad_r = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1] + pad_b = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0] + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + x_windows = window_partition(x, self.window_size) + x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1], C) + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C) + x = window_reverse(attn_windows, self.window_size, Hp, Wp) + + # if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + x = shortcut + self.drop_path1(x) + + x = self.cpe2(x.transpose(1, 2).view(B, C, H, W)) + + if self.mlp is not None: + x = x.flatten(2).transpose(1, 2) + x = x + self.drop_path2(self.mlp(self.norm2(x))) + x = x.transpose(1, 2).view(B, C, H, W) + + return x + + +class DaViTStage(nn.Module): + def __init__( + self, + in_chs, + out_chs, + depth=1, + downsample=True, + attn_types=('spatial', 'channel'), + num_heads=3, + window_size=7, + mlp_ratio=4, + qkv_bias=True, + drop_path_rates=(0, 0), + norm_layer=LayerNorm2d, + norm_layer_cl=nn.LayerNorm, + ffn=True, + cpe_act=False + ): + super().__init__() + + self.grad_checkpointing = False + + # downsample embedding layer at the beginning of each stage + if downsample: + self.downsample = Downsample(in_chs, out_chs, norm_layer=norm_layer) + else: + self.downsample = nn.Identity() + + ''' + repeating alternating attention blocks in each stage + default: (spatial -> channel) x depth + + potential opportunity to integrate with a more general version of ByobNet/ByoaNet + since the logic is similar + ''' + stage_blocks = [] + for block_idx in range(depth): + dual_attention_block = [] + for attn_idx, attn_type in enumerate(attn_types): + if attn_type == 'spatial': + dual_attention_block.append(SpatialBlock( + dim=out_chs, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=drop_path_rates[block_idx], + norm_layer=norm_layer_cl, + ffn=ffn, + cpe_act=cpe_act, + window_size=window_size, + )) + elif attn_type == 'channel': + dual_attention_block.append(ChannelBlock( + dim=out_chs, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=drop_path_rates[block_idx], + norm_layer=norm_layer_cl, + ffn=ffn, + cpe_act=cpe_act + )) + stage_blocks.append(nn.Sequential(*dual_attention_block)) + self.blocks = nn.Sequential(*stage_blocks) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + def forward(self, x: Tensor): + x = self.downsample(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class DaViT(nn.Module): + r""" DaViT + A PyTorch implementation of `DaViT: Dual Attention Vision Transformers` - https://arxiv.org/abs/2204.03645 + Supports arbitrary input sizes and pyramid feature extraction + + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks in each stage. Default: (1, 1, 3, 1) + embed_dims (tuple(int)): Patch embedding dimension. Default: (96, 192, 384, 768) + num_heads (tuple(int)): Number of attention heads in different layers. Default: (3, 6, 12, 24) + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + """ + + def __init__( + self, + in_chans=3, + depths=(1, 1, 3, 1), + embed_dims=(96, 192, 384, 768), + num_heads=(3, 6, 12, 24), + window_size=7, + mlp_ratio=4, + qkv_bias=True, + norm_layer='layernorm2d', + norm_layer_cl='layernorm', + norm_eps=1e-5, + attn_types=('spatial', 'channel'), + ffn=True, + cpe_act=False, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0., + num_classes=1000, + global_pool='avg', + head_norm_first=False, + ): + super().__init__() + num_stages = len(embed_dims) + assert num_stages == len(num_heads) == len(depths) + norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps) + norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps) + self.num_classes = num_classes + self.num_features = embed_dims[-1] + self.drop_rate = drop_rate + self.grad_checkpointing = False + self.feature_info = [] + + self.stem = Stem(in_chans, embed_dims[0], norm_layer=norm_layer) + in_chs = embed_dims[0] + + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + stages = [] + for stage_idx in range(num_stages): + out_chs = embed_dims[stage_idx] + stage = DaViTStage( + in_chs, + out_chs, + depth=depths[stage_idx], + downsample=stage_idx > 0, + attn_types=attn_types, + num_heads=num_heads[stage_idx], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path_rates=dpr[stage_idx], + norm_layer=norm_layer, + norm_layer_cl=norm_layer_cl, + ffn=ffn, + cpe_act=cpe_act, + ) + in_chs = out_chs + stages.append(stage) + self.feature_info += [dict(num_chs=out_chs, reduction=2, module=f'stages.{stage_idx}')] + + self.stages = nn.Sequential(*stages) + + # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets + # otherwise pool -> norm -> fc, the default DaViT order, similar to ConvNeXt + # FIXME generalize this structure to ClassifierHead + self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() + self.head = nn.Sequential(OrderedDict([ + ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), + ('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)), + ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), + ('drop', nn.Dropout(self.drop_rate)), + ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())])) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + for stage in self.stages: + stage.set_grad_checkpointing(enable=enable) + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + if global_pool is not None: + self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() + self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.stem(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) + x = self.norm_pre(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + x = self.head.global_pool(x) + x = self.head.norm(x) + x = self.head.flatten(x) + x = self.head.drop(x) + return x if pre_logits else self.head.fc(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + """ Remap MSFT checkpoints -> timm """ + if 'head.fc.weight' in state_dict: + return state_dict # non-MSFT checkpoint + + if 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + + import re + out_dict = {} + for k, v in state_dict.items(): + k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.downsample', k) + k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k) + k = k.replace('downsample.proj', 'downsample.conv') + k = k.replace('stages.0.downsample', 'stem') + k = k.replace('head.', 'head.fc.') + k = k.replace('norms.', 'head.norm.') + k = k.replace('cpe.0', 'cpe1') + k = k.replace('cpe.1', 'cpe2') + out_dict[k] = v + return out_dict + + +def _create_davit(variant, pretrained=False, **kwargs): + default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) + out_indices = kwargs.pop('out_indices', default_out_indices) + + model = build_model_with_cfg( + DaViT, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs) + + return model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.95, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + **kwargs + } + + +# TODO contact authors to get larger pretrained models +default_cfgs = generate_default_cfgs({ + # official microsoft weights from https://github.com/dingmyu/davit + 'davit_tiny.msft_in1k': _cfg( + hf_hub_id='timm/'), + 'davit_small.msft_in1k': _cfg( + hf_hub_id='timm/'), + 'davit_base.msft_in1k': _cfg( + hf_hub_id='timm/'), + 'davit_large': _cfg(), + 'davit_huge': _cfg(), + 'davit_giant': _cfg(), +}) + + +@register_model +def davit_tiny(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(1, 1, 3, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24), **kwargs) + return _create_davit('davit_tiny', pretrained=pretrained, **model_kwargs) + + +@register_model +def davit_small(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(1, 1, 9, 1), embed_dims=(96, 192, 384, 768), num_heads=(3, 6, 12, 24), **kwargs) + return _create_davit('davit_small', pretrained=pretrained, **model_kwargs) + + +@register_model +def davit_base(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(1, 1, 9, 1), embed_dims=(128, 256, 512, 1024), num_heads=(4, 8, 16, 32), **kwargs) + return _create_davit('davit_base', pretrained=pretrained, **model_kwargs) + + +@register_model +def davit_large(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(1, 1, 9, 1), embed_dims=(192, 384, 768, 1536), num_heads=(6, 12, 24, 48), **kwargs) + return _create_davit('davit_large', pretrained=pretrained, **model_kwargs) + + +@register_model +def davit_huge(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(1, 1, 9, 1), embed_dims=(256, 512, 1024, 2048), num_heads=(8, 16, 32, 64), **kwargs) + return _create_davit('davit_huge', pretrained=pretrained, **model_kwargs) + + +@register_model +def davit_giant(pretrained=False, **kwargs): + model_kwargs = dict( + depths=(1, 1, 12, 3), embed_dims=(384, 768, 1536, 3072), num_heads=(12, 24, 48, 96), **kwargs) + return _create_davit('davit_giant', pretrained=pretrained, **model_kwargs) diff --git a/timm/models/densenet.py b/timm/models/densenet.py index e731f7b0..ccbb491c 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -12,7 +12,7 @@ import torch.utils.checkpoint as cp from torch.jit.annotations import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier +from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier from ._builder import build_model_with_cfg from ._manipulate import MATCH_PREV_GROUP from ._registry import register_model @@ -115,8 +115,15 @@ class DenseBlock(nn.ModuleDict): _version = 2 def __init__( - self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=BatchNormAct2d, - drop_rate=0., memory_efficient=False): + self, + num_layers, + num_input_features, + bn_size, + growth_rate, + norm_layer=BatchNormAct2d, + drop_rate=0., + memory_efficient=False, + ): super(DenseBlock, self).__init__() for i in range(num_layers): layer = DenseLayer( @@ -165,12 +172,25 @@ class DenseNet(nn.Module): """ def __init__( - self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=1000, in_chans=3, global_pool='avg', - bn_size=4, stem_type='', norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0, - memory_efficient=False, aa_stem_only=True): + self, + growth_rate=32, + block_config=(6, 12, 24, 16), + num_classes=1000, + in_chans=3, + global_pool='avg', + bn_size=4, + stem_type='', + act_layer='relu', + norm_layer='batchnorm2d', + aa_layer=None, + drop_rate=0, + memory_efficient=False, + aa_stem_only=True, + ): self.num_classes = num_classes self.drop_rate = drop_rate super(DenseNet, self).__init__() + norm_layer = get_norm_act_layer(norm_layer, act_layer=act_layer) # Stem deep_stem = 'deep' in stem_type # 3x3 deep stem @@ -226,8 +246,11 @@ class DenseNet(nn.Module): dict(num_chs=num_features, reduction=current_stride, module='features.' + module_name)] current_stride *= 2 trans = DenseTransition( - num_input_features=num_features, num_output_features=num_features // 2, - norm_layer=norm_layer, aa_layer=transition_aa_layer) + num_input_features=num_features, + num_output_features=num_features // 2, + norm_layer=norm_layer, + aa_layer=transition_aa_layer, + ) self.features.add_module(f'transition{i + 1}', trans) num_features = num_features // 2 @@ -322,8 +345,8 @@ def densenetblur121d(pretrained=False, **kwargs): `"Densely Connected Convolutional Networks" ` """ model = _create_densenet( - 'densenetblur121d', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, stem_type='deep', - aa_layer=BlurPool2d, **kwargs) + 'densenetblur121d', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, + stem_type='deep', aa_layer=BlurPool2d, **kwargs) return model @@ -382,11 +405,9 @@ def densenet264(pretrained=False, **kwargs): def densenet264d_iabn(pretrained=False, **kwargs): r"""Densenet-264 model with deep stem and Inplace-ABN """ - def norm_act_fn(num_features, **kwargs): - return create_norm_act_layer('iabn', num_features, act_layer='leaky_relu', **kwargs) model = _create_densenet( 'densenet264d_iabn', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep', - norm_layer=norm_act_fn, pretrained=pretrained, **kwargs) + norm_layer='iabn', act_layer='leaky_relu', pretrained=pretrained, **kwargs) return model diff --git a/timm/models/dpn.py b/timm/models/dpn.py index 87bd918f..29a7a7e8 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -15,7 +15,7 @@ import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier +from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier, get_norm_act_layer from ._builder import build_model_with_cfg from ._registry import register_model @@ -33,6 +33,7 @@ def _cfg(url='', **kwargs): default_cfgs = { + 'dpn48b': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'dpn68': _cfg( url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'), 'dpn68b': _cfg( @@ -82,7 +83,16 @@ class BnActConv2d(nn.Module): class DualPathBlock(nn.Module): def __init__( - self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, groups, block_type='normal', b=False): + self, + in_chs, + num_1x1_a, + num_3x3_b, + num_1x1_c, + inc, + groups, + block_type='normal', + b=False, + ): super(DualPathBlock, self).__init__() self.num_1x1_c = num_1x1_c self.inc = inc @@ -167,16 +177,31 @@ class DualPathBlock(nn.Module): class DPN(nn.Module): def __init__( - self, small=False, num_init_features=64, k_r=96, groups=32, global_pool='avg', - b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), output_stride=32, - num_classes=1000, in_chans=3, drop_rate=0., fc_act_layer=nn.ELU): + self, + k_sec=(3, 4, 20, 3), + inc_sec=(16, 32, 24, 128), + k_r=96, + groups=32, + num_classes=1000, + in_chans=3, + output_stride=32, + global_pool='avg', + small=False, + num_init_features=64, + b=False, + drop_rate=0., + norm_layer='batchnorm2d', + act_layer='relu', + fc_act_layer='elu', + ): super(DPN, self).__init__() self.num_classes = num_classes self.drop_rate = drop_rate self.b = b assert output_stride == 32 # FIXME look into dilation support - norm_layer = partial(BatchNormAct2d, eps=.001) - fc_norm_layer = partial(BatchNormAct2d, eps=.001, act_layer=fc_act_layer, inplace=False) + + norm_layer = partial(get_norm_act_layer(norm_layer, act_layer=act_layer), eps=.001) + fc_norm_layer = partial(get_norm_act_layer(norm_layer, act_layer=fc_act_layer), eps=.001, inplace=False) bw_factor = 1 if small else 4 blocks = OrderedDict() @@ -291,49 +316,57 @@ def _create_dpn(variant, pretrained=False, **kwargs): **kwargs) +@register_model +def dpn48b(pretrained=False, **kwargs): + model_kwargs = dict( + small=True, num_init_features=10, k_r=128, groups=32, + b=True, k_sec=(3, 4, 6, 3), inc_sec=(16, 32, 32, 64), act_layer='silu') + return _create_dpn('dpn48b', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + + @register_model def dpn68(pretrained=False, **kwargs): model_kwargs = dict( small=True, num_init_features=10, k_r=128, groups=32, - k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs) - return _create_dpn('dpn68', pretrained=pretrained, **model_kwargs) + k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64)) + return _create_dpn('dpn68', pretrained=pretrained, **dict(model_kwargs, **kwargs)) @register_model def dpn68b(pretrained=False, **kwargs): model_kwargs = dict( small=True, num_init_features=10, k_r=128, groups=32, - b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs) - return _create_dpn('dpn68b', pretrained=pretrained, **model_kwargs) + b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64)) + return _create_dpn('dpn68b', pretrained=pretrained, **dict(model_kwargs, **kwargs)) @register_model def dpn92(pretrained=False, **kwargs): model_kwargs = dict( num_init_features=64, k_r=96, groups=32, - k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), **kwargs) - return _create_dpn('dpn92', pretrained=pretrained, **model_kwargs) + k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128)) + return _create_dpn('dpn92', pretrained=pretrained, **dict(model_kwargs, **kwargs)) @register_model def dpn98(pretrained=False, **kwargs): model_kwargs = dict( num_init_features=96, k_r=160, groups=40, - k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128), **kwargs) - return _create_dpn('dpn98', pretrained=pretrained, **model_kwargs) + k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128)) + return _create_dpn('dpn98', pretrained=pretrained, **dict(model_kwargs, **kwargs)) @register_model def dpn131(pretrained=False, **kwargs): model_kwargs = dict( num_init_features=128, k_r=160, groups=40, - k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128), **kwargs) - return _create_dpn('dpn131', pretrained=pretrained, **model_kwargs) + k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128)) + return _create_dpn('dpn131', pretrained=pretrained, **dict(model_kwargs, **kwargs)) @register_model def dpn107(pretrained=False, **kwargs): model_kwargs = dict( num_init_features=128, k_r=200, groups=50, - k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128), **kwargs) - return _create_dpn('dpn107', pretrained=pretrained, **model_kwargs) + k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128)) + return _create_dpn('dpn107', pretrained=pretrained, **dict(model_kwargs, **kwargs)) diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index 1170e7e3..e730fa30 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -12,9 +12,6 @@ These configs work well and appear to be a bit faster / lower resource than the The models without extra prefix / suffix' (coatnet_0_224, maxvit_tiny_224, etc), are intended to match paper, BUT, without any official pretrained weights it's difficult to confirm a 100% match. -# FIXME / WARNING -This impl remains a WIP, some configs and models may vanish or change... - Papers: MaxViT: Multi-Axis Vision Transformer - https://arxiv.org/abs/2204.01697 @@ -76,6 +73,8 @@ class MaxxVitTransformerCfg: partition_ratio: int = 32 window_size: Optional[Tuple[int, int]] = None grid_size: Optional[Tuple[int, int]] = None + no_block_attn: bool = False # disable window block attention for maxvit (ie only grid) + use_nchw_attn: bool = False # for MaxViT variants (not used for CoAt), keep tensors in NCHW order init_values: Optional[float] = None act_layer: str = 'gelu' norm_layer: str = 'layernorm2d' @@ -889,19 +888,17 @@ class MaxxVitBlock(nn.Module): stride: int = 1, conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg(), transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg(), - use_nchw_attn: bool = False, # FIXME move to cfg? True is ~20-30% faster on TPU, 5-10% slower on GPU - use_block_attn: bool = True, # FIXME for testing ConvNeXt conv w/o block attention drop_path: float = 0., ): super().__init__() + self.nchw_attn = transformer_cfg.use_nchw_attn conv_cls = ConvNeXtBlock if conv_cfg.block_type == 'convnext' else MbConvBlock self.conv = conv_cls(dim, dim_out, stride=stride, cfg=conv_cfg, drop_path=drop_path) attn_kwargs = dict(dim=dim_out, cfg=transformer_cfg, drop_path=drop_path) - partition_layer = PartitionAttention2d if use_nchw_attn else PartitionAttentionCl - self.nchw_attn = use_nchw_attn - self.attn_block = partition_layer(**attn_kwargs) if use_block_attn else None + partition_layer = PartitionAttention2d if self.nchw_attn else PartitionAttentionCl + self.attn_block = None if transformer_cfg.no_block_attn else partition_layer(**attn_kwargs) self.attn_grid = partition_layer(partition_type='grid', **attn_kwargs) def init_weights(self, scheme=''): @@ -1084,26 +1081,48 @@ class NormMlpHead(nn.Module): hidden_size=None, pool_type='avg', drop_rate=0., - norm_layer=nn.LayerNorm, - act_layer=nn.Tanh, + norm_layer='layernorm2d', + act_layer='tanh', ): super().__init__() self.drop_rate = drop_rate + self.in_features = in_features + self.hidden_size = hidden_size self.num_features = in_features + self.use_conv = not pool_type + norm_layer = get_norm_layer(norm_layer) + act_layer = get_act_layer(act_layer) + linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) self.norm = norm_layer(in_features) self.flatten = nn.Flatten(1) if pool_type else nn.Identity() if hidden_size: self.pre_logits = nn.Sequential(OrderedDict([ - ('fc', nn.Linear(in_features, hidden_size)), + ('fc', linear_layer(in_features, hidden_size)), ('act', act_layer()), ])) self.num_features = hidden_size else: self.pre_logits = nn.Identity() self.drop = nn.Dropout(self.drop_rate) - self.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def reset(self, num_classes, global_pool=None): + if global_pool is not None: + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() + self.use_conv = self.global_pool.is_identity() + linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear + if self.hidden_size: + if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or + (isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)): + with torch.no_grad(): + new_fc = linear_layer(self.in_features, self.hidden_size) + new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape)) + new_fc.bias.copy_(self.pre_logits.fc.bias) + self.pre_logits.fc = new_fc + self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward(self, x, pre_logits: bool = False): x = self.global_pool(x) @@ -1116,6 +1135,26 @@ class NormMlpHead(nn.Module): return x +def _overlay_kwargs(cfg: MaxxVitCfg, **kwargs): + transformer_kwargs = {} + conv_kwargs = {} + base_kwargs = {} + for k, v in kwargs.items(): + if k.startswith('transformer_'): + transformer_kwargs[k.replace('transformer_', '')] = v + elif k.startswith('conv_'): + conv_kwargs[k.replace('conv_', '')] = v + else: + base_kwargs[k] = v + cfg = replace( + cfg, + transformer_cfg=replace(cfg.transformer_cfg, **transformer_kwargs), + conv_cfg=replace(cfg.conv_cfg, **conv_kwargs), + **base_kwargs + ) + return cfg + + class MaxxVit(nn.Module): """ CoaTNet + MaxVit base model. @@ -1130,16 +1169,20 @@ class MaxxVit(nn.Module): num_classes: int = 1000, global_pool: str = 'avg', drop_rate: float = 0., - drop_path_rate: float = 0. + drop_path_rate: float = 0., + **kwargs, ): super().__init__() img_size = to_2tuple(img_size) + if kwargs: + cfg = _overlay_kwargs(cfg, **kwargs) transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size) self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.embed_dim = cfg.embed_dim[-1] self.drop_rate = drop_rate self.grad_checkpointing = False + self.feature_info = [] self.stem = Stem( in_chs=in_chans, @@ -1150,8 +1193,8 @@ class MaxxVit(nn.Module): norm_layer=cfg.conv_cfg.norm_layer, norm_eps=cfg.conv_cfg.norm_eps, ) - stride = self.stem.stride + self.feature_info += [dict(num_chs=self.stem.out_chs, reduction=2, module='stem')] feat_size = tuple([i // s for i, s in zip(img_size, to_2tuple(stride))]) num_stages = len(cfg.embed_dim) @@ -1175,15 +1218,17 @@ class MaxxVit(nn.Module): )] stride *= stage_stride in_chs = out_chs + self.feature_info += [dict(num_chs=out_chs, reduction=stride, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) final_norm_layer = partial(get_norm_layer(cfg.transformer_cfg.norm_layer), eps=cfg.transformer_cfg.norm_eps) - if cfg.head_hidden_size: + self.head_hidden_size = cfg.head_hidden_size + if self.head_hidden_size: self.norm = nn.Identity() self.head = NormMlpHead( self.num_features, num_classes, - hidden_size=cfg.head_hidden_size, + hidden_size=self.head_hidden_size, pool_type=global_pool, drop_rate=drop_rate, norm_layer=final_norm_layer, @@ -1230,9 +1275,7 @@ class MaxxVit(nn.Module): def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes - if global_pool is None: - global_pool = self.head.global_pool.pool_type - self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + self.head.reset(num_classes, global_pool) def forward_features(self, x): x = self.stem(x) @@ -1353,6 +1396,7 @@ def _next_cfg( transformer_norm_layer='layernorm2d', transformer_norm_layer_cl='layernorm', window_size=None, + no_block_attn=False, init_values=1e-6, rel_pos_type='mlp', # MLP by default for maxxvit rel_pos_dim=512, @@ -1373,6 +1417,7 @@ def _next_cfg( expand_first=False, pool_type=pool_type, window_size=window_size, + no_block_attn=no_block_attn, # enabled for MaxxViT-V2 init_values=init_values[1], norm_layer=transformer_norm_layer, norm_layer_cl=transformer_norm_layer_cl, @@ -1399,8 +1444,8 @@ def _tf_cfg(): model_cfgs = dict( - # Fiddling with configs / defaults / still pretraining - coatnet_pico_rw_224=MaxxVitCfg( + # timm specific CoAtNet configs + coatnet_pico_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(2, 3, 5, 2), stem_width=(32, 64), @@ -1409,7 +1454,7 @@ model_cfgs = dict( conv_attn_ratio=0.25, ), ), - coatnet_nano_rw_224=MaxxVitCfg( + coatnet_nano_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(3, 4, 6, 3), stem_width=(32, 64), @@ -1419,7 +1464,7 @@ model_cfgs = dict( conv_attn_ratio=0.25, ), ), - coatnet_0_rw_224=MaxxVitCfg( + coatnet_0_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 3, 7, 2), # deeper than paper '0' model stem_width=(32, 64), @@ -1428,7 +1473,7 @@ model_cfgs = dict( transformer_shortcut_bias=False, ), ), - coatnet_1_rw_224=MaxxVitCfg( + coatnet_1_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 6, 14, 2), stem_width=(32, 64), @@ -1438,7 +1483,7 @@ model_cfgs = dict( transformer_shortcut_bias=False, ) ), - coatnet_2_rw_224=MaxxVitCfg( + coatnet_2_rw=MaxxVitCfg( embed_dim=(128, 256, 512, 1024), depths=(2, 6, 14, 2), stem_width=(64, 128), @@ -1448,7 +1493,7 @@ model_cfgs = dict( #init_values=1e-6, ), ), - coatnet_3_rw_224=MaxxVitCfg( + coatnet_3_rw=MaxxVitCfg( embed_dim=(192, 384, 768, 1536), depths=(2, 6, 14, 2), stem_width=(96, 192), @@ -1459,8 +1504,8 @@ model_cfgs = dict( ), ), - # Highly experimental configs - coatnet_bn_0_rw_224=MaxxVitCfg( + # Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos) + coatnet_bn_0_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 3, 7, 2), # deeper than paper '0' model stem_width=(32, 64), @@ -1471,7 +1516,7 @@ model_cfgs = dict( transformer_norm_layer='batchnorm2d', ) ), - coatnet_rmlp_nano_rw_224=MaxxVitCfg( + coatnet_rmlp_nano_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(3, 4, 6, 3), stem_width=(32, 64), @@ -1482,7 +1527,7 @@ model_cfgs = dict( rel_pos_dim=384, ), ), - coatnet_rmlp_0_rw_224=MaxxVitCfg( + coatnet_rmlp_0_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 3, 7, 2), # deeper than paper '0' model stem_width=(32, 64), @@ -1491,7 +1536,7 @@ model_cfgs = dict( rel_pos_type='mlp', ), ), - coatnet_rmlp_1_rw_224=MaxxVitCfg( + coatnet_rmlp_1_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 6, 14, 2), stem_width=(32, 64), @@ -1503,7 +1548,7 @@ model_cfgs = dict( rel_pos_dim=384, # was supposed to be 512, woops ), ), - coatnet_rmlp_1_rw2_224=MaxxVitCfg( + coatnet_rmlp_1_rw2=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 6, 14, 2), stem_width=(32, 64), @@ -1513,7 +1558,7 @@ model_cfgs = dict( rel_pos_dim=512, # was supposed to be 512, woops ), ), - coatnet_rmlp_2_rw_224=MaxxVitCfg( + coatnet_rmlp_2_rw=MaxxVitCfg( embed_dim=(128, 256, 512, 1024), depths=(2, 6, 14, 2), stem_width=(64, 128), @@ -1524,7 +1569,7 @@ model_cfgs = dict( rel_pos_type='mlp' ), ), - coatnet_rmlp_3_rw_224=MaxxVitCfg( + coatnet_rmlp_3_rw=MaxxVitCfg( embed_dim=(192, 384, 768, 1536), depths=(2, 6, 14, 2), stem_width=(96, 192), @@ -1536,14 +1581,14 @@ model_cfgs = dict( ), ), - coatnet_nano_cc_224=MaxxVitCfg( + coatnet_nano_cc=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(3, 4, 6, 3), stem_width=(32, 64), block_type=('C', 'C', ('C', 'T'), ('C', 'T')), **_rw_coat_cfg(), ), - coatnext_nano_rw_224=MaxxVitCfg( + coatnext_nano_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(3, 4, 6, 3), stem_width=(32, 64), @@ -1555,89 +1600,95 @@ model_cfgs = dict( ), # Trying to be like the CoAtNet paper configs - coatnet_0_224=MaxxVitCfg( + coatnet_0=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 3, 5, 2), stem_width=64, + head_hidden_size=768, ), - coatnet_1_224=MaxxVitCfg( + coatnet_1=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 6, 14, 2), stem_width=64, + head_hidden_size=768, ), - coatnet_2_224=MaxxVitCfg( + coatnet_2=MaxxVitCfg( embed_dim=(128, 256, 512, 1024), depths=(2, 6, 14, 2), stem_width=128, + head_hidden_size=1024, ), - coatnet_3_224=MaxxVitCfg( + coatnet_3=MaxxVitCfg( embed_dim=(192, 384, 768, 1536), depths=(2, 6, 14, 2), stem_width=192, + head_hidden_size=1536, ), - coatnet_4_224=MaxxVitCfg( + coatnet_4=MaxxVitCfg( embed_dim=(192, 384, 768, 1536), depths=(2, 12, 28, 2), stem_width=192, + head_hidden_size=1536, ), - coatnet_5_224=MaxxVitCfg( + coatnet_5=MaxxVitCfg( embed_dim=(256, 512, 1280, 2048), depths=(2, 12, 28, 2), stem_width=192, + head_hidden_size=2048, ), # Experimental MaxVit configs - maxvit_pico_rw_256=MaxxVitCfg( + maxvit_pico_rw=MaxxVitCfg( embed_dim=(32, 64, 128, 256), depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(24, 32), **_rw_max_cfg(), ), - maxvit_nano_rw_256=MaxxVitCfg( + maxvit_nano_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(1, 2, 3, 1), block_type=('M',) * 4, stem_width=(32, 64), **_rw_max_cfg(), ), - maxvit_tiny_rw_224=MaxxVitCfg( + maxvit_tiny_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(32, 64), **_rw_max_cfg(), ), - maxvit_tiny_rw_256=MaxxVitCfg( + maxvit_tiny_pm=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(2, 2, 5, 2), - block_type=('M',) * 4, + block_type=('PM',) * 4, stem_width=(32, 64), **_rw_max_cfg(), ), - maxvit_rmlp_pico_rw_256=MaxxVitCfg( + maxvit_rmlp_pico_rw=MaxxVitCfg( embed_dim=(32, 64, 128, 256), depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(24, 32), **_rw_max_cfg(rel_pos_type='mlp'), ), - maxvit_rmlp_nano_rw_256=MaxxVitCfg( + maxvit_rmlp_nano_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(1, 2, 3, 1), block_type=('M',) * 4, stem_width=(32, 64), **_rw_max_cfg(rel_pos_type='mlp'), ), - maxvit_rmlp_tiny_rw_256=MaxxVitCfg( + maxvit_rmlp_tiny_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(32, 64), **_rw_max_cfg(rel_pos_type='mlp'), ), - maxvit_rmlp_small_rw_224=MaxxVitCfg( + maxvit_rmlp_small_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 2, 5, 2), block_type=('M',) * 4, @@ -1647,26 +1698,18 @@ model_cfgs = dict( init_values=1e-6, ), ), - maxvit_rmlp_small_rw_256=MaxxVitCfg( + maxvit_rmlp_base_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), - depths=(2, 2, 5, 2), + depths=(2, 6, 14, 2), block_type=('M',) * 4, stem_width=(32, 64), + head_hidden_size=768, **_rw_max_cfg( rel_pos_type='mlp', - init_values=1e-6, ), ), - maxvit_tiny_pm_256=MaxxVitCfg( - embed_dim=(64, 128, 256, 512), - depths=(2, 2, 5, 2), - block_type=('PM',) * 4, - stem_width=(32, 64), - **_rw_max_cfg(), - ), - - maxxvit_rmlp_nano_rw_256=MaxxVitCfg( + maxxvit_rmlp_nano_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(1, 2, 3, 1), block_type=('M',) * 4, @@ -1674,33 +1717,50 @@ model_cfgs = dict( weight_init='normal', **_next_cfg(), ), - maxxvit_rmlp_tiny_rw_256=MaxxVitCfg( + maxxvit_rmlp_tiny_rw=MaxxVitCfg( embed_dim=(64, 128, 256, 512), depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(32, 64), **_next_cfg(), ), - maxxvit_rmlp_small_rw_256=MaxxVitCfg( + maxxvit_rmlp_small_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), depths=(2, 2, 5, 2), block_type=('M',) * 4, stem_width=(48, 96), **_next_cfg(), ), - maxxvit_rmlp_base_rw_224=MaxxVitCfg( + + maxxvitv2_nano_rw=MaxxVitCfg( embed_dim=(96, 192, 384, 768), - depths=(2, 6, 14, 2), + depths=(1, 2, 3, 1), block_type=('M',) * 4, stem_width=(48, 96), - **_next_cfg(), + weight_init='normal', + **_next_cfg( + no_block_attn=True, + rel_pos_type='bias', + ), ), - maxxvit_rmlp_large_rw_224=MaxxVitCfg( + maxxvitv2_rmlp_base_rw=MaxxVitCfg( embed_dim=(128, 256, 512, 1024), depths=(2, 6, 12, 2), block_type=('M',) * 4, stem_width=(64, 128), - **_next_cfg(), + **_next_cfg( + no_block_attn=True, + ), + ), + maxxvitv2_rmlp_large_rw=MaxxVitCfg( + embed_dim=(160, 320, 640, 1280), + depths=(2, 6, 16, 2), + block_type=('M',) * 4, + stem_width=(80, 160), + head_hidden_size=1280, + **_next_cfg( + no_block_attn=True, + ), ), # Trying to be like the MaxViT paper configs @@ -1752,11 +1812,29 @@ model_cfgs = dict( ) +def checkpoint_filter_fn(state_dict, model: nn.Module): + model_state_dict = model.state_dict() + out_dict = {} + for k, v in state_dict.items(): + if k in model_state_dict and v.ndim != model_state_dict[k].ndim and v.numel() == model_state_dict[k].numel(): + # adapt between conv2d / linear layers + assert v.ndim in (2, 4) + v = v.reshape(model_state_dict[k].shape) + out_dict[k] = v + return out_dict + + def _create_maxxvit(variant, cfg_variant=None, pretrained=False, **kwargs): + if cfg_variant is None: + if variant in model_cfgs: + cfg_variant = variant + else: + cfg_variant = '_'.join(variant.split('_')[:-1]) return build_model_with_cfg( MaxxVit, variant, pretrained, - model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], + model_cfg=model_cfgs[cfg_variant], feature_cfg=dict(flatten_sequential=True), + pretrained_filter_fn=checkpoint_filter_fn, **kwargs) @@ -1772,149 +1850,218 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs({ - # Fiddling with configs / defaults / still pretraining - 'coatnet_pico_rw_224': _cfg(url=''), - 'coatnet_nano_rw_224': _cfg( + # timm specific CoAtNet configs, ImageNet-1k pretrain, fixed rel-pos + 'coatnet_pico_rw_224.untrained': _cfg(url=''), + 'coatnet_nano_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_nano_rw_224_sw-f53093b4.pth', crop_pct=0.9), - 'coatnet_0_rw_224': _cfg( + 'coatnet_0_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_0_rw_224_sw-a6439706.pth'), - 'coatnet_1_rw_224': _cfg( + 'coatnet_1_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_1_rw_224_sw-5cae1ea8.pth' ), - 'coatnet_2_rw_224': _cfg(url=''), - 'coatnet_3_rw_224': _cfg(url=''), - # Highly experimental configs - 'coatnet_bn_0_rw_224': _cfg( + # timm specific CoAtNet configs, ImageNet-12k pretrain w/ 1k fine-tune, fixed rel-pos + 'coatnet_2_rw_224.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/'), + #'coatnet_3_rw_224.untrained': _cfg(url=''), + + # Experimental CoAtNet configs w/ ImageNet-12k pretrain -> 1k fine-tune (different norm layers, MLP rel-pos) + 'coatnet_rmlp_1_rw2_224.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'coatnet_rmlp_2_rw_224.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'coatnet_rmlp_2_rw_384.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + + # Experimental CoAtNet configs w/ ImageNet-1k train (different norm layers, MLP rel-pos) + 'coatnet_bn_0_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_bn_0_rw_224_sw-c228e218.pth', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, crop_pct=0.95), - 'coatnet_rmlp_nano_rw_224': _cfg( + 'coatnet_rmlp_nano_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_nano_rw_224_sw-bd1d51b3.pth', crop_pct=0.9), - 'coatnet_rmlp_0_rw_224': _cfg(url=''), - 'coatnet_rmlp_1_rw_224': _cfg( + 'coatnet_rmlp_0_rw_224.untrained': _cfg(url=''), + 'coatnet_rmlp_1_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_1_rw_224_sw-9051e6c3.pth'), - 'coatnet_rmlp_1_rw2_224': _cfg(url=''), - 'coatnet_rmlp_2_rw_224': _cfg( + 'coatnet_rmlp_2_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnet_rmlp_2_rw_224_sw-5ccfac55.pth'), - 'coatnet_rmlp_3_rw_224': _cfg(url=''), - 'coatnet_nano_cc_224': _cfg(url=''), - 'coatnext_nano_rw_224': _cfg( + 'coatnet_rmlp_3_rw_224.untrained': _cfg(url=''), + 'coatnet_nano_cc_224.untrained': _cfg(url=''), + 'coatnext_nano_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/coatnext_nano_rw_224_ad-22cb71c2.pth', crop_pct=0.9), - # Trying to be like the CoAtNet paper configs - 'coatnet_0_224': _cfg(url=''), - 'coatnet_1_224': _cfg(url=''), - 'coatnet_2_224': _cfg(url=''), - 'coatnet_3_224': _cfg(url=''), - 'coatnet_4_224': _cfg(url=''), - 'coatnet_5_224': _cfg(url=''), - - # Experimental configs - 'maxvit_pico_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_nano_rw_256': _cfg( + # ImagenNet-12k pretrain CoAtNet + 'coatnet_2_rw_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), + 'coatnet_3_rw_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), + 'coatnet_rmlp_1_rw2_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), + 'coatnet_rmlp_2_rw_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), + + # Trying to be like the CoAtNet paper configs (will adapt if 'tf' weights are ever released) + 'coatnet_0_224.untrained': _cfg(url=''), + 'coatnet_1_224.untrained': _cfg(url=''), + 'coatnet_2_224.untrained': _cfg(url=''), + 'coatnet_3_224.untrained': _cfg(url=''), + 'coatnet_4_224.untrained': _cfg(url=''), + 'coatnet_5_224.untrained': _cfg(url=''), + + # timm specific MaxVit configs, ImageNet-1k pretrain or untrained + 'maxvit_pico_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxvit_nano_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_nano_rw_256_sw-fb127241.pth', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_tiny_rw_224': _cfg( + 'maxvit_tiny_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_tiny_rw_224_sw-7d0dffeb.pth'), - 'maxvit_tiny_rw_256': _cfg( + 'maxvit_tiny_rw_256.untrained': _cfg( url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_rmlp_pico_rw_256': _cfg( + 'maxvit_tiny_pm_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + + # timm specific MaxVit w/ MLP rel-pos, ImageNet-1k pretrain + 'maxvit_rmlp_pico_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_pico_rw_256_sw-8d82f2c6.pth', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_rmlp_nano_rw_256': _cfg( + 'maxvit_rmlp_nano_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_nano_rw_256_sw-c17bb0d6.pth', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_rmlp_tiny_rw_256': _cfg( + 'maxvit_rmlp_tiny_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_tiny_rw_256_sw-bbef0ff5.pth', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_rmlp_small_rw_224': _cfg( + 'maxvit_rmlp_small_rw_224.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxvit_rmlp_small_rw_224_sw-6ef0ae4f.pth', crop_pct=0.9, ), - 'maxvit_rmlp_small_rw_256': _cfg( + 'maxvit_rmlp_small_rw_256.untrained': _cfg( url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + # timm specific MaxVit w/ ImageNet-12k pretrain and 1k fine-tune + 'maxvit_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + ), + 'maxvit_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + + # timm specific MaxVit w/ ImageNet-12k pretrain + 'maxvit_rmlp_base_rw_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821, + ), - 'maxxvit_rmlp_nano_rw_256': _cfg( + # timm MaxxViT configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks) + 'maxxvit_rmlp_nano_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_nano_rw_256_sw-0325d459.pth', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxxvit_rmlp_tiny_rw_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxxvit_rmlp_small_rw_256': _cfg( + 'maxxvit_rmlp_tiny_rw_256.untrained': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvit_rmlp_small_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-maxx/maxxvit_rmlp_small_rw_256_sw-37e217ff.pth', input_size=(3, 256, 256), pool_size=(8, 8)), - 'maxxvit_rmlp_base_rw_224': _cfg(url=''), - 'maxxvit_rmlp_large_rw_224': _cfg(url=''), + # timm MaxxViT-V2 configs (ConvNeXt conv blocks mixed with MaxVit transformer blocks, more width, no block attn) + 'maxxvitv2_nano_rw_256.sw_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 256, 256), pool_size=(8, 8)), + 'maxxvitv2_rmlp_base_rw_224.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'maxxvitv2_rmlp_base_rw_384.sw_in12k_ft_in1k': _cfg( + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), + 'maxxvitv2_rmlp_large_rw_224.untrained': _cfg(url=''), + + 'maxxvitv2_rmlp_base_rw_224.sw_in12k': _cfg( + hf_hub_id='timm/', + num_classes=11821), # MaxViT models ported from official Tensorflow impl 'maxvit_tiny_tf_224.in1k': _cfg( - hf_hub_id='timm/maxvit_tiny_tf_224.in1k', + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'maxvit_tiny_tf_384.in1k': _cfg( - hf_hub_id='timm/maxvit_tiny_tf_384.in1k', - input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'maxvit_tiny_tf_512.in1k': _cfg( - hf_hub_id='timm/maxvit_tiny_tf_512.in1k', - input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), 'maxvit_small_tf_224.in1k': _cfg( - hf_hub_id='timm/maxvit_small_tf_224.in1k', + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'maxvit_small_tf_384.in1k': _cfg( - hf_hub_id='timm/maxvit_small_tf_384.in1k', - input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'maxvit_small_tf_512.in1k': _cfg( - hf_hub_id='timm/maxvit_small_tf_512.in1k', - input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), 'maxvit_base_tf_224.in1k': _cfg( - hf_hub_id='timm/maxvit_base_tf_224.in1k', + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'maxvit_base_tf_384.in1k': _cfg( - hf_hub_id='timm/maxvit_base_tf_384.in1k', - input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'maxvit_base_tf_512.in1k': _cfg( - hf_hub_id='timm/maxvit_base_tf_512.in1k', - input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), 'maxvit_large_tf_224.in1k': _cfg( - hf_hub_id='timm/maxvit_large_tf_224.in1k', + hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD), 'maxvit_large_tf_384.in1k': _cfg( - hf_hub_id='timm/maxvit_large_tf_384.in1k', - input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'maxvit_large_tf_512.in1k': _cfg( - hf_hub_id='timm/maxvit_large_tf_512.in1k', - input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), 'maxvit_base_tf_224.in21k': _cfg( url=''), 'maxvit_base_tf_384.in21k_ft_in1k': _cfg( - hf_hub_id='timm/maxvit_base_tf_384.in21k_ft_in1k', - input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'maxvit_base_tf_512.in21k_ft_in1k': _cfg( - hf_hub_id='timm/maxvit_base_tf_512.in21k_ft_in1k', - input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), 'maxvit_large_tf_224.in21k': _cfg( url=''), 'maxvit_large_tf_384.in21k_ft_in1k': _cfg( - hf_hub_id='timm/maxvit_large_tf_384.in21k_ft_in1k', - input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'maxvit_large_tf_512.in21k_ft_in1k': _cfg( - hf_hub_id='timm/maxvit_large_tf_512.in21k_ft_in1k', + hf_hub_id='timm/', input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), 'maxvit_xlarge_tf_224.in21k': _cfg( url=''), 'maxvit_xlarge_tf_384.in21k_ft_in1k': _cfg( - hf_hub_id='timm/maxvit_xlarge_tf_384.in21k_ft_in1k', - input_size=(3, 384, 384), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'), 'maxvit_xlarge_tf_512.in21k_ft_in1k': _cfg( - hf_hub_id='timm/maxvit_xlarge_tf_512.in21k_ft_in1k', - input_size=(3, 512, 512), crop_pct=1.0, crop_mode='squash'), + hf_hub_id='timm/', + input_size=(3, 512, 512), pool_size=(16, 16), crop_pct=1.0, crop_mode='squash'), }) @@ -1978,6 +2125,11 @@ def coatnet_rmlp_2_rw_224(pretrained=False, **kwargs): return _create_maxxvit('coatnet_rmlp_2_rw_224', pretrained=pretrained, **kwargs) +@register_model +def coatnet_rmlp_2_rw_384(pretrained=False, **kwargs): + return _create_maxxvit('coatnet_rmlp_2_rw_384', pretrained=pretrained, **kwargs) + + @register_model def coatnet_rmlp_3_rw_224(pretrained=False, **kwargs): return _create_maxxvit('coatnet_rmlp_3_rw_224', pretrained=pretrained, **kwargs) @@ -2068,6 +2220,16 @@ def maxvit_rmlp_small_rw_256(pretrained=False, **kwargs): return _create_maxxvit('maxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs) +@register_model +def maxvit_rmlp_base_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxvit_rmlp_base_rw_384(pretrained=False, **kwargs): + return _create_maxxvit('maxvit_rmlp_base_rw_384', pretrained=pretrained, **kwargs) + + @register_model def maxvit_tiny_pm_256(pretrained=False, **kwargs): return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs) @@ -2089,13 +2251,23 @@ def maxxvit_rmlp_small_rw_256(pretrained=False, **kwargs): @register_model -def maxxvit_rmlp_base_rw_224(pretrained=False, **kwargs): - return _create_maxxvit('maxxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs) +def maxxvitv2_nano_rw_256(pretrained=False, **kwargs): + return _create_maxxvit('maxxvitv2_nano_rw_256', pretrained=pretrained, **kwargs) + + +@register_model +def maxxvitv2_rmlp_base_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('maxxvitv2_rmlp_base_rw_224', pretrained=pretrained, **kwargs) + + +@register_model +def maxxvitv2_rmlp_base_rw_384(pretrained=False, **kwargs): + return _create_maxxvit('maxxvitv2_rmlp_base_rw_384', pretrained=pretrained, **kwargs) @register_model -def maxxvit_rmlp_large_rw_224(pretrained=False, **kwargs): - return _create_maxxvit('maxxvit_rmlp_large_rw_224', pretrained=pretrained, **kwargs) +def maxxvitv2_rmlp_large_rw_224(pretrained=False, **kwargs): + return _create_maxxvit('maxxvitv2_rmlp_large_rw_224', pretrained=pretrained, **kwargs) @register_model diff --git a/timm/models/mobilevit.py b/timm/models/mobilevit.py index 3d2ae84a..8e8f4428 100644 --- a/timm/models/mobilevit.py +++ b/timm/models/mobilevit.py @@ -266,9 +266,16 @@ class MobileVitBlock(nn.Module): self.transformer = nn.Sequential(*[ TransformerBlock( - transformer_dim, mlp_ratio=mlp_ratio, num_heads=num_heads, qkv_bias=True, - attn_drop=attn_drop, drop=drop, drop_path=drop_path_rate, - act_layer=layers.act, norm_layer=transformer_norm_layer) + transformer_dim, + mlp_ratio=mlp_ratio, + num_heads=num_heads, + qkv_bias=True, + attn_drop=attn_drop, + drop=drop, + drop_path=drop_path_rate, + act_layer=layers.act, + norm_layer=transformer_norm_layer, + ) for _ in range(transformer_depth) ]) self.norm = transformer_norm_layer(transformer_dim) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 48f91b35..f9a90ab3 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -17,7 +17,7 @@ Status: Hacked together by / copyright Ross Wightman, 2021. """ from collections import OrderedDict -from dataclasses import dataclass +from dataclasses import dataclass, replace from functools import partial from typing import Tuple, Optional @@ -159,11 +159,25 @@ class NfCfg: def _nfres_cfg( - depths, channels=(256, 512, 1024, 2048), group_size=None, act_layer='relu', attn_layer=None, attn_kwargs=None): + depths, + channels=(256, 512, 1024, 2048), + group_size=None, + act_layer='relu', + attn_layer=None, + attn_kwargs=None, +): attn_kwargs = attn_kwargs or {} cfg = NfCfg( - depths=depths, channels=channels, stem_type='7x7_pool', stem_chs=64, bottle_ratio=0.25, - group_size=group_size, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs) + depths=depths, + channels=channels, + stem_type='7x7_pool', + stem_chs=64, + bottle_ratio=0.25, + group_size=group_size, + act_layer=act_layer, + attn_layer=attn_layer, + attn_kwargs=attn_kwargs, + ) return cfg @@ -171,28 +185,70 @@ def _nfreg_cfg(depths, channels=(48, 104, 208, 440)): num_features = 1280 * channels[-1] // 440 attn_kwargs = dict(rd_ratio=0.5) cfg = NfCfg( - depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25, - num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs) + depths=depths, + channels=channels, + stem_type='3x3', + group_size=8, + width_factor=0.75, + bottle_ratio=2.25, + num_features=num_features, + reg=True, + attn_layer='se', + attn_kwargs=attn_kwargs, + ) return cfg def _nfnet_cfg( - depths, channels=(256, 512, 1536, 1536), group_size=128, bottle_ratio=0.5, feat_mult=2., - act_layer='gelu', attn_layer='se', attn_kwargs=None): + depths, + channels=(256, 512, 1536, 1536), + group_size=128, + bottle_ratio=0.5, + feat_mult=2., + act_layer='gelu', + attn_layer='se', + attn_kwargs=None, +): num_features = int(channels[-1] * feat_mult) attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(rd_ratio=0.5) cfg = NfCfg( - depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=group_size, - bottle_ratio=bottle_ratio, extra_conv=True, num_features=num_features, act_layer=act_layer, - attn_layer=attn_layer, attn_kwargs=attn_kwargs) + depths=depths, + channels=channels, + stem_type='deep_quad', + stem_chs=128, + group_size=group_size, + bottle_ratio=bottle_ratio, + extra_conv=True, + num_features=num_features, + act_layer=act_layer, + attn_layer=attn_layer, + attn_kwargs=attn_kwargs, + ) return cfg -def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', skipinit=True): +def _dm_nfnet_cfg( + depths, + channels=(256, 512, 1536, 1536), + act_layer='gelu', + skipinit=True, +): cfg = NfCfg( - depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=128, - bottle_ratio=0.5, extra_conv=True, gamma_in_act=True, same_padding=True, skipinit=skipinit, - num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=dict(rd_ratio=0.5)) + depths=depths, + channels=channels, + stem_type='deep_quad', + stem_chs=128, + group_size=128, + bottle_ratio=0.5, + extra_conv=True, + gamma_in_act=True, + same_padding=True, + skipinit=skipinit, + num_features=int(channels[-1] * 2.0), + act_layer=act_layer, + attn_layer='se', + attn_kwargs=dict(rd_ratio=0.5), + ) return cfg @@ -278,7 +334,14 @@ def act_with_gamma(act_type, gamma: float = 1.): class DownsampleAvg(nn.Module): def __init__( - self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, conv_layer=ScaledStdConv2d): + self, + in_chs, + out_chs, + stride=1, + dilation=1, + first_dilation=None, + conv_layer=ScaledStdConv2d, + ): """ AvgPool Downsampling as in 'D' ResNet variants. Support for dilation.""" super(DownsampleAvg, self).__init__() avg_stride = stride if dilation == 1 else 1 @@ -299,9 +362,26 @@ class NormFreeBlock(nn.Module): """ def __init__( - self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None, - alpha=1.0, beta=1.0, bottle_ratio=0.25, group_size=None, ch_div=1, reg=True, extra_conv=False, - skipinit=False, attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0.): + self, + in_chs, + out_chs=None, + stride=1, + dilation=1, + first_dilation=None, + alpha=1.0, + beta=1.0, + bottle_ratio=0.25, + group_size=None, + ch_div=1, + reg=True, + extra_conv=False, + skipinit=False, + attn_layer=None, + attn_gain=2.0, + act_layer=None, + conv_layer=None, + drop_path_rate=0., + ): super().__init__() first_dilation = first_dilation or dilation out_chs = out_chs or in_chs @@ -316,7 +396,13 @@ class NormFreeBlock(nn.Module): if in_chs != out_chs or stride != 1 or dilation != first_dilation: self.downsample = DownsampleAvg( - in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, conv_layer=conv_layer) + in_chs, + out_chs, + stride=stride, + dilation=dilation, + first_dilation=first_dilation, + conv_layer=conv_layer, + ) else: self.downsample = None @@ -452,14 +538,33 @@ class NormFreeNet(nn.Module): for what it is/does. Approx 8-10% throughput loss. """ def __init__( - self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, - drop_rate=0., drop_path_rate=0. + self, + cfg: NfCfg, + num_classes=1000, + in_chans=3, + global_pool='avg', + output_stride=32, + drop_rate=0., + drop_path_rate=0., + **kwargs, ): + """ + Args: + cfg (NfCfg): Model architecture configuration + num_classes (int): Number of classifier classes (default: 1000) + in_chans (int): Number of input channels (default: 3) + global_pool (str): Global pooling type (default: 'avg') + output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32) + drop_rate (float): Dropout rate (default: 0.) + drop_path_rate (float): Stochastic depth drop-path rate (default: 0.) + kwargs (dict): Extra kwargs overlayed onto cfg + """ super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate self.grad_checkpointing = False + cfg = replace(cfg, **kwargs) assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})." conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d if cfg.gamma_in_act: @@ -472,7 +577,12 @@ class NormFreeNet(nn.Module): stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div) self.stem, stem_stride, stem_feat = create_stem( - in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer) + in_chans, + stem_chs, + cfg.stem_type, + conv_layer=conv_layer, + act_layer=act_layer, + ) self.feature_info = [stem_feat] drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] diff --git a/timm/models/regnet.py b/timm/models/regnet.py index e1cc821b..63c9b57f 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -14,7 +14,7 @@ Weights from original impl have been modified Hacked together by / Copyright 2020 Ross Wightman """ import math -from dataclasses import dataclass +from dataclasses import dataclass, replace from functools import partial from typing import Optional, Union, Callable @@ -237,7 +237,15 @@ def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_la def create_shortcut( - downsample_type, in_chs, out_chs, kernel_size, stride, dilation=(1, 1), norm_layer=None, preact=False): + downsample_type, + in_chs, + out_chs, + kernel_size, + stride, + dilation=(1, 1), + norm_layer=None, + preact=False, +): assert downsample_type in ('avg', 'conv1x1', '', None) if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: dargs = dict(stride=stride, dilation=dilation[0], norm_layer=norm_layer, preact=preact) @@ -259,9 +267,21 @@ class Bottleneck(nn.Module): """ def __init__( - self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25, - downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - drop_block=None, drop_path_rate=0.): + self, + in_chs, + out_chs, + stride=1, + dilation=(1, 1), + bottle_ratio=1, + group_size=1, + se_ratio=0.25, + downsample='conv1x1', + linear_out=False, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + drop_block=None, + drop_path_rate=0., + ): super(Bottleneck, self).__init__() act_layer = get_act_layer(act_layer) bottleneck_chs = int(round(out_chs * bottle_ratio)) @@ -307,9 +327,21 @@ class PreBottleneck(nn.Module): """ def __init__( - self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25, - downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - drop_block=None, drop_path_rate=0.): + self, + in_chs, + out_chs, + stride=1, + dilation=(1, 1), + bottle_ratio=1, + group_size=1, + se_ratio=0.25, + downsample='conv1x1', + linear_out=False, + act_layer=nn.ReLU, + norm_layer=nn.BatchNorm2d, + drop_block=None, + drop_path_rate=0., + ): super(PreBottleneck, self).__init__() norm_act_layer = get_norm_act_layer(norm_layer, act_layer) bottleneck_chs = int(round(out_chs * bottle_ratio)) @@ -353,8 +385,16 @@ class RegStage(nn.Module): """Stage (sequence of blocks w/ the same output shape).""" def __init__( - self, depth, in_chs, out_chs, stride, dilation, - drop_path_rates=None, block_fn=Bottleneck, **block_kwargs): + self, + depth, + in_chs, + out_chs, + stride, + dilation, + drop_path_rates=None, + block_fn=Bottleneck, + **block_kwargs, + ): super(RegStage, self).__init__() self.grad_checkpointing = False @@ -367,8 +407,13 @@ class RegStage(nn.Module): name = "b{}".format(i + 1) self.add_module( name, block_fn( - block_in_chs, out_chs, stride=block_stride, dilation=block_dilation, - drop_path_rate=dpr, **block_kwargs) + block_in_chs, + out_chs, + stride=block_stride, + dilation=block_dilation, + drop_path_rate=dpr, + **block_kwargs, + ) ) first_dilation = dilation @@ -389,12 +434,35 @@ class RegNet(nn.Module): """ def __init__( - self, cfg: RegNetCfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', - drop_rate=0., drop_path_rate=0., zero_init_last=True): + self, + cfg: RegNetCfg, + in_chans=3, + num_classes=1000, + output_stride=32, + global_pool='avg', + drop_rate=0., + drop_path_rate=0., + zero_init_last=True, + **kwargs, + ): + """ + + Args: + cfg (RegNetCfg): Model architecture configuration + in_chans (int): Number of input channels (default: 3) + num_classes (int): Number of classifier classes (default: 1000) + output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32) + global_pool (str): Global pooling type (default: 'avg') + drop_rate (float): Dropout rate (default: 0.) + drop_path_rate (float): Stochastic depth drop-path rate (default: 0.) + zero_init_last (bool): Zero-init last weight of residual path + kwargs (dict): Extra kwargs overlayed onto cfg + """ super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate assert output_stride in (8, 16, 32) + cfg = replace(cfg, **kwargs) # update cfg with extra passed kwargs # Construct the stem stem_width = cfg.stem_width @@ -428,7 +496,7 @@ class RegNet(nn.Module): self.final_conv = get_act_layer(cfg.act_layer)() if final_act else nn.Identity() self.num_features = prev_width self.head = ClassifierHead( - in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) @@ -461,8 +529,12 @@ class RegNet(nn.Module): dict(zip(arg_names, params)) for params in zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_br, stage_gs, stage_dpr)] common_args = dict( - downsample=cfg.downsample, se_ratio=cfg.se_ratio, linear_out=cfg.linear_out, - act_layer=cfg.act_layer, norm_layer=cfg.norm_layer) + downsample=cfg.downsample, + se_ratio=cfg.se_ratio, + linear_out=cfg.linear_out, + act_layer=cfg.act_layer, + norm_layer=cfg.norm_layer, + ) return per_stage_args, common_args @torch.jit.ignore @@ -518,7 +590,6 @@ def _init_weights(module, name='', zero_init_last=False): def _filter_fn(state_dict): - """ convert patch embedding weight from manual patchify + linear proj to conv""" if 'classy_state_dict' in state_dict: import re state_dict = state_dict['classy_state_dict']['base_model']['model'] diff --git a/timm/models/res2net.py b/timm/models/res2net.py index 607ba722..29a49953 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -156,8 +156,8 @@ def res2net50_26w_4s(pretrained=False, **kwargs): pretrained (bool): If True, returns a model pre-trained on ImageNet """ model_args = dict( - block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4), **kwargs) - return _create_res2net('res2net50_26w_4s', pretrained, **model_args) + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4)) + return _create_res2net('res2net50_26w_4s', pretrained, **dict(model_args, **kwargs)) @register_model @@ -167,8 +167,8 @@ def res2net101_26w_4s(pretrained=False, **kwargs): pretrained (bool): If True, returns a model pre-trained on ImageNet """ model_args = dict( - block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4), **kwargs) - return _create_res2net('res2net101_26w_4s', pretrained, **model_args) + block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4)) + return _create_res2net('res2net101_26w_4s', pretrained, **dict(model_args, **kwargs)) @register_model @@ -178,8 +178,8 @@ def res2net50_26w_6s(pretrained=False, **kwargs): pretrained (bool): If True, returns a model pre-trained on ImageNet """ model_args = dict( - block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6), **kwargs) - return _create_res2net('res2net50_26w_6s', pretrained, **model_args) + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6)) + return _create_res2net('res2net50_26w_6s', pretrained, **dict(model_args, **kwargs)) @register_model @@ -189,8 +189,8 @@ def res2net50_26w_8s(pretrained=False, **kwargs): pretrained (bool): If True, returns a model pre-trained on ImageNet """ model_args = dict( - block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8), **kwargs) - return _create_res2net('res2net50_26w_8s', pretrained, **model_args) + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8)) + return _create_res2net('res2net50_26w_8s', pretrained, **dict(model_args, **kwargs)) @register_model @@ -200,8 +200,8 @@ def res2net50_48w_2s(pretrained=False, **kwargs): pretrained (bool): If True, returns a model pre-trained on ImageNet """ model_args = dict( - block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2), **kwargs) - return _create_res2net('res2net50_48w_2s', pretrained, **model_args) + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2)) + return _create_res2net('res2net50_48w_2s', pretrained, **dict(model_args, **kwargs)) @register_model @@ -211,8 +211,8 @@ def res2net50_14w_8s(pretrained=False, **kwargs): pretrained (bool): If True, returns a model pre-trained on ImageNet """ model_args = dict( - block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8), **kwargs) - return _create_res2net('res2net50_14w_8s', pretrained, **model_args) + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8)) + return _create_res2net('res2net50_14w_8s', pretrained, **dict(model_args, **kwargs)) @register_model @@ -222,5 +222,5 @@ def res2next50(pretrained=False, **kwargs): pretrained (bool): If True, returns a model pre-trained on ImageNet """ model_args = dict( - block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4), **kwargs) - return _create_res2net('res2next50', pretrained, **model_args) + block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4)) + return _create_res2net('res2next50', pretrained, **dict(model_args, **kwargs)) diff --git a/timm/models/resnest.py b/timm/models/resnest.py index 853ee1d0..38303f9c 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -163,8 +163,8 @@ def resnest14d(pretrained=False, **kwargs): model_kwargs = dict( block=ResNestBottleneck, layers=[1, 1, 1, 1], stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, - block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) - return _create_resnest('resnest14d', pretrained=pretrained, **model_kwargs) + block_args=dict(radix=2, avd=True, avd_first=False)) + return _create_resnest('resnest14d', pretrained=pretrained, **dict(model_kwargs, **kwargs)) @register_model @@ -174,8 +174,8 @@ def resnest26d(pretrained=False, **kwargs): model_kwargs = dict( block=ResNestBottleneck, layers=[2, 2, 2, 2], stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, - block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) - return _create_resnest('resnest26d', pretrained=pretrained, **model_kwargs) + block_args=dict(radix=2, avd=True, avd_first=False)) + return _create_resnest('resnest26d', pretrained=pretrained, **dict(model_kwargs, **kwargs)) @register_model @@ -186,8 +186,8 @@ def resnest50d(pretrained=False, **kwargs): model_kwargs = dict( block=ResNestBottleneck, layers=[3, 4, 6, 3], stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1, - block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) - return _create_resnest('resnest50d', pretrained=pretrained, **model_kwargs) + block_args=dict(radix=2, avd=True, avd_first=False)) + return _create_resnest('resnest50d', pretrained=pretrained, **dict(model_kwargs, **kwargs)) @register_model @@ -198,8 +198,8 @@ def resnest101e(pretrained=False, **kwargs): model_kwargs = dict( block=ResNestBottleneck, layers=[3, 4, 23, 3], stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, - block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) - return _create_resnest('resnest101e', pretrained=pretrained, **model_kwargs) + block_args=dict(radix=2, avd=True, avd_first=False)) + return _create_resnest('resnest101e', pretrained=pretrained, **dict(model_kwargs, **kwargs)) @register_model @@ -210,8 +210,8 @@ def resnest200e(pretrained=False, **kwargs): model_kwargs = dict( block=ResNestBottleneck, layers=[3, 24, 36, 3], stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, - block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) - return _create_resnest('resnest200e', pretrained=pretrained, **model_kwargs) + block_args=dict(radix=2, avd=True, avd_first=False)) + return _create_resnest('resnest200e', pretrained=pretrained, **dict(model_kwargs, **kwargs)) @register_model @@ -222,8 +222,8 @@ def resnest269e(pretrained=False, **kwargs): model_kwargs = dict( block=ResNestBottleneck, layers=[3, 30, 48, 8], stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1, - block_args=dict(radix=2, avd=True, avd_first=False), **kwargs) - return _create_resnest('resnest269e', pretrained=pretrained, **model_kwargs) + block_args=dict(radix=2, avd=True, avd_first=False)) + return _create_resnest('resnest269e', pretrained=pretrained, **dict(model_kwargs, **kwargs)) @register_model @@ -233,8 +233,8 @@ def resnest50d_4s2x40d(pretrained=False, **kwargs): model_kwargs = dict( block=ResNestBottleneck, layers=[3, 4, 6, 3], stem_type='deep', stem_width=32, avg_down=True, base_width=40, cardinality=2, - block_args=dict(radix=4, avd=True, avd_first=True), **kwargs) - return _create_resnest('resnest50d_4s2x40d', pretrained=pretrained, **model_kwargs) + block_args=dict(radix=4, avd=True, avd_first=True)) + return _create_resnest('resnest50d_4s2x40d', pretrained=pretrained, **dict(model_kwargs, **kwargs)) @register_model @@ -244,5 +244,5 @@ def resnest50d_1s4x24d(pretrained=False, **kwargs): model_kwargs = dict( block=ResNestBottleneck, layers=[3, 4, 6, 3], stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4, - block_args=dict(radix=1, avd=True, avd_first=True), **kwargs) - return _create_resnest('resnest50d_1s4x24d', pretrained=pretrained, **model_kwargs) + block_args=dict(radix=1, avd=True, avd_first=True)) + return _create_resnest('resnest50d_1s4x24d', pretrained=pretrained, **dict(model_kwargs, **kwargs)) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 2976c1f9..200280b3 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -16,7 +16,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, \ - create_classifier + get_act_layer, get_norm_layer, create_classifier from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq from ._registry import register_model, model_entrypoint @@ -500,7 +500,14 @@ class Bottleneck(nn.Module): def downsample_conv( - in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + first_dilation=None, + norm_layer=None, +): norm_layer = norm_layer or nn.BatchNorm2d kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1 @@ -514,7 +521,14 @@ def downsample_conv( def downsample_avg( - in_channels, out_channels, kernel_size, stride=1, dilation=1, first_dilation=None, norm_layer=None): + in_channels, + out_channels, + kernel_size, + stride=1, + dilation=1, + first_dilation=None, + norm_layer=None, +): norm_layer = norm_layer or nn.BatchNorm2d avg_stride = stride if dilation == 1 else 1 if stride == 1 and dilation == 1: @@ -627,31 +641,6 @@ class ResNet(nn.Module): SENet-154 - 3 layer deep 3x3 stem (same as v1c-v1s), stem_width = 64, cardinality=64, reduction by 2 on width of first bottleneck convolution, 3x3 downsample convs after first block - - Parameters - ---------- - block : Block, class for the residual block. Options are BasicBlockGl, BottleneckGl. - layers : list of int, number of layers in each block - num_classes : int, default 1000, number of classification classes. - in_chans : int, default 3, number of input (color) channels. - output_stride : int, default 32, output stride of the network, 32, 16, or 8. - global_pool : str, Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' - cardinality : int, default 1, number of convolution groups for 3x3 conv in Bottleneck. - base_width : int, default 64, factor determining bottleneck channels. `planes * base_width / 64 * cardinality` - stem_width : int, default 64, number of channels in stem convolutions - stem_type : str, default '' - The type of stem: - * '', default - a single 7x7 conv with a width of stem_width - * 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2 - * 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2 - block_reduce_first : int, default 1 - Reduction factor for first convolution output width of residual blocks, 1 for all archs except senets, where 2 - down_kernel_size : int, default 1, kernel size of residual block downsample path, 1x1 for most, 3x3 for senets - avg_down : bool, default False, use average pooling for projection skip connection between stages/downsample. - act_layer : nn.Module, activation layer - norm_layer : nn.Module, normalization layer - aa_layer : nn.Module, anti-aliasing layer - drop_rate : float, default 0. Dropout probability before classifier, for training """ def __init__( @@ -679,12 +668,45 @@ class ResNet(nn.Module): zero_init_last=True, block_args=None, ): + """ + Args: + block (nn.Module): class for the residual block. Options are BasicBlock, Bottleneck. + layers (List[int]) : number of layers in each block + num_classes (int): number of classification classes (default 1000) + in_chans (int): number of input (color) channels. (default 3) + output_stride (int): output stride of the network, 32, 16, or 8. (default 32) + global_pool (str): Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' (default 'avg') + cardinality (int): number of convolution groups for 3x3 conv in Bottleneck. (default 1) + base_width (int): bottleneck channels factor. `planes * base_width / 64 * cardinality` (default 64) + stem_width (int): number of channels in stem convolutions (default 64) + stem_type (str): The type of stem (default ''): + * '', default - a single 7x7 conv with a width of stem_width + * 'deep' - three 3x3 convolution layers of widths stem_width, stem_width, stem_width * 2 + * 'deep_tiered' - three 3x3 conv layers of widths stem_width//4 * 3, stem_width, stem_width * 2 + replace_stem_pool (bool): replace stem max-pooling layer with a 3x3 stride-2 convolution + block_reduce_first (int): Reduction factor for first convolution output width of residual blocks, + 1 for all archs except senets, where 2 (default 1) + down_kernel_size (int): kernel size of residual block downsample path, + 1x1 for most, 3x3 for senets (default: 1) + avg_down (bool): use avg pooling for projection skip connection between stages/downsample (default False) + act_layer (str, nn.Module): activation layer + norm_layer (str, nn.Module): normalization layer + aa_layer (nn.Module): anti-aliasing layer + drop_rate (float): Dropout probability before classifier, for training (default 0.) + drop_path_rate (float): Stochastic depth drop-path rate (default 0.) + drop_block_rate (float): Drop block rate (default 0.) + zero_init_last (bool): zero-init the last weight in residual path (usually last BN affine weight) + block_args (dict): Extra kwargs to pass through to block module + """ super(ResNet, self).__init__() block_args = block_args or dict() assert output_stride in (8, 16, 32) self.num_classes = num_classes self.drop_rate = drop_rate self.grad_checkpointing = False + + act_layer = get_act_layer(act_layer) + norm_layer = get_norm_layer(norm_layer) # Stem deep_stem = 'deep' in stem_type @@ -823,77 +845,72 @@ def _create_resnet(variant, pretrained=False, **kwargs): def resnet10t(pretrained=False, **kwargs): """Constructs a ResNet-10-T model. """ - model_args = dict( - block=BasicBlock, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs) - return _create_resnet('resnet10t', pretrained, **model_args) + model_args = dict(block=BasicBlock, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True) + return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs)) @register_model def resnet14t(pretrained=False, **kwargs): """Constructs a ResNet-14-T model. """ - model_args = dict( - block=Bottleneck, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs) - return _create_resnet('resnet14t', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[1, 1, 1, 1], stem_width=32, stem_type='deep_tiered', avg_down=True) + return _create_resnet('resnet14t', pretrained, **dict(model_args, **kwargs)) @register_model def resnet18(pretrained=False, **kwargs): """Constructs a ResNet-18 model. """ - model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs) - return _create_resnet('resnet18', pretrained, **model_args) + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2]) + return _create_resnet('resnet18', pretrained, **dict(model_args, **kwargs)) @register_model def resnet18d(pretrained=False, **kwargs): """Constructs a ResNet-18-D model. """ - model_args = dict( - block=BasicBlock, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, **kwargs) - return _create_resnet('resnet18d', pretrained, **model_args) + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnet18d', pretrained, **dict(model_args, **kwargs)) @register_model def resnet34(pretrained=False, **kwargs): """Constructs a ResNet-34 model. """ - model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs) - return _create_resnet('resnet34', pretrained, **model_args) + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3]) + return _create_resnet('resnet34', pretrained, **dict(model_args, **kwargs)) @register_model def resnet34d(pretrained=False, **kwargs): """Constructs a ResNet-34-D model. """ - model_args = dict( - block=BasicBlock, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) - return _create_resnet('resnet34d', pretrained, **model_args) + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnet34d', pretrained, **dict(model_args, **kwargs)) @register_model def resnet26(pretrained=False, **kwargs): """Constructs a ResNet-26 model. """ - model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], **kwargs) - return _create_resnet('resnet26', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2]) + return _create_resnet('resnet26', pretrained, **dict(model_args, **kwargs)) @register_model def resnet26t(pretrained=False, **kwargs): """Constructs a ResNet-26-T model. """ - model_args = dict( - block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs) - return _create_resnet('resnet26t', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep_tiered', avg_down=True) + return _create_resnet('resnet26t', pretrained, **dict(model_args, **kwargs)) @register_model def resnet26d(pretrained=False, **kwargs): """Constructs a ResNet-26-D model. """ - model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True, **kwargs) - return _create_resnet('resnet26d', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnet26d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -901,83 +918,79 @@ def resnet50(pretrained=False, **kwargs): """Constructs a ResNet-50 model. """ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) - return _create_resnet('resnet50', pretrained, **model_args) + return _create_resnet('resnet50', pretrained, **dict(model_args, **kwargs)) @register_model def resnet50d(pretrained=False, **kwargs) -> ResNet: """Constructs a ResNet-50-D model. """ - model_args = dict( - block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) - return _create_resnet('resnet50d', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnet50d', pretrained, **dict(model_args, **kwargs)) @register_model def resnet50t(pretrained=False, **kwargs): """Constructs a ResNet-50-T model. """ - model_args = dict( - block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True, **kwargs) - return _create_resnet('resnet50t', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True) + return _create_resnet('resnet50t', pretrained, **dict(model_args, **kwargs)) @register_model def resnet101(pretrained=False, **kwargs): """Constructs a ResNet-101 model. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs) - return _create_resnet('resnet101', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3]) + return _create_resnet('resnet101', pretrained, **dict(model_args, **kwargs)) @register_model def resnet101d(pretrained=False, **kwargs): """Constructs a ResNet-101-D model. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) - return _create_resnet('resnet101d', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnet101d', pretrained, **dict(model_args, **kwargs)) @register_model def resnet152(pretrained=False, **kwargs): """Constructs a ResNet-152 model. """ - model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs) - return _create_resnet('resnet152', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3]) + return _create_resnet('resnet152', pretrained, **dict(model_args, **kwargs)) @register_model def resnet152d(pretrained=False, **kwargs): """Constructs a ResNet-152-D model. """ - model_args = dict( - block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) - return _create_resnet('resnet152d', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnet152d', pretrained, **dict(model_args, **kwargs)) @register_model def resnet200(pretrained=False, **kwargs): """Constructs a ResNet-200 model. """ - model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3], **kwargs) - return _create_resnet('resnet200', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3]) + return _create_resnet('resnet200', pretrained, **dict(model_args, **kwargs)) @register_model def resnet200d(pretrained=False, **kwargs): """Constructs a ResNet-200-D model. """ - model_args = dict( - block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True, **kwargs) - return _create_resnet('resnet200d', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnet200d', pretrained, **dict(model_args, **kwargs)) @register_model def tv_resnet34(pretrained=False, **kwargs): """Constructs a ResNet-34 model with original Torchvision weights. """ - model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], **kwargs) - return _create_resnet('tv_resnet34', pretrained, **model_args) + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3]) + return _create_resnet('tv_resnet34', pretrained, **dict(model_args, **kwargs)) @register_model @@ -985,23 +998,23 @@ def tv_resnet50(pretrained=False, **kwargs): """Constructs a ResNet-50 model with original Torchvision weights. """ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) - return _create_resnet('tv_resnet50', pretrained, **model_args) + return _create_resnet('tv_resnet50', pretrained, **dict(model_args, **kwargs)) @register_model def tv_resnet101(pretrained=False, **kwargs): """Constructs a ResNet-101 model w/ Torchvision pretrained weights. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], **kwargs) - return _create_resnet('tv_resnet101', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3]) + return _create_resnet('tv_resnet101', pretrained, **dict(model_args, **kwargs)) @register_model def tv_resnet152(pretrained=False, **kwargs): """Constructs a ResNet-152 model w/ Torchvision pretrained weights. """ - model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], **kwargs) - return _create_resnet('tv_resnet152', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3]) + return _create_resnet('tv_resnet152', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1012,8 +1025,8 @@ def wide_resnet50_2(pretrained=False, **kwargs): convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 channels, and in Wide ResNet-50-2 has 2048-1024-2048. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128, **kwargs) - return _create_resnet('wide_resnet50_2', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], base_width=128) + return _create_resnet('wide_resnet50_2', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1023,8 +1036,8 @@ def wide_resnet101_2(pretrained=False, **kwargs): which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], base_width=128, **kwargs) - return _create_resnet('wide_resnet101_2', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], base_width=128) + return _create_resnet('wide_resnet101_2', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1039,8 +1052,8 @@ def resnet50_gn(pretrained=False, **kwargs): def resnext50_32x4d(pretrained=False, **kwargs): """Constructs a ResNeXt50-32x4d model. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) - return _create_resnet('resnext50_32x4d', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4) + return _create_resnet('resnext50_32x4d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1049,40 +1062,40 @@ def resnext50d_32x4d(pretrained=False, **kwargs): """ model_args = dict( block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, - stem_width=32, stem_type='deep', avg_down=True, **kwargs) - return _create_resnet('resnext50d_32x4d', pretrained, **model_args) + stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnext50d_32x4d', pretrained, **dict(model_args, **kwargs)) @register_model def resnext101_32x4d(pretrained=False, **kwargs): """Constructs a ResNeXt-101 32x4d model. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) - return _create_resnet('resnext101_32x4d', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4) + return _create_resnet('resnext101_32x4d', pretrained, **dict(model_args, **kwargs)) @register_model def resnext101_32x8d(pretrained=False, **kwargs): """Constructs a ResNeXt-101 32x8d model. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) - return _create_resnet('resnext101_32x8d', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8) + return _create_resnet('resnext101_32x8d', pretrained, **dict(model_args, **kwargs)) @register_model def resnext101_64x4d(pretrained=False, **kwargs): """Constructs a ResNeXt101-64x4d model. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4, **kwargs) - return _create_resnet('resnext101_64x4d', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=64, base_width=4) + return _create_resnet('resnext101_64x4d', pretrained, **dict(model_args, **kwargs)) @register_model def tv_resnext50_32x4d(pretrained=False, **kwargs): """Constructs a ResNeXt50-32x4d model with original Torchvision weights. """ - model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) - return _create_resnet('tv_resnext50_32x4d', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4) + return _create_resnet('tv_resnext50_32x4d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1092,8 +1105,8 @@ def ig_resnext101_32x8d(pretrained=False, **kwargs): `"Exploring the Limits of Weakly Supervised Pretraining" `_ Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) - return _create_resnet('ig_resnext101_32x8d', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8) + return _create_resnet('ig_resnext101_32x8d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1103,8 +1116,8 @@ def ig_resnext101_32x16d(pretrained=False, **kwargs): `"Exploring the Limits of Weakly Supervised Pretraining" `_ Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs) - return _create_resnet('ig_resnext101_32x16d', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16) + return _create_resnet('ig_resnext101_32x16d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1114,8 +1127,8 @@ def ig_resnext101_32x32d(pretrained=False, **kwargs): `"Exploring the Limits of Weakly Supervised Pretraining" `_ Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=32, **kwargs) - return _create_resnet('ig_resnext101_32x32d', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=32) + return _create_resnet('ig_resnext101_32x32d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1125,8 +1138,8 @@ def ig_resnext101_32x48d(pretrained=False, **kwargs): `"Exploring the Limits of Weakly Supervised Pretraining" `_ Weights from https://pytorch.org/hub/facebookresearch_WSL-Images_resnext/ """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=48, **kwargs) - return _create_resnet('ig_resnext101_32x48d', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=48) + return _create_resnet('ig_resnext101_32x48d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1135,8 +1148,8 @@ def ssl_resnet18(pretrained=False, **kwargs): `"Billion-scale Semi-Supervised Learning for Image Classification" `_ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ - model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs) - return _create_resnet('ssl_resnet18', pretrained, **model_args) + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2]) + return _create_resnet('ssl_resnet18', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1146,7 +1159,7 @@ def ssl_resnet50(pretrained=False, **kwargs): Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) - return _create_resnet('ssl_resnet50', pretrained, **model_args) + return _create_resnet('ssl_resnet50', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1155,8 +1168,8 @@ def ssl_resnext50_32x4d(pretrained=False, **kwargs): `"Billion-scale Semi-Supervised Learning for Image Classification" `_ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ - model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) - return _create_resnet('ssl_resnext50_32x4d', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4) + return _create_resnet('ssl_resnext50_32x4d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1165,8 +1178,8 @@ def ssl_resnext101_32x4d(pretrained=False, **kwargs): `"Billion-scale Semi-Supervised Learning for Image Classification" `_ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) - return _create_resnet('ssl_resnext101_32x4d', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4) + return _create_resnet('ssl_resnext101_32x4d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1175,8 +1188,8 @@ def ssl_resnext101_32x8d(pretrained=False, **kwargs): `"Billion-scale Semi-Supervised Learning for Image Classification" `_ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) - return _create_resnet('ssl_resnext101_32x8d', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8) + return _create_resnet('ssl_resnext101_32x8d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1185,8 +1198,8 @@ def ssl_resnext101_32x16d(pretrained=False, **kwargs): `"Billion-scale Semi-Supervised Learning for Image Classification" `_ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs) - return _create_resnet('ssl_resnext101_32x16d', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16) + return _create_resnet('ssl_resnext101_32x16d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1196,8 +1209,8 @@ def swsl_resnet18(pretrained=False, **kwargs): `"Billion-scale Semi-Supervised Learning for Image Classification" `_ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ - model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], **kwargs) - return _create_resnet('swsl_resnet18', pretrained, **model_args) + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2]) + return _create_resnet('swsl_resnet18', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1208,7 +1221,7 @@ def swsl_resnet50(pretrained=False, **kwargs): Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], **kwargs) - return _create_resnet('swsl_resnet50', pretrained, **model_args) + return _create_resnet('swsl_resnet50', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1218,8 +1231,8 @@ def swsl_resnext50_32x4d(pretrained=False, **kwargs): `"Billion-scale Semi-Supervised Learning for Image Classification" `_ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ - model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, **kwargs) - return _create_resnet('swsl_resnext50_32x4d', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4) + return _create_resnet('swsl_resnext50_32x4d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1229,8 +1242,8 @@ def swsl_resnext101_32x4d(pretrained=False, **kwargs): `"Billion-scale Semi-Supervised Learning for Image Classification" `_ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, **kwargs) - return _create_resnet('swsl_resnext101_32x4d', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4) + return _create_resnet('swsl_resnext101_32x4d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1240,8 +1253,8 @@ def swsl_resnext101_32x8d(pretrained=False, **kwargs): `"Billion-scale Semi-Supervised Learning for Image Classification" `_ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, **kwargs) - return _create_resnet('swsl_resnext101_32x8d', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8) + return _create_resnet('swsl_resnext101_32x8d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1251,8 +1264,8 @@ def swsl_resnext101_32x16d(pretrained=False, **kwargs): `"Billion-scale Semi-Supervised Learning for Image Classification" `_ Weights from https://github.com/facebookresearch/semi-supervised-ImageNet1K-models/ """ - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16, **kwargs) - return _create_resnet('swsl_resnext101_32x16d', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=16) + return _create_resnet('swsl_resnext101_32x16d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1263,8 +1276,8 @@ def ecaresnet26t(pretrained=False, **kwargs): """ model_args = dict( block=Bottleneck, layers=[2, 2, 2, 2], stem_width=32, - stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs) - return _create_resnet('ecaresnet26t', pretrained, **model_args) + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca')) + return _create_resnet('ecaresnet26t', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1273,8 +1286,8 @@ def ecaresnet50d(pretrained=False, **kwargs): """ model_args = dict( block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, - block_args=dict(attn_layer='eca'), **kwargs) - return _create_resnet('ecaresnet50d', pretrained, **model_args) + block_args=dict(attn_layer='eca')) + return _create_resnet('ecaresnet50d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1284,8 +1297,8 @@ def ecaresnet50d_pruned(pretrained=False, **kwargs): """ model_args = dict( block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', avg_down=True, - block_args=dict(attn_layer='eca'), **kwargs) - return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **model_args) + block_args=dict(attn_layer='eca')) + return _create_resnet('ecaresnet50d_pruned', pretrained, pruned=True, **dict(model_args, **kwargs)) @register_model @@ -1295,8 +1308,8 @@ def ecaresnet50t(pretrained=False, **kwargs): """ model_args = dict( block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, - stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs) - return _create_resnet('ecaresnet50t', pretrained, **model_args) + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca')) + return _create_resnet('ecaresnet50t', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1305,8 +1318,8 @@ def ecaresnetlight(pretrained=False, **kwargs): """ model_args = dict( block=Bottleneck, layers=[1, 1, 11, 3], stem_width=32, avg_down=True, - block_args=dict(attn_layer='eca'), **kwargs) - return _create_resnet('ecaresnetlight', pretrained, **model_args) + block_args=dict(attn_layer='eca')) + return _create_resnet('ecaresnetlight', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1315,8 +1328,8 @@ def ecaresnet101d(pretrained=False, **kwargs): """ model_args = dict( block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, - block_args=dict(attn_layer='eca'), **kwargs) - return _create_resnet('ecaresnet101d', pretrained, **model_args) + block_args=dict(attn_layer='eca')) + return _create_resnet('ecaresnet101d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1326,8 +1339,8 @@ def ecaresnet101d_pruned(pretrained=False, **kwargs): """ model_args = dict( block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', avg_down=True, - block_args=dict(attn_layer='eca'), **kwargs) - return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **model_args) + block_args=dict(attn_layer='eca')) + return _create_resnet('ecaresnet101d_pruned', pretrained, pruned=True, **dict(model_args, **kwargs)) @register_model @@ -1336,8 +1349,8 @@ def ecaresnet200d(pretrained=False, **kwargs): """ model_args = dict( block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True, - block_args=dict(attn_layer='eca'), **kwargs) - return _create_resnet('ecaresnet200d', pretrained, **model_args) + block_args=dict(attn_layer='eca')) + return _create_resnet('ecaresnet200d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1346,8 +1359,8 @@ def ecaresnet269d(pretrained=False, **kwargs): """ model_args = dict( block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True, - block_args=dict(attn_layer='eca'), **kwargs) - return _create_resnet('ecaresnet269d', pretrained, **model_args) + block_args=dict(attn_layer='eca')) + return _create_resnet('ecaresnet269d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1358,8 +1371,8 @@ def ecaresnext26t_32x4d(pretrained=False, **kwargs): """ model_args = dict( block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, - stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs) - return _create_resnet('ecaresnext26t_32x4d', pretrained, **model_args) + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca')) + return _create_resnet('ecaresnext26t_32x4d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1370,54 +1383,54 @@ def ecaresnext50t_32x4d(pretrained=False, **kwargs): """ model_args = dict( block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, - stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca'), **kwargs) - return _create_resnet('ecaresnext50t_32x4d', pretrained, **model_args) + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='eca')) + return _create_resnet('ecaresnext50t_32x4d', pretrained, **dict(model_args, **kwargs)) @register_model def seresnet18(pretrained=False, **kwargs): - model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('seresnet18', pretrained, **model_args) + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se')) + return _create_resnet('seresnet18', pretrained, **dict(model_args, **kwargs)) @register_model def seresnet34(pretrained=False, **kwargs): - model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('seresnet34', pretrained, **model_args) + model_args = dict(block=BasicBlock, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se')) + return _create_resnet('seresnet34', pretrained, **dict(model_args, **kwargs)) @register_model def seresnet50(pretrained=False, **kwargs): - model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('seresnet50', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], block_args=dict(attn_layer='se')) + return _create_resnet('seresnet50', pretrained, **dict(model_args, **kwargs)) @register_model def seresnet50t(pretrained=False, **kwargs): model_args = dict( - block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', avg_down=True, - block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('seresnet50t', pretrained, **model_args) + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep_tiered', + avg_down=True, block_args=dict(attn_layer='se')) + return _create_resnet('seresnet50t', pretrained, **dict(model_args, **kwargs)) @register_model def seresnet101(pretrained=False, **kwargs): - model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('seresnet101', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 23, 3], block_args=dict(attn_layer='se')) + return _create_resnet('seresnet101', pretrained, **dict(model_args, **kwargs)) @register_model def seresnet152(pretrained=False, **kwargs): - model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('seresnet152', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 8, 36, 3], block_args=dict(attn_layer='se')) + return _create_resnet('seresnet152', pretrained, **dict(model_args, **kwargs)) @register_model def seresnet152d(pretrained=False, **kwargs): model_args = dict( - block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', avg_down=True, - block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('seresnet152d', pretrained, **model_args) + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', + avg_down=True, block_args=dict(attn_layer='se')) + return _create_resnet('seresnet152d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1425,9 +1438,9 @@ def seresnet200d(pretrained=False, **kwargs): """Constructs a ResNet-200-D model with SE attn. """ model_args = dict( - block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', avg_down=True, - block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('seresnet200d', pretrained, **model_args) + block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', + avg_down=True, block_args=dict(attn_layer='se')) + return _create_resnet('seresnet200d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1435,9 +1448,9 @@ def seresnet269d(pretrained=False, **kwargs): """Constructs a ResNet-269-D model with SE attn. """ model_args = dict( - block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', avg_down=True, - block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('seresnet269d', pretrained, **model_args) + block=Bottleneck, layers=[3, 30, 48, 8], stem_width=32, stem_type='deep', + avg_down=True, block_args=dict(attn_layer='se')) + return _create_resnet('seresnet269d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1448,8 +1461,8 @@ def seresnext26d_32x4d(pretrained=False, **kwargs): """ model_args = dict( block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, - stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('seresnext26d_32x4d', pretrained, **model_args) + stem_type='deep', avg_down=True, block_args=dict(attn_layer='se')) + return _create_resnet('seresnext26d_32x4d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1460,8 +1473,8 @@ def seresnext26t_32x4d(pretrained=False, **kwargs): """ model_args = dict( block=Bottleneck, layers=[2, 2, 2, 2], cardinality=32, base_width=4, stem_width=32, - stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('seresnext26t_32x4d', pretrained, **model_args) + stem_type='deep_tiered', avg_down=True, block_args=dict(attn_layer='se')) + return _create_resnet('seresnext26t_32x4d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1477,24 +1490,24 @@ def seresnext26tn_32x4d(pretrained=False, **kwargs): def seresnext50_32x4d(pretrained=False, **kwargs): model_args = dict( block=Bottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4, - block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('seresnext50_32x4d', pretrained, **model_args) + block_args=dict(attn_layer='se')) + return _create_resnet('seresnext50_32x4d', pretrained, **dict(model_args, **kwargs)) @register_model def seresnext101_32x4d(pretrained=False, **kwargs): model_args = dict( block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=4, - block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('seresnext101_32x4d', pretrained, **model_args) + block_args=dict(attn_layer='se')) + return _create_resnet('seresnext101_32x4d', pretrained, **dict(model_args, **kwargs)) @register_model def seresnext101_32x8d(pretrained=False, **kwargs): model_args = dict( block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, - block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('seresnext101_32x8d', pretrained, **model_args) + block_args=dict(attn_layer='se')) + return _create_resnet('seresnext101_32x8d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1502,32 +1515,32 @@ def seresnext101d_32x8d(pretrained=False, **kwargs): model_args = dict( block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, stem_width=32, stem_type='deep', avg_down=True, - block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('seresnext101d_32x8d', pretrained, **model_args) + block_args=dict(attn_layer='se')) + return _create_resnet('seresnext101d_32x8d', pretrained, **dict(model_args, **kwargs)) @register_model def senet154(pretrained=False, **kwargs): model_args = dict( block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', - down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('senet154', pretrained, **model_args) + down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se')) + return _create_resnet('senet154', pretrained, **dict(model_args, **kwargs)) @register_model def resnetblur18(pretrained=False, **kwargs): """Constructs a ResNet-18 model with blur anti-aliasing """ - model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d, **kwargs) - return _create_resnet('resnetblur18', pretrained, **model_args) + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d) + return _create_resnet('resnetblur18', pretrained, **dict(model_args, **kwargs)) @register_model def resnetblur50(pretrained=False, **kwargs): """Constructs a ResNet-50 model with blur anti-aliasing """ - model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, **kwargs) - return _create_resnet('resnetblur50', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d) + return _create_resnet('resnetblur50', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1536,8 +1549,8 @@ def resnetblur50d(pretrained=False, **kwargs): """ model_args = dict( block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, - stem_width=32, stem_type='deep', avg_down=True, **kwargs) - return _create_resnet('resnetblur50d', pretrained, **model_args) + stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnetblur50d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1546,16 +1559,25 @@ def resnetblur101d(pretrained=False, **kwargs): """ model_args = dict( block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=BlurPool2d, - stem_width=32, stem_type='deep', avg_down=True, **kwargs) - return _create_resnet('resnetblur101d', pretrained, **model_args) + stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnetblur101d', pretrained, **dict(model_args, **kwargs)) + + +@register_model +def resnetaa34d(pretrained=False, **kwargs): + """Constructs a ResNet-34-D model w/ avgpool anti-aliasing + """ + model_args = dict( + block=BasicBlock, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnetaa34d', pretrained, **dict(model_args, **kwargs)) @register_model def resnetaa50(pretrained=False, **kwargs): """Constructs a ResNet-50 model with avgpool anti-aliasing """ - model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, **kwargs) - return _create_resnet('resnetaa50', pretrained, **model_args) + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d) + return _create_resnet('resnetaa50', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1564,8 +1586,8 @@ def resnetaa50d(pretrained=False, **kwargs): """ model_args = dict( block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, - stem_width=32, stem_type='deep', avg_down=True, **kwargs) - return _create_resnet('resnetaa50d', pretrained, **model_args) + stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnetaa50d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1574,8 +1596,8 @@ def resnetaa101d(pretrained=False, **kwargs): """ model_args = dict( block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=nn.AvgPool2d, - stem_width=32, stem_type='deep', avg_down=True, **kwargs) - return _create_resnet('resnetaa101d', pretrained, **model_args) + stem_width=32, stem_type='deep', avg_down=True) + return _create_resnet('resnetaa101d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1584,8 +1606,8 @@ def seresnetaa50d(pretrained=False, **kwargs): """ model_args = dict( block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, - stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('seresnetaa50d', pretrained, **model_args) + stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se')) + return _create_resnet('seresnetaa50d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1595,8 +1617,8 @@ def seresnextaa101d_32x8d(pretrained=False, **kwargs): model_args = dict( block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, stem_width=32, stem_type='deep', avg_down=True, aa_layer=nn.AvgPool2d, - block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('seresnextaa101d_32x8d', pretrained, **model_args) + block_args=dict(attn_layer='se')) + return _create_resnet('seresnextaa101d_32x8d', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1608,8 +1630,8 @@ def resnetrs50(pretrained=False, **kwargs): attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, - avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) - return _create_resnet('resnetrs50', pretrained, **model_args) + avg_down=True, block_args=dict(attn_layer=attn_layer)) + return _create_resnet('resnetrs50', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1621,8 +1643,8 @@ def resnetrs101(pretrained=False, **kwargs): attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, - avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) - return _create_resnet('resnetrs101', pretrained, **model_args) + avg_down=True, block_args=dict(attn_layer=attn_layer)) + return _create_resnet('resnetrs101', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1634,8 +1656,8 @@ def resnetrs152(pretrained=False, **kwargs): attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, - avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) - return _create_resnet('resnetrs152', pretrained, **model_args) + avg_down=True, block_args=dict(attn_layer=attn_layer)) + return _create_resnet('resnetrs152', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1647,8 +1669,8 @@ def resnetrs200(pretrained=False, **kwargs): attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, - avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) - return _create_resnet('resnetrs200', pretrained, **model_args) + avg_down=True, block_args=dict(attn_layer=attn_layer)) + return _create_resnet('resnetrs200', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1660,8 +1682,8 @@ def resnetrs270(pretrained=False, **kwargs): attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, - avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) - return _create_resnet('resnetrs270', pretrained, **model_args) + avg_down=True, block_args=dict(attn_layer=attn_layer)) + return _create_resnet('resnetrs270', pretrained, **dict(model_args, **kwargs)) @@ -1674,8 +1696,8 @@ def resnetrs350(pretrained=False, **kwargs): attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, - avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) - return _create_resnet('resnetrs350', pretrained, **model_args) + avg_down=True, block_args=dict(attn_layer=attn_layer)) + return _create_resnet('resnetrs350', pretrained, **dict(model_args, **kwargs)) @register_model @@ -1687,5 +1709,5 @@ def resnetrs420(pretrained=False, **kwargs): attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, - avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) - return _create_resnet('resnetrs420', pretrained, **model_args) + avg_down=True, block_args=dict(attn_layer=attn_layer)) + return _create_resnet('resnetrs420', pretrained, **dict(model_args, **kwargs)) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index a55f48ac..41e29e12 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -37,7 +37,7 @@ import torch.nn as nn from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0, FilterResponseNormTlu2d, \ - ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d + ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer from ._builder import build_model_with_cfg from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv from ._registry import register_model @@ -276,8 +276,16 @@ class Bottleneck(nn.Module): class DownsampleConv(nn.Module): def __init__( - self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, preact=True, - conv_layer=None, norm_layer=None): + self, + in_chs, + out_chs, + stride=1, + dilation=1, + first_dilation=None, + preact=True, + conv_layer=None, + norm_layer=None, + ): super(DownsampleConv, self).__init__() self.conv = conv_layer(in_chs, out_chs, 1, stride=stride) self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False) @@ -288,8 +296,16 @@ class DownsampleConv(nn.Module): class DownsampleAvg(nn.Module): def __init__( - self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, - preact=True, conv_layer=None, norm_layer=None): + self, + in_chs, + out_chs, + stride=1, + dilation=1, + first_dilation=None, + preact=True, + conv_layer=None, + norm_layer=None, + ): """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.""" super(DownsampleAvg, self).__init__() avg_stride = stride if dilation == 1 else 1 @@ -334,9 +350,18 @@ class ResNetStage(nn.Module): drop_path_rate = block_dpr[block_idx] if block_dpr else 0. stride = stride if block_idx == 0 else 1 self.blocks.add_module(str(block_idx), block_fn( - prev_chs, out_chs, stride=stride, dilation=dilation, bottle_ratio=bottle_ratio, groups=groups, - first_dilation=first_dilation, proj_layer=proj_layer, drop_path_rate=drop_path_rate, - **layer_kwargs, **block_kwargs)) + prev_chs, + out_chs, + stride=stride, + dilation=dilation, + bottle_ratio=bottle_ratio, + groups=groups, + first_dilation=first_dilation, + proj_layer=proj_layer, + drop_path_rate=drop_path_rate, + **layer_kwargs, + **block_kwargs, + )) prev_chs = out_chs first_dilation = dilation proj_layer = None @@ -413,21 +438,49 @@ class ResNetV2(nn.Module): avg_down=False, preact=True, act_layer=nn.ReLU, - conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32), + conv_layer=StdConv2d, drop_rate=0., drop_path_rate=0., zero_init_last=False, ): + """ + Args: + layers (List[int]) : number of layers in each block + channels (List[int]) : number of channels in each block: + num_classes (int): number of classification classes (default 1000) + in_chans (int): number of input (color) channels. (default 3) + global_pool (str): Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' (default 'avg') + output_stride (int): output stride of the network, 32, 16, or 8. (default 32) + width_factor (int): channel (width) multiplication factor + stem_chs (int): stem width (default: 64) + stem_type (str): stem type (default: '' == 7x7) + avg_down (bool): average pooling in residual downsampling (default: False) + preact (bool): pre-activiation (default: True) + act_layer (Union[str, nn.Module]): activation layer + norm_layer (Union[str, nn.Module]): normalization layer + conv_layer (nn.Module): convolution module + drop_rate: classifier dropout rate (default: 0.) + drop_path_rate: stochastic depth rate (default: 0.) + zero_init_last: zero-init last weight in residual path (default: False) + """ super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate wf = width_factor + norm_layer = get_norm_act_layer(norm_layer, act_layer=act_layer) + act_layer = get_act_layer(act_layer) self.feature_info = [] stem_chs = make_div(stem_chs * wf) self.stem = create_resnetv2_stem( - in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer) + in_chans, + stem_chs, + stem_type, + preact, + conv_layer=conv_layer, + norm_layer=norm_layer, + ) stem_feat = ('stem.conv3' if is_stem_deep(stem_type) else 'stem.conv') if preact else 'stem.norm' self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat)) @@ -693,86 +746,83 @@ def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs): @register_model def resnetv2_50(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_50', pretrained=pretrained, - layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs) + model_args = dict(layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d) + return _create_resnetv2('resnetv2_50', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def resnetv2_50d(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_50d', pretrained=pretrained, + model_args = dict( layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, - stem_type='deep', avg_down=True, **kwargs) + stem_type='deep', avg_down=True) + return _create_resnetv2('resnetv2_50d', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def resnetv2_50t(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_50t', pretrained=pretrained, + model_args = dict( layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, - stem_type='tiered', avg_down=True, **kwargs) + stem_type='tiered', avg_down=True) + return _create_resnetv2('resnetv2_50t', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def resnetv2_101(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_101', pretrained=pretrained, - layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs) + model_args = dict(layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d) + return _create_resnetv2('resnetv2_101', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def resnetv2_101d(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_101d', pretrained=pretrained, + model_args = dict( layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, - stem_type='deep', avg_down=True, **kwargs) + stem_type='deep', avg_down=True) + return _create_resnetv2('resnetv2_101d', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def resnetv2_152(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_152', pretrained=pretrained, - layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs) + model_args = dict(layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d) + return _create_resnetv2('resnetv2_152', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def resnetv2_152d(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_152d', pretrained=pretrained, + model_args = dict( layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, - stem_type='deep', avg_down=True, **kwargs) + stem_type='deep', avg_down=True) + return _create_resnetv2('resnetv2_152d', pretrained=pretrained, **dict(model_args, **kwargs)) # Experimental configs (may change / be removed) @register_model def resnetv2_50d_gn(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_50d_gn', pretrained=pretrained, + model_args = dict( layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=GroupNormAct, - stem_type='deep', avg_down=True, **kwargs) + stem_type='deep', avg_down=True) + return _create_resnetv2('resnetv2_50d_gn', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def resnetv2_50d_evob(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_50d_evob', pretrained=pretrained, + model_args = dict( layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dB0, - stem_type='deep', avg_down=True, zero_init_last=True, **kwargs) + stem_type='deep', avg_down=True, zero_init_last=True) + return _create_resnetv2('resnetv2_50d_evob', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def resnetv2_50d_evos(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_50d_evos', pretrained=pretrained, + model_args = dict( layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dS0, - stem_type='deep', avg_down=True, **kwargs) + stem_type='deep', avg_down=True) + return _create_resnetv2('resnetv2_50d_evos', pretrained=pretrained, **dict(model_args, **kwargs)) @register_model def resnetv2_50d_frn(pretrained=False, **kwargs): - return _create_resnetv2( - 'resnetv2_50d_frn', pretrained=pretrained, + model_args = dict( layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=FilterResponseNormTlu2d, - stem_type='deep', avg_down=True, **kwargs) + stem_type='deep', avg_down=True) + return _create_resnetv2('resnetv2_50d_frn', pretrained=pretrained, **dict(model_args, **kwargs)) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index d6865549..d32f9dea 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -697,6 +697,13 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs({ + # re-finetuned augreg 21k FT on in1k weights + 'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg( + hf_hub_id='timm/'), + 'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg(), + 'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg( + hf_hub_id='timm/'), + # How to train your ViT (augreg) weights, pretrained on 21k FT on in1k 'vit_tiny_patch16_224.augreg_in21k_ft_in1k': _cfg( url='https://storage.googleapis.com/vit_models/augreg/Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz', @@ -751,13 +758,6 @@ default_cfgs = generate_default_cfgs({ hf_hub_id='timm/', custom_load=True, input_size=(3, 384, 384), crop_pct=1.0), - # re-finetuned augreg 21k FT on in1k weights - 'vit_base_patch16_224.augreg2_in21k_ft_in1k': _cfg( - hf_hub_id='timm/'), - 'vit_base_patch16_384.augreg2_in21k_ft_in1k': _cfg(), - 'vit_base_patch8_224.augreg2_in21k_ft_in1k': _cfg( - hf_hub_id='timm/'), - # patch models (weights from official Google JAX impl) pretrained on in21k FT on in1k 'vit_base_patch16_224.orig_in21k_ft_in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_224-80ecf9dd.pth', @@ -802,7 +802,6 @@ default_cfgs = generate_default_cfgs({ 'vit_giant_patch14_224.untrained': _cfg(url=''), 'vit_gigantic_patch14_224.untrained': _cfg(url=''), - # patch models, imagenet21k (weights from official Google JAX impl) 'vit_large_patch32_224.orig_in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_patch32_224_in21k-9046d2e7.pth', @@ -869,7 +868,6 @@ default_cfgs = generate_default_cfgs({ hf_hub_id='timm/', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - # ViT ImageNet-21K-P pretraining by MILL 'vit_base_patch16_224_miil.in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth', @@ -880,7 +878,7 @@ default_cfgs = generate_default_cfgs({ hf_hub_id='timm/', mean=(0., 0., 0.), std=(1., 1., 1.), crop_pct=0.875, interpolation='bilinear'), - # custom timm variants + # Custom timm variants 'vit_base_patch16_rpn_224.in1k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth', hf_hub_id='timm/'), @@ -896,52 +894,6 @@ default_cfgs = generate_default_cfgs({ 'vit_base_patch16_gap_224': _cfg(), # CLIP pretrained image tower and related fine-tuned weights - 'vit_base_patch32_clip_224.laion2b': _cfg( - hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K', - hf_hub_filename='open_clip_pytorch_model.bin', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), - 'vit_base_patch16_clip_224.laion2b': _cfg( - #hf_hub_id='laion/CLIP-ViT-B-16-laion2B-s34B-b88K', - hf_hub_filename='open_clip_pytorch_model.bin', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), - 'vit_large_patch14_clip_224.laion2b': _cfg( - hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K', - hf_hub_filename='open_clip_pytorch_model.bin', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768), - 'vit_huge_patch14_clip_224.laion2b': _cfg( - hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K', - hf_hub_filename='open_clip_pytorch_model.bin', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), - 'vit_giant_patch14_clip_224.laion2b': _cfg( - hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K', - hf_hub_filename='open_clip_pytorch_model.bin', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), - - 'vit_base_patch32_clip_224.laion2b_ft_in1k': _cfg( - hf_hub_id='timm/', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), - 'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg( - hf_hub_id='timm/', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), - 'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg( - hf_hub_id='timm/', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), - 'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg( - hf_hub_id='timm/', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0), - 'vit_large_patch14_clip_336.laion2b_ft_in1k': _cfg( - hf_hub_id='timm/', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, - crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), - 'vit_huge_patch14_clip_224.laion2b_ft_in1k': _cfg( - hf_hub_id='timm/', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), - 'vit_huge_patch14_clip_336.laion2b_ft_in1k': _cfg( - hf_hub_id='', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), - 'vit_base_patch32_clip_224.laion2b_ft_in12k_in1k': _cfg( hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), @@ -973,28 +925,52 @@ default_cfgs = generate_default_cfgs({ mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), - 'vit_base_patch32_clip_224.laion2b_ft_in12k': _cfg( - #hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), - 'vit_base_patch16_clip_224.laion2b_ft_in12k': _cfg( + 'vit_base_patch32_clip_224.openai_ft_in12k_in1k': _cfg( + # hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k_in1k', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), + 'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg( hf_hub_id='timm/', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), - 'vit_large_patch14_clip_224.laion2b_ft_in12k': _cfg( + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'), + 'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg( hf_hub_id='timm/', - mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=11821), - 'vit_huge_patch14_clip_224.laion2b_ft_in12k': _cfg( + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95), + 'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg( hf_hub_id='timm/', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'), + 'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), + 'vit_large_patch14_clip_336.openai_ft_in12k_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), - 'vit_base_patch32_clip_224.openai': _cfg( + 'vit_base_patch32_clip_224.laion2b_ft_in1k': _cfg( hf_hub_id='timm/', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), - 'vit_base_patch16_clip_224.openai': _cfg( + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), + 'vit_base_patch16_clip_224.laion2b_ft_in1k': _cfg( hf_hub_id='timm/', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), - 'vit_large_patch14_clip_224.openai': _cfg( + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), + 'vit_base_patch16_clip_384.laion2b_ft_in1k': _cfg( hf_hub_id='timm/', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 384, 384), crop_mode='squash'), + 'vit_large_patch14_clip_224.laion2b_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0), + 'vit_large_patch14_clip_336.laion2b_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), + 'vit_huge_patch14_clip_224.laion2b_ft_in1k': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), + 'vit_huge_patch14_clip_336.laion2b_ft_in1k': _cfg( + hf_hub_id='', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, + crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), 'vit_base_patch32_clip_224.openai_ft_in1k': _cfg( hf_hub_id='timm/', @@ -1010,30 +986,21 @@ default_cfgs = generate_default_cfgs({ hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), - 'vit_base_patch32_clip_224.openai_ft_in12k_in1k': _cfg( - #hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k_in1k', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD), - 'vit_base_patch32_clip_384.openai_ft_in12k_in1k': _cfg( - hf_hub_id='timm/', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'), - 'vit_base_patch16_clip_224.openai_ft_in12k_in1k': _cfg( - hf_hub_id='timm/', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=0.95), - 'vit_base_patch16_clip_384.openai_ft_in12k_in1k': _cfg( + 'vit_base_patch32_clip_224.laion2b_ft_in12k': _cfg( + #hf_hub_id='timm/vit_base_patch32_clip_224.laion2b_ft_in12k', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), + 'vit_base_patch16_clip_224.laion2b_ft_in12k': _cfg( hf_hub_id='timm/', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - crop_pct=0.95, input_size=(3, 384, 384), crop_mode='squash'), - 'vit_large_patch14_clip_224.openai_ft_in12k_in1k': _cfg( + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), + 'vit_large_patch14_clip_224.laion2b_ft_in12k': _cfg( hf_hub_id='timm/', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0), - 'vit_large_patch14_clip_336.openai_ft_in12k_in1k': _cfg( + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=11821), + 'vit_huge_patch14_clip_224.laion2b_ft_in12k': _cfg( hf_hub_id='timm/', - mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, - crop_pct=1.0, input_size=(3, 336, 336), crop_mode='squash'), + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821), 'vit_base_patch32_clip_224.openai_ft_in12k': _cfg( - #hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k', + # hf_hub_id='timm/vit_base_patch32_clip_224.openai_ft_in12k', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=11821), 'vit_base_patch16_clip_224.openai_ft_in12k': _cfg( hf_hub_id='timm/', @@ -1042,6 +1009,41 @@ default_cfgs = generate_default_cfgs({ hf_hub_id='timm/', mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=11821), + 'vit_base_patch32_clip_224.laion2b': _cfg( + hf_hub_id='laion/CLIP-ViT-B-32-laion2B-s34B-b79K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + 'vit_base_patch16_clip_224.laion2b': _cfg( + # hf_hub_id='laion/CLIP-ViT-B-16-laion2B-s34B-b88K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=512), + 'vit_large_patch14_clip_224.laion2b': _cfg( + hf_hub_id='laion/CLIP-ViT-L-14-laion2B-s32B-b82K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD, crop_pct=1.0, num_classes=768), + 'vit_huge_patch14_clip_224.laion2b': _cfg( + hf_hub_id='laion/CLIP-ViT-H-14-laion2B-s32B-b79K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), + 'vit_giant_patch14_clip_224.laion2b': _cfg( + hf_hub_id='laion/CLIP-ViT-g-14-laion2B-s12B-b42K', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1024), + 'vit_gigantic_patch14_clip_224.laion2b': _cfg( + hf_hub_id='laion/CLIP-ViT-bigG-14-laion2B-39B-b160k', + hf_hub_filename='open_clip_pytorch_model.bin', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=1280), + + 'vit_base_patch32_clip_224.openai': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + 'vit_base_patch16_clip_224.openai': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + 'vit_large_patch14_clip_224.openai': _cfg( + hf_hub_id='timm/', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0, num_classes=768), + # experimental (may be removed) 'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), 'vit_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95), @@ -1152,8 +1154,8 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs): def vit_tiny_patch16_224(pretrained=False, **kwargs): """ ViT-Tiny (Vit-Ti/16) """ - model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) - model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3) + model = _create_vision_transformer('vit_tiny_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1161,8 +1163,8 @@ def vit_tiny_patch16_224(pretrained=False, **kwargs): def vit_tiny_patch16_384(pretrained=False, **kwargs): """ ViT-Tiny (Vit-Ti/16) @ 384x384. """ - model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3, **kwargs) - model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3) + model = _create_vision_transformer('vit_tiny_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1170,8 +1172,8 @@ def vit_tiny_patch16_384(pretrained=False, **kwargs): def vit_small_patch32_224(pretrained=False, **kwargs): """ ViT-Small (ViT-S/32) """ - model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer('vit_small_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1179,8 +1181,8 @@ def vit_small_patch32_224(pretrained=False, **kwargs): def vit_small_patch32_384(pretrained=False, **kwargs): """ ViT-Small (ViT-S/32) at 384x384. """ - model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=32, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer('vit_small_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1188,8 +1190,8 @@ def vit_small_patch32_384(pretrained=False, **kwargs): def vit_small_patch16_224(pretrained=False, **kwargs): """ ViT-Small (ViT-S/16) """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1197,8 +1199,8 @@ def vit_small_patch16_224(pretrained=False, **kwargs): def vit_small_patch16_384(pretrained=False, **kwargs): """ ViT-Small (ViT-S/16) """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer('vit_small_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1206,8 +1208,8 @@ def vit_small_patch16_384(pretrained=False, **kwargs): def vit_small_patch8_224(pretrained=False, **kwargs): """ ViT-Small (ViT-S/8) """ - model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6, **kwargs) - model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=8, embed_dim=384, depth=12, num_heads=6) + model = _create_vision_transformer('vit_small_patch8_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1216,8 +1218,8 @@ def vit_base_patch32_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer('vit_base_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1226,8 +1228,8 @@ def vit_base_patch32_384(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer('vit_base_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1236,8 +1238,8 @@ def vit_base_patch16_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer('vit_base_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1246,8 +1248,8 @@ def vit_base_patch16_384(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer('vit_base_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1256,8 +1258,8 @@ def vit_base_patch8_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/8) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) - model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12) + model = _create_vision_transformer('vit_base_patch8_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1265,8 +1267,8 @@ def vit_base_patch8_224(pretrained=False, **kwargs): def vit_large_patch32_224(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). No pretrained weights. """ - model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer('vit_large_patch32_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1275,8 +1277,8 @@ def vit_large_patch32_384(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/32) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer('vit_large_patch32_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1285,8 +1287,8 @@ def vit_large_patch16_224(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 224x224, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer('vit_large_patch16_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1295,8 +1297,8 @@ def vit_large_patch16_384(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/16) from original paper (https://arxiv.org/abs/2010.11929). ImageNet-1k weights fine-tuned from in21k @ 384x384, source https://github.com/google-research/vision_transformer. """ - model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer('vit_large_patch16_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1304,8 +1306,8 @@ def vit_large_patch16_384(pretrained=False, **kwargs): def vit_large_patch14_224(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/14) """ - model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16) + model = _create_vision_transformer('vit_large_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1313,8 +1315,8 @@ def vit_large_patch14_224(pretrained=False, **kwargs): def vit_huge_patch14_224(pretrained=False, **kwargs): """ ViT-Huge model (ViT-H/14) from original paper (https://arxiv.org/abs/2010.11929). """ - model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16) + model = _create_vision_transformer('vit_huge_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1322,8 +1324,8 @@ def vit_huge_patch14_224(pretrained=False, **kwargs): def vit_giant_patch14_224(pretrained=False, **kwargs): """ ViT-Giant (little-g) model (ViT-g/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 """ - model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16) + model = _create_vision_transformer('vit_giant_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1331,8 +1333,9 @@ def vit_giant_patch14_224(pretrained=False, **kwargs): def vit_gigantic_patch14_224(pretrained=False, **kwargs): """ ViT-Gigantic (big-G) model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 """ - model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, **kwargs) - model = _create_vision_transformer('vit_gigantic_patch14_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16) + model = _create_vision_transformer( + 'vit_gigantic_patch14_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1341,8 +1344,9 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) from original paper (https://arxiv.org/abs/2010.11929). Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs) - model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False) + model = _create_vision_transformer( + 'vit_base_patch16_224_miil', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1352,8 +1356,9 @@ def vit_medium_patch16_gap_240(pretrained=False, **kwargs): """ model_kwargs = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, - global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) - model = _create_vision_transformer('vit_medium_patch16_gap_240', pretrained=pretrained, **model_kwargs) + global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False) + model = _create_vision_transformer( + 'vit_medium_patch16_gap_240', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1363,8 +1368,9 @@ def vit_medium_patch16_gap_256(pretrained=False, **kwargs): """ model_kwargs = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, - global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) - model = _create_vision_transformer('vit_medium_patch16_gap_256', pretrained=pretrained, **model_kwargs) + global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False) + model = _create_vision_transformer( + 'vit_medium_patch16_gap_256', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1374,8 +1380,9 @@ def vit_medium_patch16_gap_384(pretrained=False, **kwargs): """ model_kwargs = dict( patch_size=16, embed_dim=512, depth=12, num_heads=8, class_token=False, - global_pool=kwargs.get('global_pool', 'avg'), qkv_bias=False, init_values=1e-6, fc_norm=False, **kwargs) - model = _create_vision_transformer('vit_medium_patch16_gap_384', pretrained=pretrained, **model_kwargs) + global_pool='avg', qkv_bias=False, init_values=1e-6, fc_norm=False) + model = _create_vision_transformer( + 'vit_medium_patch16_gap_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1384,9 +1391,9 @@ def vit_base_patch16_gap_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256 """ model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, - global_pool=kwargs.get('global_pool', 'avg'), fc_norm=False, **kwargs) - model = _create_vision_transformer('vit_base_patch16_gap_224', pretrained=pretrained, **model_kwargs) + patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool='avg', fc_norm=False) + model = _create_vision_transformer( + 'vit_base_patch16_gap_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1395,8 +1402,9 @@ def vit_base_patch32_clip_224(pretrained=False, **kwargs): """ ViT-B/32 CLIP image tower @ 224x224 """ model_kwargs = dict( - patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_base_patch32_clip_224', pretrained=pretrained, **model_kwargs) + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch32_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1405,8 +1413,9 @@ def vit_base_patch32_clip_384(pretrained=False, **kwargs): """ ViT-B/32 CLIP image tower @ 384x384 """ model_kwargs = dict( - patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_base_patch32_clip_384', pretrained=pretrained, **model_kwargs) + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch32_clip_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1415,8 +1424,9 @@ def vit_base_patch32_clip_448(pretrained=False, **kwargs): """ ViT-B/32 CLIP image tower @ 448x448 """ model_kwargs = dict( - patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_base_patch32_clip_448', pretrained=pretrained, **model_kwargs) + patch_size=32, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch32_clip_448', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1424,9 +1434,9 @@ def vit_base_patch32_clip_448(pretrained=False, **kwargs): def vit_base_patch16_clip_224(pretrained=False, **kwargs): """ ViT-B/16 CLIP image tower """ - model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_base_patch16_clip_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch16_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1434,9 +1444,9 @@ def vit_base_patch16_clip_224(pretrained=False, **kwargs): def vit_base_patch16_clip_384(pretrained=False, **kwargs): """ ViT-B/16 CLIP image tower @ 384x384 """ - model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_base_patch16_clip_384', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_base_patch16_clip_384', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1444,9 +1454,9 @@ def vit_base_patch16_clip_384(pretrained=False, **kwargs): def vit_large_patch14_clip_224(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/14) CLIP image tower """ - model_kwargs = dict( - patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_large_patch14_clip_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_large_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1454,9 +1464,9 @@ def vit_large_patch14_clip_224(pretrained=False, **kwargs): def vit_large_patch14_clip_336(pretrained=False, **kwargs): """ ViT-Large model (ViT-L/14) CLIP image tower @ 336x336 """ - model_kwargs = dict( - patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_large_patch14_clip_336', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_large_patch14_clip_336', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1464,9 +1474,9 @@ def vit_large_patch14_clip_336(pretrained=False, **kwargs): def vit_huge_patch14_clip_224(pretrained=False, **kwargs): """ ViT-Huge model (ViT-H/14) CLIP image tower. """ - model_kwargs = dict( - patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_huge_patch14_clip_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_huge_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1474,9 +1484,9 @@ def vit_huge_patch14_clip_224(pretrained=False, **kwargs): def vit_huge_patch14_clip_336(pretrained=False, **kwargs): """ ViT-Huge model (ViT-H/14) CLIP image tower @ 336x336 """ - model_kwargs = dict( - patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_huge_patch14_clip_336', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_huge_patch14_clip_336', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1486,20 +1496,32 @@ def vit_giant_patch14_clip_224(pretrained=False, **kwargs): Pretrained weights from CLIP image tower. """ model_kwargs = dict( - patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, - pre_norm=True, norm_layer=nn.LayerNorm, **kwargs) - model = _create_vision_transformer('vit_giant_patch14_clip_224', pretrained=pretrained, **model_kwargs) + patch_size=14, embed_dim=1408, mlp_ratio=48/11, depth=40, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_giant_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model +@register_model +def vit_gigantic_patch14_clip_224(pretrained=False, **kwargs): + """ ViT-bigG model (ViT-G/14) from `Scaling Vision Transformers` - https://arxiv.org/abs/2106.04560 + Pretrained weights from CLIP image tower. + """ + model_kwargs = dict( + patch_size=14, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_gigantic_patch14_clip_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) + return model + # Experimental models below @register_model def vit_base_patch32_plus_256(pretrained=False, **kwargs): """ ViT-Base (ViT-B/32+) """ - model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs) - model = _create_vision_transformer('vit_base_patch32_plus_256', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5) + model = _create_vision_transformer( + 'vit_base_patch32_plus_256', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1507,8 +1529,9 @@ def vit_base_patch32_plus_256(pretrained=False, **kwargs): def vit_base_patch16_plus_240(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16+) """ - model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs) - model = _create_vision_transformer('vit_base_patch16_plus_240', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5) + model = _create_vision_transformer( + 'vit_base_patch16_plus_240', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1517,9 +1540,10 @@ def vit_base_patch16_rpn_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/ residual post-norm """ model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, class_token=False, - block_fn=ResPostBlock, global_pool=kwargs.pop('global_pool', 'avg'), **kwargs) - model = _create_vision_transformer('vit_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs) + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, + class_token=False, block_fn=ResPostBlock, global_pool='avg') + model = _create_vision_transformer( + 'vit_base_patch16_rpn_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1529,8 +1553,9 @@ def vit_small_patch16_36x1_224(pretrained=False, **kwargs): Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5, **kwargs) - model = _create_vision_transformer('vit_small_patch16_36x1_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=384, depth=36, num_heads=6, init_values=1e-5) + model = _create_vision_transformer( + 'vit_small_patch16_36x1_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1541,8 +1566,9 @@ def vit_small_patch16_18x2_224(pretrained=False, **kwargs): Paper focuses on 24x2 + 48x1 for 'Small' width but those are extremely slow. """ model_kwargs = dict( - patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock, **kwargs) - model = _create_vision_transformer('vit_small_patch16_18x2_224', pretrained=pretrained, **model_kwargs) + patch_size=16, embed_dim=384, depth=18, num_heads=6, init_values=1e-5, block_fn=ParallelBlock) + model = _create_vision_transformer( + 'vit_small_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1551,27 +1577,26 @@ def vit_base_patch16_18x2_224(pretrained=False, **kwargs): """ ViT-Base w/ LayerScale + 18 x 2 (36 block parallel) config. Experimental, may remove. Based on `Three things everyone should know about Vision Transformers` - https://arxiv.org/abs/2203.09795 """ - model_kwargs = dict( - patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock, **kwargs) - model = _create_vision_transformer('vit_base_patch16_18x2_224', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=18, num_heads=12, init_values=1e-5, block_fn=ParallelBlock) + model = _create_vision_transformer( + 'vit_base_patch16_18x2_224', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def eva_large_patch14_196(pretrained=False, **kwargs): """ EVA-large model https://arxiv.org/abs/2211.07636 /via MAE MIM pretrain""" - model_kwargs = dict( - patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs) - model = _create_vision_transformer('eva_large_patch14_196', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg') + model = _create_vision_transformer( + 'eva_large_patch14_196', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @register_model def eva_large_patch14_336(pretrained=False, **kwargs): """ EVA-large model https://arxiv.org/abs/2211.07636 via MAE MIM pretrain""" - model_kwargs = dict( - patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg', **kwargs) - model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, global_pool='avg') + model = _create_vision_transformer('eva_large_patch14_336', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1579,8 +1604,8 @@ def eva_large_patch14_336(pretrained=False, **kwargs): def flexivit_small(pretrained=False, **kwargs): """ FlexiViT-Small """ - model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, **kwargs) - model = _create_vision_transformer('flexivit_small', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True) + model = _create_vision_transformer('flexivit_small', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1588,8 +1613,8 @@ def flexivit_small(pretrained=False, **kwargs): def flexivit_base(pretrained=False, **kwargs): """ FlexiViT-Base """ - model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, **kwargs) - model = _create_vision_transformer('flexivit_base', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True) + model = _create_vision_transformer('flexivit_base', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model @@ -1597,6 +1622,6 @@ def flexivit_base(pretrained=False, **kwargs): def flexivit_large(pretrained=False, **kwargs): """ FlexiViT-Large """ - model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, **kwargs) - model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **model_kwargs) + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True) + model = _create_vision_transformer('flexivit_large', pretrained=pretrained, **dict(model_kwargs, **kwargs)) return model diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index bf0e4f89..8aea5802 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -181,8 +181,18 @@ class SequentialAppendList(nn.Sequential): class OsaBlock(nn.Module): def __init__( - self, in_chs, mid_chs, out_chs, layer_per_block, residual=False, - depthwise=False, attn='', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path=None): + self, + in_chs, + mid_chs, + out_chs, + layer_per_block, + residual=False, + depthwise=False, + attn='', + norm_layer=BatchNormAct2d, + act_layer=nn.ReLU, + drop_path=None, + ): super(OsaBlock, self).__init__() self.residual = residual @@ -232,9 +242,20 @@ class OsaBlock(nn.Module): class OsaStage(nn.Module): def __init__( - self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block, downsample=True, - residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, - drop_path_rates=None): + self, + in_chs, + mid_chs, + out_chs, + block_per_stage, + layer_per_block, + downsample=True, + residual=True, + depthwise=False, + attn='ese', + norm_layer=BatchNormAct2d, + act_layer=nn.ReLU, + drop_path_rates=None, + ): super(OsaStage, self).__init__() self.grad_checkpointing = False @@ -270,16 +291,38 @@ class OsaStage(nn.Module): class VovNet(nn.Module): def __init__( - self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4, - output_stride=32, norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path_rate=0.): - """ VovNet (v2) + self, + cfg, + in_chans=3, + num_classes=1000, + global_pool='avg', + output_stride=32, + norm_layer=BatchNormAct2d, + act_layer=nn.ReLU, + drop_rate=0., + drop_path_rate=0., + **kwargs, + ): + """ + Args: + cfg (dict): Model architecture configuration + in_chans (int): Number of input channels (default: 3) + num_classes (int): Number of classifier classes (default: 1000) + global_pool (str): Global pooling type (default: 'avg') + output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32) + norm_layer (Union[str, nn.Module]): normalization layer + act_layer (Union[str, nn.Module]): activation layer + drop_rate (float): Dropout rate (default: 0.) + drop_path_rate (float): Stochastic depth drop-path rate (default: 0.) + kwargs (dict): Extra kwargs overlayed onto cfg """ super(VovNet, self).__init__() self.num_classes = num_classes self.drop_rate = drop_rate - assert stem_stride in (4, 2) assert output_stride == 32 # FIXME support dilation + cfg = dict(cfg, **kwargs) + stem_stride = cfg.get("stem_stride", 4) stem_chs = cfg["stem_chs"] stage_conv_chs = cfg["stage_conv_chs"] stage_out_chs = cfg["stage_out_chs"] @@ -307,9 +350,15 @@ class VovNet(nn.Module): for i in range(4): # num_stages downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4 stages += [OsaStage( - in_ch_list[i], stage_conv_chs[i], stage_out_chs[i], block_per_stage[i], layer_per_block, - downsample=downsample, drop_path_rates=stage_dpr[i], **stage_args) - ] + in_ch_list[i], + stage_conv_chs[i], + stage_out_chs[i], + block_per_stage[i], + layer_per_block, + downsample=downsample, + drop_path_rates=stage_dpr[i], + **stage_args, + )] self.num_features = stage_out_chs[i] current_stride *= 2 if downsample else 1 self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')] @@ -324,7 +373,6 @@ class VovNet(nn.Module): elif isinstance(m, nn.Linear): nn.init.zeros_(m.bias) - @torch.jit.ignore def group_matcher(self, coarse=False): return dict( diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index e3348e64..6bb7085f 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -216,7 +216,7 @@ class XceptionAligned(nn.Module): num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))] self.act = act_layer(inplace=True) if preact else nn.Identity() self.head = ClassifierHead( - in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + in_features=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) @torch.jit.ignore def group_matcher(self, coarse=False): diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index a9ff0c78..7727adff 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -8,7 +8,7 @@ from .distributed import distribute_bn, reduce_tensor, init_distributed_device,\ from .jit import set_jit_legacy, set_jit_fuser from .log import setup_default_logging, FormatterNoInfo from .metrics import AverageMeter, accuracy -from .misc import natural_key, add_bool_arg +from .misc import natural_key, add_bool_arg, ParseKwargs from .model import unwrap_model, get_state_dict, freeze, unfreeze from .model_ema import ModelEma, ModelEmaV2 from .random import random_seed diff --git a/timm/utils/misc.py b/timm/utils/misc.py index 39c0097c..326a50f7 100644 --- a/timm/utils/misc.py +++ b/timm/utils/misc.py @@ -2,6 +2,8 @@ Hacked together by / Copyright 2020 Ross Wightman """ +import argparse +import ast import re @@ -16,3 +18,15 @@ def add_bool_arg(parser, name, default=False, help=''): group.add_argument('--' + name, dest=dest_name, action='store_true', help=help) group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help) parser.set_defaults(**{dest_name: default}) + + +class ParseKwargs(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + kw = {} + for value in values: + key, value = value.split('=') + try: + kw[key] = ast.literal_eval(value) + except ValueError: + kw[key] = str(value) # fallback to string (avoid need to escape on command line) + setattr(namespace, self.dest, kw) diff --git a/timm/utils/model.py b/timm/utils/model.py index b95c4539..d74ee5b7 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -7,6 +7,8 @@ import fnmatch import torch from torchvision.ops.misc import FrozenBatchNorm2d +from timm.layers import BatchNormAct2d, SyncBatchNormAct, FrozenBatchNormAct2d,\ + freeze_batch_norm_2d, unfreeze_batch_norm_2d from .model_ema import ModelEma @@ -100,70 +102,6 @@ def extract_spp_stats( return hook.stats -def freeze_batch_norm_2d(module): - """ - Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is - itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and - returned. Otherwise, the module is walked recursively and submodules are converted in place. - - Args: - module (torch.nn.Module): Any PyTorch module. - - Returns: - torch.nn.Module: Resulting module - - Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 - """ - res = module - if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): - res = FrozenBatchNorm2d(module.num_features) - res.num_features = module.num_features - res.affine = module.affine - if module.affine: - res.weight.data = module.weight.data.clone().detach() - res.bias.data = module.bias.data.clone().detach() - res.running_mean.data = module.running_mean.data - res.running_var.data = module.running_var.data - res.eps = module.eps - else: - for name, child in module.named_children(): - new_child = freeze_batch_norm_2d(child) - if new_child is not child: - res.add_module(name, new_child) - return res - - -def unfreeze_batch_norm_2d(module): - """ - Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance - of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked - recursively and submodules are converted in place. - - Args: - module (torch.nn.Module): Any PyTorch module. - - Returns: - torch.nn.Module: Resulting module - - Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 - """ - res = module - if isinstance(module, FrozenBatchNorm2d): - res = torch.nn.BatchNorm2d(module.num_features) - if module.affine: - res.weight.data = module.weight.data.clone().detach() - res.bias.data = module.bias.data.clone().detach() - res.running_mean.data = module.running_mean.data - res.running_var.data = module.running_var.data - res.eps = module.eps - else: - for name, child in module.named_children(): - new_child = unfreeze_batch_norm_2d(child) - if new_child is not child: - res.add_module(name, new_child) - return res - - def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'): """ Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is @@ -179,7 +117,12 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, """ assert mode in ["freeze", "unfreeze"], '`mode` must be one of "freeze" or "unfreeze"' - if isinstance(root_module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + if isinstance(root_module, ( + torch.nn.modules.batchnorm.BatchNorm2d, + torch.nn.modules.batchnorm.SyncBatchNorm, + BatchNormAct2d, + SyncBatchNormAct, + )): # Raise assertion here because we can't convert it in place raise AssertionError( "You have provided a batch norm layer as the `root module`. Please use " @@ -213,13 +156,18 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, # It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't # convert it in place, but will return the converted result. In this case `res` holds the converted # result and we may try to re-assign the named module - if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + if isinstance(m, ( + torch.nn.modules.batchnorm.BatchNorm2d, + torch.nn.modules.batchnorm.SyncBatchNorm, + BatchNormAct2d, + SyncBatchNormAct, + )): _add_submodule(root_module, n, res) # Unfreeze batch norm else: res = unfreeze_batch_norm_2d(m) # Ditto. See note above in mode == 'freeze' branch - if isinstance(m, FrozenBatchNorm2d): + if isinstance(m, (FrozenBatchNorm2d, FrozenBatchNormAct2d)): _add_submodule(root_module, n, res) diff --git a/timm/version.py b/timm/version.py index b110e6cc..e2ac9a76 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.8.4dev0' +__version__ = '0.8.8dev0' diff --git a/train.py b/train.py index e51d7c90..9f450ab8 100755 --- a/train.py +++ b/train.py @@ -89,56 +89,58 @@ parser.add_argument('--data-dir', metavar='DIR', parser.add_argument('--dataset', metavar='NAME', default='', help='dataset type + name ("/") (default: ImageFolder or ImageTar if empty)') group.add_argument('--train-split', metavar='NAME', default='train', - help='dataset train split (default: train)') + help='dataset train split (default: train)') group.add_argument('--val-split', metavar='NAME', default='validation', - help='dataset validation split (default: validation)') + help='dataset validation split (default: validation)') group.add_argument('--dataset-download', action='store_true', default=False, - help='Allow download of dataset for torch/ and tfds/ datasets that support it.') + help='Allow download of dataset for torch/ and tfds/ datasets that support it.') group.add_argument('--class-map', default='', type=str, metavar='FILENAME', - help='path to class to idx mapping file (default: "")') + help='path to class to idx mapping file (default: "")') # Model parameters group = parser.add_argument_group('Model parameters') group.add_argument('--model', default='resnet50', type=str, metavar='MODEL', - help='Name of model to train (default: "resnet50")') + help='Name of model to train (default: "resnet50")') group.add_argument('--pretrained', action='store_true', default=False, - help='Start with pretrained version of specified network (if avail)') + help='Start with pretrained version of specified network (if avail)') group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', - help='Initialize model from this checkpoint (default: none)') + help='Initialize model from this checkpoint (default: none)') group.add_argument('--resume', default='', type=str, metavar='PATH', - help='Resume full model and optimizer state from checkpoint (default: none)') + help='Resume full model and optimizer state from checkpoint (default: none)') group.add_argument('--no-resume-opt', action='store_true', default=False, - help='prevent resume of optimizer state when resuming model') + help='prevent resume of optimizer state when resuming model') group.add_argument('--num-classes', type=int, default=None, metavar='N', - help='number of label classes (Model default if None)') + help='number of label classes (Model default if None)') group.add_argument('--gp', default=None, type=str, metavar='POOL', - help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') + help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') group.add_argument('--img-size', type=int, default=None, metavar='N', - help='Image size (default: None => model default)') + help='Image size (default: None => model default)') group.add_argument('--in-chans', type=int, default=None, metavar='N', - help='Image input channels (default: None => 3)') + help='Image input channels (default: None => 3)') group.add_argument('--input-size', default=None, nargs=3, type=int, - metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') + metavar='N N N', + help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') group.add_argument('--crop-pct', default=None, type=float, - metavar='N', help='Input image center crop percent (for validation only)') + metavar='N', help='Input image center crop percent (for validation only)') group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', - help='Override mean pixel value of dataset') + help='Override mean pixel value of dataset') group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', - help='Override std deviation of dataset') + help='Override std deviation of dataset') group.add_argument('--interpolation', default='', type=str, metavar='NAME', - help='Image resize interpolation type (overrides model)') + help='Image resize interpolation type (overrides model)') group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', - help='Input batch size for training (default: 128)') + help='Input batch size for training (default: 128)') group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N', - help='Validation batch size override (default: None)') + help='Validation batch size override (default: None)') group.add_argument('--channels-last', action='store_true', default=False, - help='Use channels_last memory layout') + help='Use channels_last memory layout') group.add_argument('--fuser', default='', type=str, - help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") + help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") group.add_argument('--grad-checkpointing', action='store_true', default=False, - help='Enable gradient checkpointing through model blocks/stages') + help='Enable gradient checkpointing through model blocks/stages') group.add_argument('--fast-norm', default=False, action='store_true', - help='enable experimental fast-norm') + help='enable experimental fast-norm') +group.add_argument('--model-kwargs', nargs='*', default={}, action=utils.ParseKwargs) scripting_group = group.add_mutually_exclusive_group() scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true', @@ -151,199 +153,200 @@ scripting_group.add_argument('--aot-autograd', default=False, action='store_true # Optimizer parameters group = parser.add_argument_group('Optimizer parameters') group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', - help='Optimizer (default: "sgd")') + help='Optimizer (default: "sgd")') group.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', - help='Optimizer Epsilon (default: None, use opt default)') + help='Optimizer Epsilon (default: None, use opt default)') group.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', - help='Optimizer Betas (default: None, use opt default)') + help='Optimizer Betas (default: None, use opt default)') group.add_argument('--momentum', type=float, default=0.9, metavar='M', - help='Optimizer momentum (default: 0.9)') + help='Optimizer momentum (default: 0.9)') group.add_argument('--weight-decay', type=float, default=2e-5, - help='weight decay (default: 2e-5)') + help='weight decay (default: 2e-5)') group.add_argument('--clip-grad', type=float, default=None, metavar='NORM', - help='Clip gradient norm (default: None, no clipping)') + help='Clip gradient norm (default: None, no clipping)') group.add_argument('--clip-mode', type=str, default='norm', - help='Gradient clipping mode. One of ("norm", "value", "agc")') + help='Gradient clipping mode. One of ("norm", "value", "agc")') group.add_argument('--layer-decay', type=float, default=None, - help='layer-wise learning rate decay (default: None)') + help='layer-wise learning rate decay (default: None)') +group.add_argument('--opt-kwargs', nargs='*', default={}, action=utils.ParseKwargs) # Learning rate schedule parameters group = parser.add_argument_group('Learning rate schedule parameters') group.add_argument('--sched', type=str, default='cosine', metavar='SCHEDULER', - help='LR scheduler (default: "step"') + help='LR scheduler (default: "step"') group.add_argument('--sched-on-updates', action='store_true', default=False, - help='Apply LR scheduler step on update instead of epoch end.') + help='Apply LR scheduler step on update instead of epoch end.') group.add_argument('--lr', type=float, default=None, metavar='LR', - help='learning rate, overrides lr-base if set (default: None)') + help='learning rate, overrides lr-base if set (default: None)') group.add_argument('--lr-base', type=float, default=0.1, metavar='LR', - help='base learning rate: lr = lr_base * global_batch_size / base_size') + help='base learning rate: lr = lr_base * global_batch_size / base_size') group.add_argument('--lr-base-size', type=int, default=256, metavar='DIV', - help='base learning rate batch size (divisor, default: 256).') + help='base learning rate batch size (divisor, default: 256).') group.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE', - help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)') + help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)') group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', - help='learning rate noise on/off epoch percentages') + help='learning rate noise on/off epoch percentages') group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', - help='learning rate noise limit percent (default: 0.67)') + help='learning rate noise limit percent (default: 0.67)') group.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', - help='learning rate noise std-dev (default: 1.0)') + help='learning rate noise std-dev (default: 1.0)') group.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', - help='learning rate cycle len multiplier (default: 1.0)') + help='learning rate cycle len multiplier (default: 1.0)') group.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT', - help='amount to decay each learning rate cycle (default: 0.5)') + help='amount to decay each learning rate cycle (default: 0.5)') group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', - help='learning rate cycle limit, cycles enabled if > 1') + help='learning rate cycle limit, cycles enabled if > 1') group.add_argument('--lr-k-decay', type=float, default=1.0, - help='learning rate k-decay for cosine/poly (default: 1.0)') + help='learning rate k-decay for cosine/poly (default: 1.0)') group.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR', - help='warmup learning rate (default: 1e-5)') + help='warmup learning rate (default: 1e-5)') group.add_argument('--min-lr', type=float, default=0, metavar='LR', - help='lower lr bound for cyclic schedulers that hit 0 (default: 0)') + help='lower lr bound for cyclic schedulers that hit 0 (default: 0)') group.add_argument('--epochs', type=int, default=300, metavar='N', - help='number of epochs to train (default: 300)') + help='number of epochs to train (default: 300)') group.add_argument('--epoch-repeats', type=float, default=0., metavar='N', - help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') + help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') group.add_argument('--start-epoch', default=None, type=int, metavar='N', - help='manual epoch number (useful on restarts)') + help='manual epoch number (useful on restarts)') group.add_argument('--decay-milestones', default=[90, 180, 270], type=int, nargs='+', metavar="MILESTONES", - help='list of decay epoch indices for multistep lr. must be increasing') + help='list of decay epoch indices for multistep lr. must be increasing') group.add_argument('--decay-epochs', type=float, default=90, metavar='N', - help='epoch interval to decay LR') + help='epoch interval to decay LR') group.add_argument('--warmup-epochs', type=int, default=5, metavar='N', - help='epochs to warmup LR, if scheduler supports') + help='epochs to warmup LR, if scheduler supports') group.add_argument('--warmup-prefix', action='store_true', default=False, - help='Exclude warmup period from decay schedule.'), + help='Exclude warmup period from decay schedule.'), group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N', - help='epochs to cooldown LR at min_lr, after cyclic schedule ends') + help='epochs to cooldown LR at min_lr, after cyclic schedule ends') group.add_argument('--patience-epochs', type=int, default=10, metavar='N', - help='patience epochs for Plateau LR scheduler (default: 10)') + help='patience epochs for Plateau LR scheduler (default: 10)') group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', - help='LR decay rate (default: 0.1)') + help='LR decay rate (default: 0.1)') # Augmentation & regularization parameters group = parser.add_argument_group('Augmentation and regularization parameters') group.add_argument('--no-aug', action='store_true', default=False, - help='Disable all training augmentation, override other train aug args') + help='Disable all training augmentation, override other train aug args') group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', - help='Random resize scale (default: 0.08 1.0)') -group.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO', - help='Random resize aspect ratio (default: 0.75 1.33)') + help='Random resize scale (default: 0.08 1.0)') +group.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO', + help='Random resize aspect ratio (default: 0.75 1.33)') group.add_argument('--hflip', type=float, default=0.5, - help='Horizontal flip training aug probability') + help='Horizontal flip training aug probability') group.add_argument('--vflip', type=float, default=0., - help='Vertical flip training aug probability') + help='Vertical flip training aug probability') group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', - help='Color jitter factor (default: 0.4)') + help='Color jitter factor (default: 0.4)') group.add_argument('--aa', type=str, default=None, metavar='NAME', - help='Use AutoAugment policy. "v0" or "original". (default: None)'), + help='Use AutoAugment policy. "v0" or "original". (default: None)'), group.add_argument('--aug-repeats', type=float, default=0, - help='Number of augmentation repetitions (distributed training only) (default: 0)') + help='Number of augmentation repetitions (distributed training only) (default: 0)') group.add_argument('--aug-splits', type=int, default=0, - help='Number of augmentation splits (default: 0, valid: 0 or >=2)') + help='Number of augmentation splits (default: 0, valid: 0 or >=2)') group.add_argument('--jsd-loss', action='store_true', default=False, - help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') + help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') group.add_argument('--bce-loss', action='store_true', default=False, - help='Enable BCE loss w/ Mixup/CutMix use.') + help='Enable BCE loss w/ Mixup/CutMix use.') group.add_argument('--bce-target-thresh', type=float, default=None, - help='Threshold for binarizing softened BCE targets (default: None, disabled)') + help='Threshold for binarizing softened BCE targets (default: None, disabled)') group.add_argument('--reprob', type=float, default=0., metavar='PCT', - help='Random erase prob (default: 0.)') + help='Random erase prob (default: 0.)') group.add_argument('--remode', type=str, default='pixel', - help='Random erase mode (default: "pixel")') + help='Random erase mode (default: "pixel")') group.add_argument('--recount', type=int, default=1, - help='Random erase count (default: 1)') + help='Random erase count (default: 1)') group.add_argument('--resplit', action='store_true', default=False, - help='Do not random erase first (clean) augmentation split') + help='Do not random erase first (clean) augmentation split') group.add_argument('--mixup', type=float, default=0.0, - help='mixup alpha, mixup enabled if > 0. (default: 0.)') + help='mixup alpha, mixup enabled if > 0. (default: 0.)') group.add_argument('--cutmix', type=float, default=0.0, - help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') + help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, - help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') + help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') group.add_argument('--mixup-prob', type=float, default=1.0, - help='Probability of performing mixup or cutmix when either/both is enabled') + help='Probability of performing mixup or cutmix when either/both is enabled') group.add_argument('--mixup-switch-prob', type=float, default=0.5, - help='Probability of switching to cutmix when both mixup and cutmix enabled') + help='Probability of switching to cutmix when both mixup and cutmix enabled') group.add_argument('--mixup-mode', type=str, default='batch', - help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') + help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', - help='Turn off mixup after this epoch, disabled if 0 (default: 0)') + help='Turn off mixup after this epoch, disabled if 0 (default: 0)') group.add_argument('--smoothing', type=float, default=0.1, - help='Label smoothing (default: 0.1)') + help='Label smoothing (default: 0.1)') group.add_argument('--train-interpolation', type=str, default='random', - help='Training interpolation (random, bilinear, bicubic default: "random")') + help='Training interpolation (random, bilinear, bicubic default: "random")') group.add_argument('--drop', type=float, default=0.0, metavar='PCT', - help='Dropout rate (default: 0.)') + help='Dropout rate (default: 0.)') group.add_argument('--drop-connect', type=float, default=None, metavar='PCT', - help='Drop connect rate, DEPRECATED, use drop-path (default: None)') + help='Drop connect rate, DEPRECATED, use drop-path (default: None)') group.add_argument('--drop-path', type=float, default=None, metavar='PCT', - help='Drop path rate (default: None)') + help='Drop path rate (default: None)') group.add_argument('--drop-block', type=float, default=None, metavar='PCT', - help='Drop block rate (default: None)') + help='Drop block rate (default: None)') # Batch norm parameters (only works with gen_efficientnet based models currently) group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.') group.add_argument('--bn-momentum', type=float, default=None, - help='BatchNorm momentum override (if not None)') + help='BatchNorm momentum override (if not None)') group.add_argument('--bn-eps', type=float, default=None, - help='BatchNorm epsilon override (if not None)') + help='BatchNorm epsilon override (if not None)') group.add_argument('--sync-bn', action='store_true', - help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') + help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') group.add_argument('--dist-bn', type=str, default='reduce', - help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') + help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') group.add_argument('--split-bn', action='store_true', - help='Enable separate BN layers per augmentation split.') + help='Enable separate BN layers per augmentation split.') # Model Exponential Moving Average group = parser.add_argument_group('Model exponential moving average parameters') group.add_argument('--model-ema', action='store_true', default=False, - help='Enable tracking moving average of model weights') + help='Enable tracking moving average of model weights') group.add_argument('--model-ema-force-cpu', action='store_true', default=False, - help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') + help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') group.add_argument('--model-ema-decay', type=float, default=0.9998, - help='decay factor for model weights moving average (default: 0.9998)') + help='decay factor for model weights moving average (default: 0.9998)') # Misc group = parser.add_argument_group('Miscellaneous parameters') group.add_argument('--seed', type=int, default=42, metavar='S', - help='random seed (default: 42)') + help='random seed (default: 42)') group.add_argument('--worker-seeding', type=str, default='all', - help='worker seed mode (default: all)') + help='worker seed mode (default: all)') group.add_argument('--log-interval', type=int, default=50, metavar='N', - help='how many batches to wait before logging training status') + help='how many batches to wait before logging training status') group.add_argument('--recovery-interval', type=int, default=0, metavar='N', - help='how many batches to wait before writing recovery checkpoint') + help='how many batches to wait before writing recovery checkpoint') group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N', - help='number of checkpoints to keep (default: 10)') + help='number of checkpoints to keep (default: 10)') group.add_argument('-j', '--workers', type=int, default=4, metavar='N', - help='how many training processes to use (default: 4)') + help='how many training processes to use (default: 4)') group.add_argument('--save-images', action='store_true', default=False, - help='save images of input bathes every log interval for debugging') + help='save images of input bathes every log interval for debugging') group.add_argument('--amp', action='store_true', default=False, - help='use NVIDIA Apex AMP or Native AMP for mixed precision training') + help='use NVIDIA Apex AMP or Native AMP for mixed precision training') group.add_argument('--amp-dtype', default='float16', type=str, - help='lower precision AMP dtype (default: float16)') + help='lower precision AMP dtype (default: float16)') group.add_argument('--amp-impl', default='native', type=str, - help='AMP impl to use, "native" or "apex" (default: native)') + help='AMP impl to use, "native" or "apex" (default: native)') group.add_argument('--no-ddp-bb', action='store_true', default=False, - help='Force broadcast buffers for native DDP to off.') + help='Force broadcast buffers for native DDP to off.') group.add_argument('--pin-mem', action='store_true', default=False, - help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') + help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') group.add_argument('--no-prefetcher', action='store_true', default=False, - help='disable fast prefetcher') + help='disable fast prefetcher') group.add_argument('--output', default='', type=str, metavar='PATH', - help='path to output folder (default: none, current dir)') + help='path to output folder (default: none, current dir)') group.add_argument('--experiment', default='', type=str, metavar='NAME', - help='name of train experiment, name of sub-folder for output') + help='name of train experiment, name of sub-folder for output') group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', - help='Best metric (default: "top1"') + help='Best metric (default: "top1"') group.add_argument('--tta', type=int, default=0, metavar='N', - help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') + help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') group.add_argument("--local_rank", default=0, type=int) group.add_argument('--use-multi-epochs-loader', action='store_true', default=False, - help='use the multi-epochs-loader to save time at the beginning of every epoch') + help='use the multi-epochs-loader to save time at the beginning of every epoch') group.add_argument('--log-wandb', action='store_true', default=False, - help='log training and validation metrics to wandb') + help='log training and validation metrics to wandb') def _parse_args(): @@ -371,8 +374,6 @@ def main(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True - if args.data and not args.data_dir: - args.data_dir = args.data args.prefetcher = not args.no_prefetcher device = utils.init_distributed_device(args) if args.distributed: @@ -383,14 +384,6 @@ def main(): _logger.info(f'Training with a single process on 1 device ({args.device}).') assert args.rank >= 0 - if utils.is_primary(args) and args.log_wandb: - if has_wandb: - wandb.init(project=args.experiment, config=args) - else: - _logger.warning( - "You've requested to log metrics to wandb but package not found. " - "Metrics not being logged to wandb, try `pip install wandb`") - # resolve AMP arguments based on PyTorch / Apex availability use_amp = None amp_dtype = torch.float16 @@ -432,6 +425,7 @@ def main(): bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint, + **args.model_kwargs, ) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' @@ -504,7 +498,11 @@ def main(): f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) ' f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.') - optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args)) + optimizer = create_optimizer_v2( + model, + **optimizer_kwargs(cfg=args), + **args.opt_kwargs, + ) # setup automatic mixed-precision (AMP) loss scaling and op casting amp_autocast = suppress # do nothing @@ -559,6 +557,8 @@ def main(): # NOTE: EMA model does not need to be wrapped by DDP # create the train and eval datasets + if args.data and not args.data_dir: + args.data_dir = args.data dataset_train = create_dataset( args.dataset, root=args.data_dir, @@ -712,6 +712,14 @@ def main(): with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) + if utils.is_primary(args) and args.log_wandb: + if has_wandb: + wandb.init(project=args.experiment, config=args) + else: + _logger.warning( + "You've requested to log metrics to wandb but package not found. " + "Metrics not being logged to wandb, try `pip install wandb`") + # setup learning rate schedule and starting epoch updates_per_epoch = len(loader_train) lr_scheduler, num_epochs = create_scheduler_v2( diff --git a/validate.py b/validate.py index 4669fbac..b606103d 100755 --- a/validate.py +++ b/validate.py @@ -26,7 +26,7 @@ from timm.data import create_dataset, create_loader, resolve_data_config, RealLa from timm.layers import apply_test_time_pool, set_fast_norm from timm.models import create_model, load_checkpoint, is_model, list_models from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \ - decay_batch_step, check_batch_size_retry + decay_batch_step, check_batch_size_retry, ParseKwargs try: from apex import amp @@ -71,6 +71,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int, metavar='N', help='mini-batch size (default: 256)') parser.add_argument('--img-size', default=None, type=int, metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--in-chans', type=int, default=None, metavar='N', + help='Image input channels (default: None => 3)') parser.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') parser.add_argument('--use-train-size', action='store_true', default=False, @@ -123,6 +125,8 @@ parser.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") parser.add_argument('--fast-norm', default=False, action='store_true', help='enable experimental fast-norm') +parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs) + scripting_group = parser.add_mutually_exclusive_group() scripting_group.add_argument('--torchscript', default=False, action='store_true', @@ -181,13 +185,20 @@ def validate(args): set_fast_norm() # create model + in_chans = 3 + if args.in_chans is not None: + in_chans = args.in_chans + elif args.input_size is not None: + in_chans = args.input_size[0] + model = create_model( args.model, pretrained=args.pretrained, num_classes=args.num_classes, - in_chans=3, + in_chans=in_chans, global_pool=args.gp, scriptable=args.torchscript, + **args.model_kwargs, ) if args.num_classes is None: assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.' @@ -232,8 +243,9 @@ def validate(args): criterion = nn.CrossEntropyLoss().to(device) + root_dir = args.data or args.data_dir dataset = create_dataset( - root=args.data, + root=root_dir, name=args.dataset, split=args.split, download=args.dataset_download, @@ -389,7 +401,7 @@ def main(): if args.model == 'all': # validate all models in a list of names with pretrained checkpoints args.pretrained = True - model_names = list_models(pretrained=True, exclude_filters=['*_in21k', '*_in22k', '*_dino']) + model_names = list_models('convnext*', pretrained=True, exclude_filters=['*_in21k', '*_in22k', '*in12k', '*_dino', '*fcmae']) model_cfgs = [(n, '') for n in model_names] elif not is_model(args.model): # model name doesn't exist, try as wildcard filter