Davit update formatting and fix grad checkpointing (#7)

fixed head to gap->norm->fc as per convnext, along with option for norm->gap->fc
failed tests due to clip convnext models, davit tests passed
pull/1583/head
Fredo Guan 1 year ago committed by GitHub
parent 10b3f696b4
commit 81ca323751
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -16,5 +16,6 @@ jobs:
package_name: timm
repo_owner: rwightman
path_to_docs: pytorch-image-models/hfdocs/source
version_tag_suffix: ""
secrets:
token: ${{ secrets.HUGGINGFACE_PUSH }}

@ -17,3 +17,4 @@ jobs:
package_name: timm
repo_owner: rwightman
path_to_docs: pytorch-image-models/hfdocs/source
version_tag_suffix: ""

@ -40,9 +40,10 @@ jobs:
- name: Install torch on ubuntu
if: startsWith(matrix.os, 'ubuntu')
run: |
pip install --no-cache-dir torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
sudo sed -i 's/azure\.//' /etc/apt/sources.list
sudo apt update
sudo apt install -y google-perftools
pip install --no-cache-dir torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
- name: Install requirements
run: |
pip install -r requirements.txt

10
.gitignore vendored

@ -106,6 +106,16 @@ output/
*.tar
*.pth
*.pt
*.torch
*.gz
Untitled.ipynb
Testing notebook.ipynb
# Root dir exclusions
/*.csv
/*.yaml
/*.json
/*.jpg
/*.png
/*.zip
/*.tar.*

@ -21,12 +21,36 @@ And a big thanks to all GitHub sponsors who helped with some of my costs before
## What's New
### 🤗 Survey: Feedback Appreciated 🤗
For a few months now, `timm` has been part of the Hugging Face ecosystem. Yearly, we survey users of our tools to see what we could do better, what we need to continue doing, or what we need to stop doing.
If you have a couple of minutes and want to participate in shaping the future of the ecosystem, please share your thoughts:
[**hf.co/oss-survey**](https://hf.co/oss-survey) 🙏
* ❗Updates after Oct 10, 2022 are available in 0.8.x pre-releases (`pip install --pre timm`) or cloning main❗
* Stable releases are 0.6.x and available by normal pip install or clone from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch.
### Jan 11, 2023
* Update ConvNeXt ImageNet-12k pretrain series w/ two new fine-tuned weights (and pre FT `.in12k` tags)
* `convnext_nano.in12k_ft_in1k` - 82.3 @ 224, 82.9 @ 288 (previously released)
* `convnext_tiny.in12k_ft_in1k` - 84.2 @ 224, 84.5 @ 288
* `convnext_small.in12k_ft_in1k` - 85.2 @ 224, 85.3 @ 288
### Jan 6, 2023
* Finally got around to adding `--model-kwargs` and `--opt-kwargs` to scripts to pass through rare args directly to model classes from cmd line
* `train.py /imagenet --model resnet50 --amp --model-kwargs output_stride=16 act_layer=silu`
* `train.py /imagenet --model vit_base_patch16_clip_224 --img-size 240 --amp --model-kwargs img_size=240 patch_size=12`
* Cleanup some popular models to better support arg passthrough / merge with model configs, more to go.
### Jan 5, 2023
* ConvNeXt-V2 models and weights added to existing `convnext.py`
* Paper: [ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](http://arxiv.org/abs/2301.00808)
* Reference impl: https://github.com/facebookresearch/ConvNeXt-V2 (NOTE: weights currently CC-BY-NC)
### Dec 23, 2022 🎄☃
* Add FlexiViT models and weights from https://github.com/google-research/big_vision (check out paper at https://arxiv.org/abs/2212.08013)
* NOTE currently resizing is static on model creation, on-the-fly dynamic / train patch size sampling is a WIP
* Many more models updated to multi-weight and downloadable via HF hub now (convnext, efficientnet, mobilenet, vision_transformer*, beit)
* More model pretrained tag and adjustments, some model names changed (working on deprecation translations, consider main branch DEV branch right now, use 0.6.x for stable use)
* More ImageNet-12k (subset of 22k) pretrain models popping up:
* `efficientnet_b5.in12k_ft_in1k` - 85.9 @ 448x448
* `vit_medium_patch16_gap_384.in12k_ft_in1k` - 85.5 @ 384x384
* `vit_medium_patch16_gap_256.in12k_ft_in1k` - 84.5 @ 256x256
* `convnext_nano.in12k_ft_in1k` - 82.9 @ 288x288
### Dec 8, 2022
* Add 'EVA l' to `vision_transformer.py`, MAE style ViT-L/14 MIM pretrain w/ EVA-CLIP targets, FT on ImageNet-1k (w/ ImageNet-22k intermediate for some)
@ -325,46 +349,6 @@ More models, more fixes
* TinyNet models added by [rsomani95](https://github.com/rsomani95)
* LCNet added via MobileNetV3 architecture
### Nov 22, 2021
* A number of updated weights anew new model defs
* `eca_halonext26ts` - 79.5 @ 256
* `resnet50_gn` (new) - 80.1 @ 224, 81.3 @ 288
* `resnet50` - 80.7 @ 224, 80.9 @ 288 (trained at 176, not replacing current a1 weights as default since these don't scale as well to higher res, [weights](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1h2_176-001a1197.pth))
* `resnext50_32x4d` - 81.1 @ 224, 82.0 @ 288
* `sebotnet33ts_256` (new) - 81.2 @ 224
* `lamhalobotnet50ts_256` - 81.5 @ 256
* `halonet50ts` - 81.7 @ 256
* `halo2botnet50ts_256` - 82.0 @ 256
* `resnet101` - 82.0 @ 224, 82.8 @ 288
* `resnetv2_101` (new) - 82.1 @ 224, 83.0 @ 288
* `resnet152` - 82.8 @ 224, 83.5 @ 288
* `regnetz_d8` (new) - 83.5 @ 256, 84.0 @ 320
* `regnetz_e8` (new) - 84.5 @ 256, 85.0 @ 320
* `vit_base_patch8_224` (85.8 top-1) & `in21k` variant weights added thanks [Martins Bruveris](https://github.com/martinsbruveris)
* Groundwork in for FX feature extraction thanks to [Alexander Soare](https://github.com/alexander-soare)
* models updated for tracing compatibility (almost full support with some distlled transformer exceptions)
### Oct 19, 2021
* ResNet strikes back (https://arxiv.org/abs/2110.00476) weights added, plus any extra training components used. Model weights and some more details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-rsb-weights)
* BCE loss and Repeated Augmentation support for RSB paper
* 4 series of ResNet based attention model experiments being added (implemented across byobnet.py/byoanet.py). These include all sorts of attention, from channel attn like SE, ECA to 2D QKV self-attention layers such as Halo, Bottlneck, Lambda. Details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights)
* Working implementations of the following 2D self-attention modules (likely to be differences from paper or eventual official impl):
* Halo (https://arxiv.org/abs/2103.12731)
* Bottleneck Transformer (https://arxiv.org/abs/2101.11605)
* LambdaNetworks (https://arxiv.org/abs/2102.08602)
* A RegNetZ series of models with some attention experiments (being added to). These do not follow the paper (https://arxiv.org/abs/2103.06877) in any way other than block architecture, details of official models are not available. See more here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights)
* ConvMixer (https://openreview.net/forum?id=TVHS5Y4dNvM), CrossVit (https://arxiv.org/abs/2103.14899), and BeiT (https://arxiv.org/abs/2106.08254) architectures + weights added
* freeze/unfreeze helpers by [Alexander Soare](https://github.com/alexander-soare)
### Aug 18, 2021
* Optimizer bonanza!
* Add LAMB and LARS optimizers, incl trust ratio clipping options. Tweaked to work properly in PyTorch XLA (tested on TPUs w/ `timm bits` [branch](https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/timm/bits))
* Add MADGRAD from FB research w/ a few tweaks (decoupled decay option, step handling that works with PyTorch XLA)
* Some cleanup on all optimizers and factory. No more `.data`, a bit more consistency, unit tests for all!
* SGDP and AdamP still won't work with PyTorch XLA but others should (have yet to test Adabelief, Adafactor, Adahessian myself).
* EfficientNet-V2 XL TF ported weights added, but they don't validate well in PyTorch (L is better). The pre-processing for the V2 TF training is a bit diff and the fine-tuned 21k -> 1k weights are very sensitive and less robust than the 1k weights.
* Added PyTorch trained EfficientNet-V2 'Tiny' w/ GlobalContext attn weights. Only .1-.2 top-1 better than the SE so more of a curiosity for those interested.
## Introduction
Py**T**orch **Im**age **M**odels (`timm`) is a collection of image models, layers, utilities, optimizers, schedulers, data-loaders / augmentations, and reference training / validation scripts that aim to pull together a wide variety of SOTA models with ability to reproduce ImageNet training results.
@ -385,6 +369,7 @@ A full version of the list below with source links can be found in the [document
* CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399
* CoAtNet (Convolution and Attention) - https://arxiv.org/abs/2106.04803
* ConvNeXt - https://arxiv.org/abs/2201.03545
* ConvNeXt-V2 - http://arxiv.org/abs/2301.00808
* ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697
* CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929
* DeiT - https://arxiv.org/abs/2012.12877
@ -407,6 +392,7 @@ A full version of the list below with source links can be found in the [document
* Single-Path NAS - https://arxiv.org/abs/1904.02877
* TinyNet - https://arxiv.org/abs/2010.14819
* EVA - https://arxiv.org/abs/2211.07636
* FlexiViT - https://arxiv.org/abs/2212.08013
* GCViT (Global Context Vision Transformer) - https://arxiv.org/abs/2206.09959
* GhostNet - https://arxiv.org/abs/1911.11907
* gMLP - https://arxiv.org/abs/2105.08050

@ -22,7 +22,7 @@ from timm.data import resolve_data_config
from timm.layers import set_fast_norm
from timm.models import create_model, is_model, list_models
from timm.optim import create_optimizer_v2
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry
from timm.utils import setup_default_logging, set_jit_fuser, decay_batch_step, check_batch_size_retry, ParseKwargs
has_apex = False
try:
@ -108,12 +108,15 @@ parser.add_argument('--grad-checkpointing', action='store_true', default=False,
help='Enable gradient checkpointing through model blocks/stages')
parser.add_argument('--amp', action='store_true', default=False,
help='use PyTorch Native AMP for mixed precision training. Overrides --precision arg.')
parser.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16). Overrides --precision arg if args.amp True.')
parser.add_argument('--precision', default='float32', type=str,
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
# codegen (model compilation) options
scripting_group = parser.add_mutually_exclusive_group()
@ -124,7 +127,6 @@ scripting_group.add_argument('--torchcompile', nargs='?', type=str, default=None
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd optimization.")
# train optimizer parameters
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd"')
@ -168,19 +170,21 @@ def count_params(model: nn.Module):
def resolve_precision(precision: str):
assert precision in ('amp', 'float16', 'bfloat16', 'float32')
use_amp = False
assert precision in ('amp', 'amp_bfloat16', 'float16', 'bfloat16', 'float32')
amp_dtype = None # amp disabled
model_dtype = torch.float32
data_dtype = torch.float32
if precision == 'amp':
use_amp = True
amp_dtype = torch.float16
elif precision == 'amp_bfloat16':
amp_dtype = torch.bfloat16
elif precision == 'float16':
model_dtype = torch.float16
data_dtype = torch.float16
elif precision == 'bfloat16':
model_dtype = torch.bfloat16
data_dtype = torch.bfloat16
return use_amp, model_dtype, data_dtype
return amp_dtype, model_dtype, data_dtype
def profile_deepspeed(model, input_size=(3, 224, 224), batch_size=1, detailed=False):
@ -228,9 +232,12 @@ class BenchmarkRunner:
self.model_name = model_name
self.detail = detail
self.device = device
self.use_amp, self.model_dtype, self.data_dtype = resolve_precision(precision)
self.amp_dtype, self.model_dtype, self.data_dtype = resolve_precision(precision)
self.channels_last = kwargs.pop('channels_last', False)
self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=torch.float16) if self.use_amp else suppress
if self.amp_dtype is not None:
self.amp_autocast = partial(torch.cuda.amp.autocast, dtype=self.amp_dtype)
else:
self.amp_autocast = suppress
if fuser:
set_jit_fuser(fuser)
@ -243,6 +250,7 @@ class BenchmarkRunner:
drop_rate=kwargs.pop('drop', 0.),
drop_path_rate=kwargs.pop('drop_path', None),
drop_block_rate=kwargs.pop('drop_block', None),
**kwargs.pop('model_kwargs', {}),
)
self.model.to(
device=self.device,
@ -560,7 +568,7 @@ def _try_run(
def benchmark(args):
if args.amp:
_logger.warning("Overriding precision to 'amp' since --amp flag set.")
args.precision = 'amp'
args.precision = 'amp' if args.amp_dtype == 'float16' else '_'.join(['amp', args.amp_dtype])
_logger.info(f'Benchmarking in {args.precision} precision. '
f'{"NHWC" if args.channels_last else "NCHW"} layout. '
f'torchscript {"enabled" if args.torchscript else "disabled"}')

@ -1,5 +1,45 @@
# Archived Changes
### Nov 22, 2021
* A number of updated weights anew new model defs
* `eca_halonext26ts` - 79.5 @ 256
* `resnet50_gn` (new) - 80.1 @ 224, 81.3 @ 288
* `resnet50` - 80.7 @ 224, 80.9 @ 288 (trained at 176, not replacing current a1 weights as default since these don't scale as well to higher res, [weights](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1h2_176-001a1197.pth))
* `resnext50_32x4d` - 81.1 @ 224, 82.0 @ 288
* `sebotnet33ts_256` (new) - 81.2 @ 224
* `lamhalobotnet50ts_256` - 81.5 @ 256
* `halonet50ts` - 81.7 @ 256
* `halo2botnet50ts_256` - 82.0 @ 256
* `resnet101` - 82.0 @ 224, 82.8 @ 288
* `resnetv2_101` (new) - 82.1 @ 224, 83.0 @ 288
* `resnet152` - 82.8 @ 224, 83.5 @ 288
* `regnetz_d8` (new) - 83.5 @ 256, 84.0 @ 320
* `regnetz_e8` (new) - 84.5 @ 256, 85.0 @ 320
* `vit_base_patch8_224` (85.8 top-1) & `in21k` variant weights added thanks [Martins Bruveris](https://github.com/martinsbruveris)
* Groundwork in for FX feature extraction thanks to [Alexander Soare](https://github.com/alexander-soare)
* models updated for tracing compatibility (almost full support with some distlled transformer exceptions)
### Oct 19, 2021
* ResNet strikes back (https://arxiv.org/abs/2110.00476) weights added, plus any extra training components used. Model weights and some more details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-rsb-weights)
* BCE loss and Repeated Augmentation support for RSB paper
* 4 series of ResNet based attention model experiments being added (implemented across byobnet.py/byoanet.py). These include all sorts of attention, from channel attn like SE, ECA to 2D QKV self-attention layers such as Halo, Bottlneck, Lambda. Details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights)
* Working implementations of the following 2D self-attention modules (likely to be differences from paper or eventual official impl):
* Halo (https://arxiv.org/abs/2103.12731)
* Bottleneck Transformer (https://arxiv.org/abs/2101.11605)
* LambdaNetworks (https://arxiv.org/abs/2102.08602)
* A RegNetZ series of models with some attention experiments (being added to). These do not follow the paper (https://arxiv.org/abs/2103.06877) in any way other than block architecture, details of official models are not available. See more here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights)
* ConvMixer (https://openreview.net/forum?id=TVHS5Y4dNvM), CrossVit (https://arxiv.org/abs/2103.14899), and BeiT (https://arxiv.org/abs/2106.08254) architectures + weights added
* freeze/unfreeze helpers by [Alexander Soare](https://github.com/alexander-soare)
### Aug 18, 2021
* Optimizer bonanza!
* Add LAMB and LARS optimizers, incl trust ratio clipping options. Tweaked to work properly in PyTorch XLA (tested on TPUs w/ `timm bits` [branch](https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/timm/bits))
* Add MADGRAD from FB research w/ a few tweaks (decoupled decay option, step handling that works with PyTorch XLA)
* Some cleanup on all optimizers and factory. No more `.data`, a bit more consistency, unit tests for all!
* SGDP and AdamP still won't work with PyTorch XLA but others should (have yet to test Adabelief, Adafactor, Adahessian myself).
* EfficientNet-V2 XL TF ported weights added, but they don't validate well in PyTorch (L is better). The pre-processing for the V2 TF training is a bit diff and the fine-tuned 21k -> 1k weights are very sensitive and less robust than the 1k weights.
* Added PyTorch trained EfficientNet-V2 'Tiny' w/ GlobalContext attn weights. Only .1-.2 top-1 better than the SE so more of a curiosity for those interested.
### July 12, 2021
* Add XCiT models from [official facebook impl](https://github.com/facebookresearch/xcit). Contributed by [Alexander Soare](https://github.com/alexander-soare)

@ -1,4 +1,183 @@
# Recent Changes
### Jan 5, 2023
* ConvNeXt-V2 models and weights added to existing `convnext.py`
* Paper: [ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders](http://arxiv.org/abs/2301.00808)
* Reference impl: https://github.com/facebookresearch/ConvNeXt-V2 (NOTE: weights currently CC-BY-NC)
### Dec 23, 2022 🎄☃
* Add FlexiViT models and weights from https://github.com/google-research/big_vision (check out paper at https://arxiv.org/abs/2212.08013)
* NOTE currently resizing is static on model creation, on-the-fly dynamic / train patch size sampling is a WIP
* Many more models updated to multi-weight and downloadable via HF hub now (convnext, efficientnet, mobilenet, vision_transformer*, beit)
* More model pretrained tag and adjustments, some model names changed (working on deprecation translations, consider main branch DEV branch right now, use 0.6.x for stable use)
* More ImageNet-12k (subset of 22k) pretrain models popping up:
* `efficientnet_b5.in12k_ft_in1k` - 85.9 @ 448x448
* `vit_medium_patch16_gap_384.in12k_ft_in1k` - 85.5 @ 384x384
* `vit_medium_patch16_gap_256.in12k_ft_in1k` - 84.5 @ 256x256
* `convnext_nano.in12k_ft_in1k` - 82.9 @ 288x288
### Dec 8, 2022
* Add 'EVA l' to `vision_transformer.py`, MAE style ViT-L/14 MIM pretrain w/ EVA-CLIP targets, FT on ImageNet-1k (w/ ImageNet-22k intermediate for some)
* original source: https://github.com/baaivision/EVA
| model | top1 | param_count | gmac | macts | hub |
|:------------------------------------------|-----:|------------:|------:|------:|:----------------------------------------|
| eva_large_patch14_336.in22k_ft_in22k_in1k | 89.2 | 304.5 | 191.1 | 270.2 | [link](https://huggingface.co/BAAI/EVA) |
| eva_large_patch14_336.in22k_ft_in1k | 88.7 | 304.5 | 191.1 | 270.2 | [link](https://huggingface.co/BAAI/EVA) |
| eva_large_patch14_196.in22k_ft_in22k_in1k | 88.6 | 304.1 | 61.6 | 63.5 | [link](https://huggingface.co/BAAI/EVA) |
| eva_large_patch14_196.in22k_ft_in1k | 87.9 | 304.1 | 61.6 | 63.5 | [link](https://huggingface.co/BAAI/EVA) |
### Dec 6, 2022
* Add 'EVA g', BEiT style ViT-g/14 model weights w/ both MIM pretrain and CLIP pretrain to `beit.py`.
* original source: https://github.com/baaivision/EVA
* paper: https://arxiv.org/abs/2211.07636
| model | top1 | param_count | gmac | macts | hub |
|:-----------------------------------------|-------:|--------------:|-------:|--------:|:----------------------------------------|
| eva_giant_patch14_560.m30m_ft_in22k_in1k | 89.8 | 1014.4 | 1906.8 | 2577.2 | [link](https://huggingface.co/BAAI/EVA) |
| eva_giant_patch14_336.m30m_ft_in22k_in1k | 89.6 | 1013 | 620.6 | 550.7 | [link](https://huggingface.co/BAAI/EVA) |
| eva_giant_patch14_336.clip_ft_in1k | 89.4 | 1013 | 620.6 | 550.7 | [link](https://huggingface.co/BAAI/EVA) |
| eva_giant_patch14_224.clip_ft_in1k | 89.1 | 1012.6 | 267.2 | 192.6 | [link](https://huggingface.co/BAAI/EVA) |
### Dec 5, 2022
* Pre-release (`0.8.0dev0`) of multi-weight support (`model_arch.pretrained_tag`). Install with `pip install --pre timm`
* vision_transformer, maxvit, convnext are the first three model impl w/ support
* model names are changing with this (previous _21k, etc. fn will merge), still sorting out deprecation handling
* bugs are likely, but I need feedback so please try it out
* if stability is needed, please use 0.6.x pypi releases or clone from [0.6.x branch](https://github.com/rwightman/pytorch-image-models/tree/0.6.x)
* Support for PyTorch 2.0 compile is added in train/validate/inference/benchmark, use `--torchcompile` argument
* Inference script allows more control over output, select k for top-class index + prob json, csv or parquet output
* Add a full set of fine-tuned CLIP image tower weights from both LAION-2B and original OpenAI CLIP models
| model | top1 | param_count | gmac | macts | hub |
|:-------------------------------------------------|-------:|--------------:|-------:|--------:|:-------------------------------------------------------------------------------------|
| vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k | 88.6 | 632.5 | 391 | 407.5 | [link](https://huggingface.co/timm/vit_huge_patch14_clip_336.laion2b_ft_in12k_in1k) |
| vit_large_patch14_clip_336.openai_ft_in12k_in1k | 88.3 | 304.5 | 191.1 | 270.2 | [link](https://huggingface.co/timm/vit_large_patch14_clip_336.openai_ft_in12k_in1k) |
| vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k | 88.2 | 632 | 167.4 | 139.4 | [link](https://huggingface.co/timm/vit_huge_patch14_clip_224.laion2b_ft_in12k_in1k) |
| vit_large_patch14_clip_336.laion2b_ft_in12k_in1k | 88.2 | 304.5 | 191.1 | 270.2 | [link](https://huggingface.co/timm/vit_large_patch14_clip_336.laion2b_ft_in12k_in1k) |
| vit_large_patch14_clip_224.openai_ft_in12k_in1k | 88.2 | 304.2 | 81.1 | 88.8 | [link](https://huggingface.co/timm/vit_large_patch14_clip_224.openai_ft_in12k_in1k) |
| vit_large_patch14_clip_224.laion2b_ft_in12k_in1k | 87.9 | 304.2 | 81.1 | 88.8 | [link](https://huggingface.co/timm/vit_large_patch14_clip_224.laion2b_ft_in12k_in1k) |
| vit_large_patch14_clip_224.openai_ft_in1k | 87.9 | 304.2 | 81.1 | 88.8 | [link](https://huggingface.co/timm/vit_large_patch14_clip_224.openai_ft_in1k) |
| vit_large_patch14_clip_336.laion2b_ft_in1k | 87.9 | 304.5 | 191.1 | 270.2 | [link](https://huggingface.co/timm/vit_large_patch14_clip_336.laion2b_ft_in1k) |
| vit_huge_patch14_clip_224.laion2b_ft_in1k | 87.6 | 632 | 167.4 | 139.4 | [link](https://huggingface.co/timm/vit_huge_patch14_clip_224.laion2b_ft_in1k) |
| vit_large_patch14_clip_224.laion2b_ft_in1k | 87.3 | 304.2 | 81.1 | 88.8 | [link](https://huggingface.co/timm/vit_large_patch14_clip_224.laion2b_ft_in1k) |
| vit_base_patch16_clip_384.laion2b_ft_in12k_in1k | 87.2 | 86.9 | 55.5 | 101.6 | [link](https://huggingface.co/timm/vit_base_patch16_clip_384.laion2b_ft_in12k_in1k) |
| vit_base_patch16_clip_384.openai_ft_in12k_in1k | 87 | 86.9 | 55.5 | 101.6 | [link](https://huggingface.co/timm/vit_base_patch16_clip_384.openai_ft_in12k_in1k) |
| vit_base_patch16_clip_384.laion2b_ft_in1k | 86.6 | 86.9 | 55.5 | 101.6 | [link](https://huggingface.co/timm/vit_base_patch16_clip_384.laion2b_ft_in1k) |
| vit_base_patch16_clip_384.openai_ft_in1k | 86.2 | 86.9 | 55.5 | 101.6 | [link](https://huggingface.co/timm/vit_base_patch16_clip_384.openai_ft_in1k) |
| vit_base_patch16_clip_224.laion2b_ft_in12k_in1k | 86.2 | 86.6 | 17.6 | 23.9 | [link](https://huggingface.co/timm/vit_base_patch16_clip_224.laion2b_ft_in12k_in1k) |
| vit_base_patch16_clip_224.openai_ft_in12k_in1k | 85.9 | 86.6 | 17.6 | 23.9 | [link](https://huggingface.co/timm/vit_base_patch16_clip_224.openai_ft_in12k_in1k) |
| vit_base_patch32_clip_448.laion2b_ft_in12k_in1k | 85.8 | 88.3 | 17.9 | 23.9 | [link](https://huggingface.co/timm/vit_base_patch32_clip_448.laion2b_ft_in12k_in1k) |
| vit_base_patch16_clip_224.laion2b_ft_in1k | 85.5 | 86.6 | 17.6 | 23.9 | [link](https://huggingface.co/timm/vit_base_patch16_clip_224.laion2b_ft_in1k) |
| vit_base_patch32_clip_384.laion2b_ft_in12k_in1k | 85.4 | 88.3 | 13.1 | 16.5 | [link](https://huggingface.co/timm/vit_base_patch32_clip_384.laion2b_ft_in12k_in1k) |
| vit_base_patch16_clip_224.openai_ft_in1k | 85.3 | 86.6 | 17.6 | 23.9 | [link](https://huggingface.co/timm/vit_base_patch16_clip_224.openai_ft_in1k) |
| vit_base_patch32_clip_384.openai_ft_in12k_in1k | 85.2 | 88.3 | 13.1 | 16.5 | [link](https://huggingface.co/timm/vit_base_patch32_clip_384.openai_ft_in12k_in1k) |
| vit_base_patch32_clip_224.laion2b_ft_in12k_in1k | 83.3 | 88.2 | 4.4 | 5 | [link](https://huggingface.co/timm/vit_base_patch32_clip_224.laion2b_ft_in12k_in1k) |
| vit_base_patch32_clip_224.laion2b_ft_in1k | 82.6 | 88.2 | 4.4 | 5 | [link](https://huggingface.co/timm/vit_base_patch32_clip_224.laion2b_ft_in1k) |
| vit_base_patch32_clip_224.openai_ft_in1k | 81.9 | 88.2 | 4.4 | 5 | [link](https://huggingface.co/timm/vit_base_patch32_clip_224.openai_ft_in1k) |
* Port of MaxViT Tensorflow Weights from official impl at https://github.com/google-research/maxvit
* There was larger than expected drops for the upscaled 384/512 in21k fine-tune weights, possible detail missing, but the 21k FT did seem sensitive to small preprocessing
| model | top1 | param_count | gmac | macts | hub |
|:-----------------------------------|-------:|--------------:|-------:|--------:|:-----------------------------------------------------------------------|
| maxvit_xlarge_tf_512.in21k_ft_in1k | 88.5 | 475.8 | 534.1 | 1413.2 | [link](https://huggingface.co/timm/maxvit_xlarge_tf_512.in21k_ft_in1k) |
| maxvit_xlarge_tf_384.in21k_ft_in1k | 88.3 | 475.3 | 292.8 | 668.8 | [link](https://huggingface.co/timm/maxvit_xlarge_tf_384.in21k_ft_in1k) |
| maxvit_base_tf_512.in21k_ft_in1k | 88.2 | 119.9 | 138 | 704 | [link](https://huggingface.co/timm/maxvit_base_tf_512.in21k_ft_in1k) |
| maxvit_large_tf_512.in21k_ft_in1k | 88 | 212.3 | 244.8 | 942.2 | [link](https://huggingface.co/timm/maxvit_large_tf_512.in21k_ft_in1k) |
| maxvit_large_tf_384.in21k_ft_in1k | 88 | 212 | 132.6 | 445.8 | [link](https://huggingface.co/timm/maxvit_large_tf_384.in21k_ft_in1k) |
| maxvit_base_tf_384.in21k_ft_in1k | 87.9 | 119.6 | 73.8 | 332.9 | [link](https://huggingface.co/timm/maxvit_base_tf_384.in21k_ft_in1k) |
| maxvit_base_tf_512.in1k | 86.6 | 119.9 | 138 | 704 | [link](https://huggingface.co/timm/maxvit_base_tf_512.in1k) |
| maxvit_large_tf_512.in1k | 86.5 | 212.3 | 244.8 | 942.2 | [link](https://huggingface.co/timm/maxvit_large_tf_512.in1k) |
| maxvit_base_tf_384.in1k | 86.3 | 119.6 | 73.8 | 332.9 | [link](https://huggingface.co/timm/maxvit_base_tf_384.in1k) |
| maxvit_large_tf_384.in1k | 86.2 | 212 | 132.6 | 445.8 | [link](https://huggingface.co/timm/maxvit_large_tf_384.in1k) |
| maxvit_small_tf_512.in1k | 86.1 | 69.1 | 67.3 | 383.8 | [link](https://huggingface.co/timm/maxvit_small_tf_512.in1k) |
| maxvit_tiny_tf_512.in1k | 85.7 | 31 | 33.5 | 257.6 | [link](https://huggingface.co/timm/maxvit_tiny_tf_512.in1k) |
| maxvit_small_tf_384.in1k | 85.5 | 69 | 35.9 | 183.6 | [link](https://huggingface.co/timm/maxvit_small_tf_384.in1k) |
| maxvit_tiny_tf_384.in1k | 85.1 | 31 | 17.5 | 123.4 | [link](https://huggingface.co/timm/maxvit_tiny_tf_384.in1k) |
| maxvit_large_tf_224.in1k | 84.9 | 211.8 | 43.7 | 127.4 | [link](https://huggingface.co/timm/maxvit_large_tf_224.in1k) |
| maxvit_base_tf_224.in1k | 84.9 | 119.5 | 24 | 95 | [link](https://huggingface.co/timm/maxvit_base_tf_224.in1k) |
| maxvit_small_tf_224.in1k | 84.4 | 68.9 | 11.7 | 53.2 | [link](https://huggingface.co/timm/maxvit_small_tf_224.in1k) |
| maxvit_tiny_tf_224.in1k | 83.4 | 30.9 | 5.6 | 35.8 | [link](https://huggingface.co/timm/maxvit_tiny_tf_224.in1k) |
### Oct 15, 2022
* Train and validation script enhancements
* Non-GPU (ie CPU) device support
* SLURM compatibility for train script
* HF datasets support (via ReaderHfds)
* TFDS/WDS dataloading improvements (sample padding/wrap for distributed use fixed wrt sample count estimate)
* in_chans !=3 support for scripts / loader
* Adan optimizer
* Can enable per-step LR scheduling via args
* Dataset 'parsers' renamed to 'readers', more descriptive of purpose
* AMP args changed, APEX via `--amp-impl apex`, bfloat16 supportedf via `--amp-dtype bfloat16`
* main branch switched to 0.7.x version, 0.6x forked for stable release of weight only adds
* master -> main branch rename
### Oct 10, 2022
* More weights in `maxxvit` series, incl first ConvNeXt block based `coatnext` and `maxxvit` experiments:
* `coatnext_nano_rw_224` - 82.0 @ 224 (G) -- (uses ConvNeXt conv block, no BatchNorm)
* `maxxvit_rmlp_nano_rw_256` - 83.0 @ 256, 83.7 @ 320 (G) (uses ConvNeXt conv block, no BN)
* `maxvit_rmlp_small_rw_224` - 84.5 @ 224, 85.1 @ 320 (G)
* `maxxvit_rmlp_small_rw_256` - 84.6 @ 256, 84.9 @ 288 (G) -- could be trained better, hparams need tuning (uses ConvNeXt block, no BN)
* `coatnet_rmlp_2_rw_224` - 84.6 @ 224, 85 @ 320 (T)
* NOTE: official MaxVit weights (in1k) have been released at https://github.com/google-research/maxvit -- some extra work is needed to port and adapt since my impl was created independently of theirs and has a few small differences + the whole TF same padding fun.
### Sept 23, 2022
* LAION-2B CLIP image towers supported as pretrained backbones for fine-tune or features (no classifier)
* vit_base_patch32_224_clip_laion2b
* vit_large_patch14_224_clip_laion2b
* vit_huge_patch14_224_clip_laion2b
* vit_giant_patch14_224_clip_laion2b
### Sept 7, 2022
* Hugging Face [`timm` docs](https://huggingface.co/docs/hub/timm) home now exists, look for more here in the future
* Add BEiT-v2 weights for base and large 224x224 models from https://github.com/microsoft/unilm/tree/master/beit2
* Add more weights in `maxxvit` series incl a `pico` (7.5M params, 1.9 GMACs), two `tiny` variants:
* `maxvit_rmlp_pico_rw_256` - 80.5 @ 256, 81.3 @ 320 (T)
* `maxvit_tiny_rw_224` - 83.5 @ 224 (G)
* `maxvit_rmlp_tiny_rw_256` - 84.2 @ 256, 84.8 @ 320 (T)
### Aug 29, 2022
* MaxVit window size scales with img_size by default. Add new RelPosMlp MaxViT weight that leverages this:
* `maxvit_rmlp_nano_rw_256` - 83.0 @ 256, 83.6 @ 320 (T)
### Aug 26, 2022
* CoAtNet (https://arxiv.org/abs/2106.04803) and MaxVit (https://arxiv.org/abs/2204.01697) `timm` original models
* both found in [`maxxvit.py`](https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/maxxvit.py) model def, contains numerous experiments outside scope of original papers
* an unfinished Tensorflow version from MaxVit authors can be found https://github.com/google-research/maxvit
* Initial CoAtNet and MaxVit timm pretrained weights (working on more):
* `coatnet_nano_rw_224` - 81.7 @ 224 (T)
* `coatnet_rmlp_nano_rw_224` - 82.0 @ 224, 82.8 @ 320 (T)
* `coatnet_0_rw_224` - 82.4 (T) -- NOTE timm '0' coatnets have 2 more 3rd stage blocks
* `coatnet_bn_0_rw_224` - 82.4 (T)
* `maxvit_nano_rw_256` - 82.9 @ 256 (T)
* `coatnet_rmlp_1_rw_224` - 83.4 @ 224, 84 @ 320 (T)
* `coatnet_1_rw_224` - 83.6 @ 224 (G)
* (T) = TPU trained with `bits_and_tpu` branch training code, (G) = GPU trained
* GCVit (weights adapted from https://github.com/NVlabs/GCVit, code 100% `timm` re-write for license purposes)
* MViT-V2 (multi-scale vit, adapted from https://github.com/facebookresearch/mvit)
* EfficientFormer (adapted from https://github.com/snap-research/EfficientFormer)
* PyramidVisionTransformer-V2 (adapted from https://github.com/whai362/PVT)
* 'Fast Norm' support for LayerNorm and GroupNorm that avoids float32 upcast w/ AMP (uses APEX LN if available for further boost)
### Aug 15, 2022
* ConvNeXt atto weights added
* `convnext_atto` - 75.7 @ 224, 77.0 @ 288
* `convnext_atto_ols` - 75.9 @ 224, 77.2 @ 288
### Aug 5, 2022
* More custom ConvNeXt smaller model defs with weights
* `convnext_femto` - 77.5 @ 224, 78.7 @ 288
* `convnext_femto_ols` - 77.9 @ 224, 78.9 @ 288
* `convnext_pico` - 79.5 @ 224, 80.4 @ 288
* `convnext_pico_ols` - 79.5 @ 224, 80.5 @ 288
* `convnext_nano_ols` - 80.9 @ 224, 81.6 @ 288
* Updated EdgeNeXt to improve ONNX export, add new base variant and weights from original (https://github.com/mmaaz60/EdgeNeXt)
### July 28, 2022
* Add freshly minted DeiT-III Medium (width=512, depth=12, num_heads=8) model weights. Thanks [Hugo Touvron](https://github.com/TouvronHugo)!
### July 27, 2022
* All runtime benchmark and validation result csv files are up-to-date!
@ -133,42 +312,3 @@ More models, more fixes
* TinyNet models added by [rsomani95](https://github.com/rsomani95)
* LCNet added via MobileNetV3 architecture
### Nov 22, 2021
* A number of updated weights anew new model defs
* `eca_halonext26ts` - 79.5 @ 256
* `resnet50_gn` (new) - 80.1 @ 224, 81.3 @ 288
* `resnet50` - 80.7 @ 224, 80.9 @ 288 (trained at 176, not replacing current a1 weights as default since these don't scale as well to higher res, [weights](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1h2_176-001a1197.pth))
* `resnext50_32x4d` - 81.1 @ 224, 82.0 @ 288
* `sebotnet33ts_256` (new) - 81.2 @ 224
* `lamhalobotnet50ts_256` - 81.5 @ 256
* `halonet50ts` - 81.7 @ 256
* `halo2botnet50ts_256` - 82.0 @ 256
* `resnet101` - 82.0 @ 224, 82.8 @ 288
* `resnetv2_101` (new) - 82.1 @ 224, 83.0 @ 288
* `resnet152` - 82.8 @ 224, 83.5 @ 288
* `regnetz_d8` (new) - 83.5 @ 256, 84.0 @ 320
* `regnetz_e8` (new) - 84.5 @ 256, 85.0 @ 320
* `vit_base_patch8_224` (85.8 top-1) & `in21k` variant weights added thanks [Martins Bruveris](https://github.com/martinsbruveris)
* Groundwork in for FX feature extraction thanks to [Alexander Soare](https://github.com/alexander-soare)
* models updated for tracing compatibility (almost full support with some distlled transformer exceptions)
### Oct 19, 2021
* ResNet strikes back (https://arxiv.org/abs/2110.00476) weights added, plus any extra training components used. Model weights and some more details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-rsb-weights)
* BCE loss and Repeated Augmentation support for RSB paper
* 4 series of ResNet based attention model experiments being added (implemented across byobnet.py/byoanet.py). These include all sorts of attention, from channel attn like SE, ECA to 2D QKV self-attention layers such as Halo, Bottlneck, Lambda. Details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights)
* Working implementations of the following 2D self-attention modules (likely to be differences from paper or eventual official impl):
* Halo (https://arxiv.org/abs/2103.12731)
* Bottleneck Transformer (https://arxiv.org/abs/2101.11605)
* LambdaNetworks (https://arxiv.org/abs/2102.08602)
* A RegNetZ series of models with some attention experiments (being added to). These do not follow the paper (https://arxiv.org/abs/2103.06877) in any way other than block architecture, details of official models are not available. See more here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights)
* ConvMixer (https://openreview.net/forum?id=TVHS5Y4dNvM), CrossVit (https://arxiv.org/abs/2103.14899), and BeiT (https://arxiv.org/abs/2106.08254) architectures + weights added
* freeze/unfreeze helpers by [Alexander Soare](https://github.com/alexander-soare)
### Aug 18, 2021
* Optimizer bonanza!
* Add LAMB and LARS optimizers, incl trust ratio clipping options. Tweaked to work properly in PyTorch XLA (tested on TPUs w/ `timm bits` [branch](https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/timm/bits))
* Add MADGRAD from FB research w/ a few tweaks (decoupled decay option, step handling that works with PyTorch XLA)
* Some cleanup on all optimizers and factory. No more `.data`, a bit more consistency, unit tests for all!
* SGDP and AdamP still won't work with PyTorch XLA but others should (have yet to test Adabelief, Adafactor, Adahessian myself).
* EfficientNet-V2 XL TF ported weights added, but they don't validate well in PyTorch (L is better). The pre-processing for the V2 TF training is a bit diff and the fine-tuned 21k -> 1k weights are very sensitive and less robust than the 1k weights.
* Added PyTorch trained EfficientNet-V2 'Tiny' w/ GlobalContext attn weights. Only .1-.2 top-1 better than the SE so more of a curiosity for those interested.

@ -0,0 +1,14 @@
# Hugging Face Timm Docs
## Getting Started
```
pip install git+https://github.com/huggingface/doc-builder.git@main#egg=hf-doc-builder
pip install watchdog black
```
## Preview the Docs Locally
```
doc-builder preview timm hfdocs/source
```

@ -1,149 +1,160 @@
- sections:
- local: index
title: Pytorch Image Models (timm)
title: Home
- local: quickstart
title: Quickstart
- local: installation
title: Installation
title: Get started
- sections:
- local: feature_extraction
title: Using Pretrained Models as Feature Extractors
- local: training_script
title: Training With The Official Training Script
- local: hf_hub
title: Share and Load Models from the 🤗 Hugging Face Hub
title: Tutorials
- sections:
- local: models
title: Model Summaries
- local: results
title: Results
- local: scripts
title: Scripts
- local: training_hparam_examples
title: Training Examples
- local: feature_extraction
title: Feature Extraction
- local: changes
title: Recent Changes
- local: archived_changes
title: Archived Changes
- local: model_pages
title: Model Pages
isExpanded: false
sections:
- local: models/adversarial-inception-v3
title: Adversarial Inception v3
- local: models/advprop
title: AdvProp (EfficientNet)
- local: models/big-transfer
title: Big Transfer (BiT)
- local: models/csp-darknet
title: CSP-DarkNet
- local: models/csp-resnet
title: CSP-ResNet
- local: models/csp-resnext
title: CSP-ResNeXt
- local: models/densenet
title: DenseNet
- local: models/dla
title: Deep Layer Aggregation
- local: models/dpn
title: Dual Path Network (DPN)
- local: models/ecaresnet
title: ECA-ResNet
- local: models/efficientnet
title: EfficientNet
- local: models/efficientnet-pruned
title: EfficientNet (Knapsack Pruned)
- local: models/ensemble-adversarial
title: Ensemble Adversarial Inception ResNet v2
- local: models/ese-vovnet
title: ESE-VoVNet
- local: models/fbnet
title: FBNet
- local: models/gloun-inception-v3
title: (Gluon) Inception v3
- local: models/gloun-resnet
title: (Gluon) ResNet
- local: models/gloun-resnext
title: (Gluon) ResNeXt
- local: models/gloun-senet
title: (Gluon) SENet
- local: models/gloun-seresnext
title: (Gluon) SE-ResNeXt
- local: models/gloun-xception
title: (Gluon) Xception
- local: models/hrnet
title: HRNet
- local: models/ig-resnext
title: Instagram ResNeXt WSL
- local: models/inception-resnet-v2
title: Inception ResNet v2
- local: models/inception-v3
title: Inception v3
- local: models/inception-v4
title: Inception v4
- local: models/legacy-se-resnet
title: (Legacy) SE-ResNet
- local: models/legacy-se-resnext
title: (Legacy) SE-ResNeXt
- local: models/legacy-senet
title: (Legacy) SENet
- local: models/mixnet
title: MixNet
- local: models/mnasnet
title: MnasNet
- local: models/mobilenet-v2
title: MobileNet v2
- local: models/mobilenet-v3
title: MobileNet v3
- local: models/nasnet
title: NASNet
- local: models/noisy-student
title: Noisy Student (EfficientNet)
- local: models/pnasnet
title: PNASNet
- local: models/regnetx
title: RegNetX
- local: models/regnety
title: RegNetY
- local: models/res2net
title: Res2Net
- local: models/res2next
title: Res2NeXt
- local: models/resnest
title: ResNeSt
- local: models/resnet
title: ResNet
- local: models/resnet-d
title: ResNet-D
- local: models/resnext
title: ResNeXt
- local: models/rexnet
title: RexNet
- local: models/se-resnet
title: SE-ResNet
- local: models/selecsls
title: SelecSLS
- local: models/seresnext
title: SE-ResNeXt
- local: models/skresnet
title: SK-ResNet
- local: models/skresnext
title: SK-ResNeXt
- local: models/spnasnet
title: SPNASNet
- local: models/ssl-resnet
title: SSL ResNet
- local: models/swsl-resnet
title: SWSL ResNet
- local: models/swsl-resnext
title: SWSL ResNeXt
- local: models/tf-efficientnet
title: (Tensorflow) EfficientNet
- local: models/tf-efficientnet-condconv
title: (Tensorflow) EfficientNet CondConv
- local: models/tf-efficientnet-lite
title: (Tensorflow) EfficientNet Lite
- local: models/tf-inception-v3
title: (Tensorflow) Inception v3
- local: models/tf-mixnet
title: (Tensorflow) MixNet
- local: models/tf-mobilenet-v3
title: (Tensorflow) MobileNet v3
- local: models/tresnet
title: TResNet
- local: models/wide-resnet
title: Wide ResNet
- local: models/xception
title: Xception
title: Get started
- local: models/adversarial-inception-v3
title: Adversarial Inception v3
- local: models/advprop
title: AdvProp (EfficientNet)
- local: models/big-transfer
title: Big Transfer (BiT)
- local: models/csp-darknet
title: CSP-DarkNet
- local: models/csp-resnet
title: CSP-ResNet
- local: models/csp-resnext
title: CSP-ResNeXt
- local: models/densenet
title: DenseNet
- local: models/dla
title: Deep Layer Aggregation
- local: models/dpn
title: Dual Path Network (DPN)
- local: models/ecaresnet
title: ECA-ResNet
- local: models/efficientnet
title: EfficientNet
- local: models/efficientnet-pruned
title: EfficientNet (Knapsack Pruned)
- local: models/ensemble-adversarial
title: Ensemble Adversarial Inception ResNet v2
- local: models/ese-vovnet
title: ESE-VoVNet
- local: models/fbnet
title: FBNet
- local: models/gloun-inception-v3
title: (Gluon) Inception v3
- local: models/gloun-resnet
title: (Gluon) ResNet
- local: models/gloun-resnext
title: (Gluon) ResNeXt
- local: models/gloun-senet
title: (Gluon) SENet
- local: models/gloun-seresnext
title: (Gluon) SE-ResNeXt
- local: models/gloun-xception
title: (Gluon) Xception
- local: models/hrnet
title: HRNet
- local: models/ig-resnext
title: Instagram ResNeXt WSL
- local: models/inception-resnet-v2
title: Inception ResNet v2
- local: models/inception-v3
title: Inception v3
- local: models/inception-v4
title: Inception v4
- local: models/legacy-se-resnet
title: (Legacy) SE-ResNet
- local: models/legacy-se-resnext
title: (Legacy) SE-ResNeXt
- local: models/legacy-senet
title: (Legacy) SENet
- local: models/mixnet
title: MixNet
- local: models/mnasnet
title: MnasNet
- local: models/mobilenet-v2
title: MobileNet v2
- local: models/mobilenet-v3
title: MobileNet v3
- local: models/nasnet
title: NASNet
- local: models/noisy-student
title: Noisy Student (EfficientNet)
- local: models/pnasnet
title: PNASNet
- local: models/regnetx
title: RegNetX
- local: models/regnety
title: RegNetY
- local: models/res2net
title: Res2Net
- local: models/res2next
title: Res2NeXt
- local: models/resnest
title: ResNeSt
- local: models/resnet
title: ResNet
- local: models/resnet-d
title: ResNet-D
- local: models/resnext
title: ResNeXt
- local: models/rexnet
title: RexNet
- local: models/se-resnet
title: SE-ResNet
- local: models/selecsls
title: SelecSLS
- local: models/seresnext
title: SE-ResNeXt
- local: models/skresnet
title: SK-ResNet
- local: models/skresnext
title: SK-ResNeXt
- local: models/spnasnet
title: SPNASNet
- local: models/ssl-resnet
title: SSL ResNet
- local: models/swsl-resnet
title: SWSL ResNet
- local: models/swsl-resnext
title: SWSL ResNeXt
- local: models/tf-efficientnet
title: (Tensorflow) EfficientNet
- local: models/tf-efficientnet-condconv
title: (Tensorflow) EfficientNet CondConv
- local: models/tf-efficientnet-lite
title: (Tensorflow) EfficientNet Lite
- local: models/tf-inception-v3
title: (Tensorflow) Inception v3
- local: models/tf-mixnet
title: (Tensorflow) MixNet
- local: models/tf-mobilenet-v3
title: (Tensorflow) MobileNet v3
- local: models/tresnet
title: TResNet
- local: models/wide-resnet
title: Wide ResNet
- local: models/xception
title: Xception
title: Model Pages
isExpanded: false
- sections:
- local: reference/models
title: Models
- local: reference/data
title: Data
- local: reference/optimizers
title: Optimizers
- local: reference/schedulers
title: Learning Rate Schedulers
title: Reference

@ -1,418 +0,0 @@
# Archived Changes
### July 12, 2021
* Add XCiT models from [official facebook impl](https://github.com/facebookresearch/xcit). Contributed by [Alexander Soare](https://github.com/alexander-soare)
### July 5-9, 2021
* Add `efficientnetv2_rw_t` weights, a custom 'tiny' 13.6M param variant that is a bit better than (non NoisyStudent) B3 models. Both faster and better accuracy (at same or lower res)
* top-1 82.34 @ 288x288 and 82.54 @ 320x320
* Add [SAM pretrained](https://arxiv.org/abs/2106.01548) in1k weight for ViT B/16 (`vit_base_patch16_sam_224`) and B/32 (`vit_base_patch32_sam_224`) models.
* Add 'Aggregating Nested Transformer' (NesT) w/ weights converted from official [Flax impl](https://github.com/google-research/nested-transformer). Contributed by [Alexander Soare](https://github.com/alexander-soare).
* `jx_nest_base` - 83.534, `jx_nest_small` - 83.120, `jx_nest_tiny` - 81.426
### June 23, 2021
* Reproduce gMLP model training, `gmlp_s16_224` trained to 79.6 top-1, matching [paper](https://arxiv.org/abs/2105.08050). Hparams for this and other recent MLP training [here](https://gist.github.com/rwightman/d6c264a9001f9167e06c209f630b2cc6)
### June 20, 2021
* Release Vision Transformer 'AugReg' weights from [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270)
* .npz weight loading support added, can load any of the 50K+ weights from the [AugReg series](https://console.cloud.google.com/storage/browser/vit_models/augreg)
* See [example notebook](https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax_augreg.ipynb) from [official impl](https://github.com/google-research/vision_transformer/) for navigating the augreg weights
* Replaced all default weights w/ best AugReg variant (if possible). All AugReg 21k classifiers work.
* Highlights: `vit_large_patch16_384` (87.1 top-1), `vit_large_r50_s32_384` (86.2 top-1), `vit_base_patch16_384` (86.0 top-1)
* `vit_deit_*` renamed to just `deit_*`
* Remove my old small model, replace with DeiT compatible small w/ AugReg weights
* Add 1st training of my `gmixer_24_224` MLP /w GLU, 78.1 top-1 w/ 25M params.
* Add weights from official ResMLP release (https://github.com/facebookresearch/deit)
* Add `eca_nfnet_l2` weights from my 'lightweight' series. 84.7 top-1 at 384x384.
* Add distilled BiT 50x1 student and 152x2 Teacher weights from [Knowledge distillation: A good teacher is patient and consistent](https://arxiv.org/abs/2106.05237)
* NFNets and ResNetV2-BiT models work w/ Pytorch XLA now
* weight standardization uses F.batch_norm instead of std_mean (std_mean wasn't lowered)
* eps values adjusted, will be slight differences but should be quite close
* Improve test coverage and classifier interface of non-conv (vision transformer and mlp) models
* Cleanup a few classifier / flatten details for models w/ conv classifiers or early global pool
* Please report any regressions, this PR touched quite a few models.
### June 8, 2021
* Add first ResMLP weights, trained in PyTorch XLA on TPU-VM w/ my XLA branch. 24 block variant, 79.2 top-1.
* Add ResNet51-Q model w/ pretrained weights at 82.36 top-1.
* NFNet inspired block layout with quad layer stem and no maxpool
* Same param count (35.7M) and throughput as ResNetRS-50 but +1.5 top-1 @ 224x224 and +2.5 top-1 at 288x288
### May 25, 2021
* Add LeViT, Visformer, Convit (PR by Aman Arora), Twins (PR by paper authors) transformer models
* Cleanup input_size/img_size override handling and testing for all vision transformer models
* Add `efficientnetv2_rw_m` model and weights (started training before official code). 84.8 top-1, 53M params.
### May 14, 2021
* Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl.
* 1k trained variants: `tf_efficientnetv2_s/m/l`
* 21k trained variants: `tf_efficientnetv2_s/m/l_in21k`
* 21k pretrained -> 1k fine-tuned: `tf_efficientnetv2_s/m/l_in21ft1k`
* v2 models w/ v1 scaling: `tf_efficientnetv2_b0` through `b3`
* Rename my prev V2 guess `efficientnet_v2s` -> `efficientnetv2_rw_s`
* Some blank `efficientnetv2_*` models in-place for future native PyTorch training
### May 5, 2021
* Add MLP-Mixer models and port pretrained weights from [Google JAX impl](https://github.com/google-research/vision_transformer/tree/linen)
* Add CaiT models and pretrained weights from [FB](https://github.com/facebookresearch/deit)
* Add ResNet-RS models and weights from [TF](https://github.com/tensorflow/tpu/tree/master/models/official/resnet/resnet_rs). Thanks [Aman Arora](https://github.com/amaarora)
* Add CoaT models and weights. Thanks [Mohammed Rizin](https://github.com/morizin)
* Add new ImageNet-21k weights & finetuned weights for TResNet, MobileNet-V3, ViT models. Thanks [mrT](https://github.com/mrT23)
* Add GhostNet models and weights. Thanks [Kai Han](https://github.com/iamhankai)
* Update ByoaNet attention modles
* Improve SA module inits
* Hack together experimental stand-alone Swin based attn module and `swinnet`
* Consistent '26t' model defs for experiments.
* Add improved Efficientnet-V2S (prelim model def) weights. 83.8 top-1.
* WandB logging support
### April 13, 2021
* Add Swin Transformer models and weights from https://github.com/microsoft/Swin-Transformer
### April 12, 2021
* Add ECA-NFNet-L1 (slimmed down F1 w/ SiLU, 41M params) trained with this code. 84% top-1 @ 320x320. Trained at 256x256.
* Add EfficientNet-V2S model (unverified model definition) weights. 83.3 top-1 @ 288x288. Only trained single res 224. Working on progressive training.
* Add ByoaNet model definition (Bring-your-own-attention) w/ SelfAttention block and corresponding SA/SA-like modules and model defs
* Lambda Networks - https://arxiv.org/abs/2102.08602
* Bottleneck Transformers - https://arxiv.org/abs/2101.11605
* Halo Nets - https://arxiv.org/abs/2103.12731
* Adabelief optimizer contributed by Juntang Zhuang
### April 1, 2021
* Add snazzy `benchmark.py` script for bulk `timm` model benchmarking of train and/or inference
* Add Pooling-based Vision Transformer (PiT) models (from https://github.com/naver-ai/pit)
* Merged distilled variant into main for torchscript compatibility
* Some `timm` cleanup/style tweaks and weights have hub download support
* Cleanup Vision Transformer (ViT) models
* Merge distilled (DeiT) model into main so that torchscript can work
* Support updated weight init (defaults to old still) that closer matches original JAX impl (possibly better training from scratch)
* Separate hybrid model defs into different file and add several new model defs to fiddle with, support patch_size != 1 for hybrids
* Fix fine-tuning num_class changes (PiT and ViT) and pos_embed resizing (Vit) with distilled variants
* nn.Sequential for block stack (does not break downstream compat)
* TnT (Transformer-in-Transformer) models contributed by author (from https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT)
* Add RegNetY-160 weights from DeiT teacher model
* Add new NFNet-L0 w/ SE attn (rename `nfnet_l0b`->`nfnet_l0`) weights 82.75 top-1 @ 288x288
* Some fixes/improvements for TFDS dataset wrapper
### March 7, 2021
* First 0.4.x PyPi release w/ NFNets (& related), ByoB (GPU-Efficient, RepVGG, etc).
* Change feature extraction for pre-activation nets (NFNets, ResNetV2) to return features before activation.
### Feb 18, 2021
* Add pretrained weights and model variants for NFNet-F* models from [DeepMind Haiku impl](https://github.com/deepmind/deepmind-research/tree/master/nfnets).
* Models are prefixed with `dm_`. They require SAME padding conv, skipinit enabled, and activation gains applied in act fn.
* These models are big, expect to run out of GPU memory. With the GELU activiation + other options, they are roughly 1/2 the inference speed of my SiLU PyTorch optimized `s` variants.
* Original model results are based on pre-processing that is not the same as all other models so you'll see different results in the results csv (once updated).
* Matching the original pre-processing as closely as possible I get these results:
* `dm_nfnet_f6` - 86.352
* `dm_nfnet_f5` - 86.100
* `dm_nfnet_f4` - 85.834
* `dm_nfnet_f3` - 85.676
* `dm_nfnet_f2` - 85.178
* `dm_nfnet_f1` - 84.696
* `dm_nfnet_f0` - 83.464
### Feb 16, 2021
* Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py.
* AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc`
* PyTorch global norm of 1.0 (old behaviour, always norm), `--clip-grad 1.0`
* PyTorch value clipping of 10, `--clip-grad 10. --clip-mode value`
* AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training w/ NFNet/NF-ResNet.
### Feb 12, 2021
* Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs
### Feb 10, 2021
* More model archs, incl a flexible ByobNet backbone ('Bring-your-own-blocks')
* GPU-Efficient-Networks (https://github.com/idstcv/GPU-Efficient-Networks), impl in `byobnet.py`
* RepVGG (https://github.com/DingXiaoH/RepVGG), impl in `byobnet.py`
* classic VGG (from torchvision, impl in `vgg`)
* Refinements to normalizer layer arg handling and normalizer+act layer handling in some models
* Default AMP mode changed to native PyTorch AMP instead of APEX. Issues not being fixed with APEX. Native works with `--channels-last` and `--torchscript` model training, APEX does not.
* Fix a few bugs introduced since last pypi release
### Feb 8, 2021
* Add several ResNet weights with ECA attention. 26t & 50t trained @ 256, test @ 320. 269d train @ 256, fine-tune @320, test @ 352.
* `ecaresnet26t` - 79.88 top-1 @ 320x320, 79.08 @ 256x256
* `ecaresnet50t` - 82.35 top-1 @ 320x320, 81.52 @ 256x256
* `ecaresnet269d` - 84.93 top-1 @ 352x352, 84.87 @ 320x320
* Remove separate tiered (`t`) vs tiered_narrow (`tn`) ResNet model defs, all `tn` changed to `t` and `t` models removed (`seresnext26t_32x4d` only model w/ weights that was removed).
* Support model default_cfgs with separate train vs test resolution `test_input_size` and remove extra `_320` suffix ResNet model defs that were just for test.
### Jan 30, 2021
* Add initial "Normalization Free" NF-RegNet-B* and NF-ResNet model definitions based on [paper](https://arxiv.org/abs/2101.08692)
### Jan 25, 2021
* Add ResNetV2 Big Transfer (BiT) models w/ ImageNet-1k and 21k weights from https://github.com/google-research/big_transfer
* Add official R50+ViT-B/16 hybrid models + weights from https://github.com/google-research/vision_transformer
* ImageNet-21k ViT weights are added w/ model defs and representation layer (pre logits) support
* NOTE: ImageNet-21k classifier heads were zero'd in original weights, they are only useful for transfer learning
* Add model defs and weights for DeiT Vision Transformer models from https://github.com/facebookresearch/deit
* Refactor dataset classes into ImageDataset/IterableImageDataset + dataset specific parser classes
* Add Tensorflow-Datasets (TFDS) wrapper to allow use of TFDS image classification sets with train script
* Ex: `train.py /data/tfds --dataset tfds/oxford_iiit_pet --val-split test --model resnet50 -b 256 --amp --num-classes 37 --opt adamw --lr 3e-4 --weight-decay .001 --pretrained -j 2`
* Add improved .tar dataset parser that reads images from .tar, folder of .tar files, or .tar within .tar
* Run validation on full ImageNet-21k directly from tar w/ BiT model: `validate.py /data/fall11_whole.tar --model resnetv2_50x1_bitm_in21k --amp`
* Models in this update should be stable w/ possible exception of ViT/BiT, possibility of some regressions with train/val scripts and dataset handling
### Jan 3, 2021
* Add SE-ResNet-152D weights
* 256x256 val, 0.94 crop top-1 - 83.75
* 320x320 val, 1.0 crop - 84.36
* Update results files
### Dec 18, 2020
* Add ResNet-101D, ResNet-152D, and ResNet-200D weights trained @ 256x256
* 256x256 val, 0.94 crop (top-1) - 101D (82.33), 152D (83.08), 200D (83.25)
* 288x288 val, 1.0 crop - 101D (82.64), 152D (83.48), 200D (83.76)
* 320x320 val, 1.0 crop - 101D (83.00), 152D (83.66), 200D (84.01)
### Dec 7, 2020
* Simplify EMA module (ModelEmaV2), compatible with fully torchscripted models
* Misc fixes for SiLU ONNX export, default_cfg missing from Feature extraction models, Linear layer w/ AMP + torchscript
* PyPi release @ 0.3.2 (needed by EfficientDet)
### Oct 30, 2020
* Test with PyTorch 1.7 and fix a small top-n metric view vs reshape issue.
* Convert newly added 224x224 Vision Transformer weights from official JAX repo. 81.8 top-1 for B/16, 83.1 L/16.
* Support PyTorch 1.7 optimized, native SiLU (aka Swish) activation. Add mapping to 'silu' name, custom swish will eventually be deprecated.
* Fix regression for loading pretrained classifier via direct model entrypoint functions. Didn't impact create_model() factory usage.
* PyPi release @ 0.3.0 version!
### Oct 26, 2020
* Update Vision Transformer models to be compatible with official code release at https://github.com/google-research/vision_transformer
* Add Vision Transformer weights (ImageNet-21k pretrain) for 384x384 base and large models converted from official jax impl
* ViT-B/16 - 84.2
* ViT-B/32 - 81.7
* ViT-L/16 - 85.2
* ViT-L/32 - 81.5
### Oct 21, 2020
* Weights added for Vision Transformer (ViT) models. 77.86 top-1 for 'small' and 79.35 for 'base'. Thanks to [Christof](https://www.kaggle.com/christofhenkel) for training the base model w/ lots of GPUs.
### Oct 13, 2020
* Initial impl of Vision Transformer models. Both patch and hybrid (CNN backbone) variants. Currently trying to train...
* Adafactor and AdaHessian (FP32 only, no AMP) optimizers
* EdgeTPU-M (`efficientnet_em`) model trained in PyTorch, 79.3 top-1
* Pip release, doc updates pending a few more changes...
### Sept 18, 2020
* New ResNet 'D' weights. 72.7 (top-1) ResNet-18-D, 77.1 ResNet-34-D, 80.5 ResNet-50-D
* Added a few untrained defs for other ResNet models (66D, 101D, 152D, 200/200D)
### Sept 3, 2020
* New weights
* Wide-ResNet50 - 81.5 top-1 (vs 78.5 torchvision)
* SEResNeXt50-32x4d - 81.3 top-1 (vs 79.1 cadene)
* Support for native Torch AMP and channels_last memory format added to train/validate scripts (`--channels-last`, `--native-amp` vs `--apex-amp`)
* Models tested with channels_last on latest NGC 20.08 container. AdaptiveAvgPool in attn layers changed to mean((2,3)) to work around bug with NHWC kernel.
### Aug 12, 2020
* New/updated weights from training experiments
* EfficientNet-B3 - 82.1 top-1 (vs 81.6 for official with AA and 81.9 for AdvProp)
* RegNetY-3.2GF - 82.0 top-1 (78.9 from official ver)
* CSPResNet50 - 79.6 top-1 (76.6 from official ver)
* Add CutMix integrated w/ Mixup. See [pull request](https://github.com/rwightman/pytorch-image-models/pull/218) for some usage examples
* Some fixes for using pretrained weights with `in_chans` != 3 on several models.
### Aug 5, 2020
Universal feature extraction, new models, new weights, new test sets.
* All models support the `features_only=True` argument for `create_model` call to return a network that extracts feature maps from the deepest layer at each stride.
* New models
* CSPResNet, CSPResNeXt, CSPDarkNet, DarkNet
* ReXNet
* (Modified Aligned) Xception41/65/71 (a proper port of TF models)
* New trained weights
* SEResNet50 - 80.3 top-1
* CSPDarkNet53 - 80.1 top-1
* CSPResNeXt50 - 80.0 top-1
* DPN68b - 79.2 top-1
* EfficientNet-Lite0 (non-TF ver) - 75.5 (submitted by [@hal-314](https://github.com/hal-314))
* Add 'real' labels for ImageNet and ImageNet-Renditions test set, see [`results/README.md`](results/README.md)
* Test set ranking/top-n diff script by [@KushajveerSingh](https://github.com/KushajveerSingh)
* Train script and loader/transform tweaks to punch through more aug arguments
* README and documentation overhaul. See initial (WIP) documentation at https://rwightman.github.io/pytorch-image-models/
* adamp and sgdp optimizers added by [@hellbell](https://github.com/hellbell)
### June 11, 2020
Bunch of changes:
* DenseNet models updated with memory efficient addition from torchvision (fixed a bug), blur pooling and deep stem additions
* VoVNet V1 and V2 models added, 39 V2 variant (ese_vovnet_39b) trained to 79.3 top-1
* Activation factory added along with new activations:
* select act at model creation time for more flexibility in using activations compatible with scripting or tracing (ONNX export)
* hard_mish (experimental) added with memory-efficient grad, along with ME hard_swish
* context mgr for setting exportable/scriptable/no_jit states
* Norm + Activation combo layers added with initial trial support in DenseNet and VoVNet along with impl of EvoNorm and InplaceAbn wrapper that fit the interface
* Torchscript works for all but two of the model types as long as using Pytorch 1.5+, tests added for this
* Some import cleanup and classifier reset changes, all models will have classifier reset to nn.Identity on reset_classifer(0) call
* Prep for 0.1.28 pip release
### May 12, 2020
* Add ResNeSt models (code adapted from https://github.com/zhanghang1989/ResNeSt, paper https://arxiv.org/abs/2004.08955))
### May 3, 2020
* Pruned EfficientNet B1, B2, and B3 (https://arxiv.org/abs/2002.08258) contributed by [Yonathan Aflalo](https://github.com/yoniaflalo)
### May 1, 2020
* Merged a number of execellent contributions in the ResNet model family over the past month
* BlurPool2D and resnetblur models initiated by [Chris Ha](https://github.com/VRandme), I trained resnetblur50 to 79.3.
* TResNet models and SpaceToDepth, AntiAliasDownsampleLayer layers by [mrT23](https://github.com/mrT23)
* ecaresnet (50d, 101d, light) models and two pruned variants using pruning as per (https://arxiv.org/abs/2002.08258) by [Yonathan Aflalo](https://github.com/yoniaflalo)
* 200 pretrained models in total now with updated results csv in results folder
### April 5, 2020
* Add some newly trained MobileNet-V2 models trained with latest h-params, rand augment. They compare quite favourably to EfficientNet-Lite
* 3.5M param MobileNet-V2 100 @ 73%
* 4.5M param MobileNet-V2 110d @ 75%
* 6.1M param MobileNet-V2 140 @ 76.5%
* 5.8M param MobileNet-V2 120d @ 77.3%
### March 18, 2020
* Add EfficientNet-Lite models w/ weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite)
* Add RandAugment trained ResNeXt-50 32x4d weights with 79.8 top-1. Trained by [Andrew Lavin](https://github.com/andravin) (see Training section for hparams)
### April 5, 2020
* Add some newly trained MobileNet-V2 models trained with latest h-params, rand augment. They compare quite favourably to EfficientNet-Lite
* 3.5M param MobileNet-V2 100 @ 73%
* 4.5M param MobileNet-V2 110d @ 75%
* 6.1M param MobileNet-V2 140 @ 76.5%
* 5.8M param MobileNet-V2 120d @ 77.3%
### March 18, 2020
* Add EfficientNet-Lite models w/ weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite)
* Add RandAugment trained ResNeXt-50 32x4d weights with 79.8 top-1. Trained by [Andrew Lavin](https://github.com/andravin) (see Training section for hparams)
### Feb 29, 2020
* New MobileNet-V3 Large weights trained from stratch with this code to 75.77% top-1
* IMPORTANT CHANGE - default weight init changed for all MobilenetV3 / EfficientNet / related models
* overall results similar to a bit better training from scratch on a few smaller models tried
* performance early in training seems consistently improved but less difference by end
* set `fix_group_fanout=False` in `_init_weight_goog` fn if you need to reproducte past behaviour
* Experimental LR noise feature added applies a random perturbation to LR each epoch in specified range of training
### Feb 18, 2020
* Big refactor of model layers and addition of several attention mechanisms. Several additions motivated by 'Compounding the Performance Improvements...' (https://arxiv.org/abs/2001.06268):
* Move layer/module impl into `layers` subfolder/module of `models` and organize in a more granular fashion
* ResNet downsample paths now properly support dilation (output stride != 32) for avg_pool ('D' variant) and 3x3 (SENets) networks
* Add Selective Kernel Nets on top of ResNet base, pretrained weights
* skresnet18 - 73% top-1
* skresnet34 - 76.9% top-1
* skresnext50_32x4d (equiv to SKNet50) - 80.2% top-1
* ECA and CECA (circular padding) attention layer contributed by [Chris Ha](https://github.com/VRandme)
* CBAM attention experiment (not the best results so far, may remove)
* Attention factory to allow dynamically selecting one of SE, ECA, CBAM in the `.se` position for all ResNets
* Add DropBlock and DropPath (formerly DropConnect for EfficientNet/MobileNetv3) support to all ResNet variants
* Full dataset results updated that incl NoisyStudent weights and 2 of the 3 SK weights
### Feb 12, 2020
* Add EfficientNet-L2 and B0-B7 NoisyStudent weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet)
### Feb 6, 2020
* Add RandAugment trained EfficientNet-ES (EdgeTPU-Small) weights with 78.1 top-1. Trained by [Andrew Lavin](https://github.com/andravin) (see Training section for hparams)
### Feb 1/2, 2020
* Port new EfficientNet-B8 (RandAugment) weights, these are different than the B8 AdvProp, different input normalization.
* Update results csv files on all models for ImageNet validation and three other test sets
* Push PyPi package update
### Jan 31, 2020
* Update ResNet50 weights with a new 79.038 result from further JSD / AugMix experiments. Full command line for reproduction in training section below.
### Jan 11/12, 2020
* Master may be a bit unstable wrt to training, these changes have been tested but not all combos
* Implementations of AugMix added to existing RA and AA. Including numerous supporting pieces like JSD loss (Jensen-Shannon divergence + CE), and AugMixDataset
* SplitBatchNorm adaptation layer added for implementing Auxiliary BN as per AdvProp paper
* ResNet-50 AugMix trained model w/ 79% top-1 added
* `seresnext26tn_32x4d` - 77.99 top-1, 93.75 top-5 added to tiered experiment, higher img/s than 't' and 'd'
### Jan 3, 2020
* Add RandAugment trained EfficientNet-B0 weight with 77.7 top-1. Trained by [Michael Klachko](https://github.com/michaelklachko) with this code and recent hparams (see Training section)
* Add `avg_checkpoints.py` script for post training weight averaging and update all scripts with header docstrings and shebangs.
### Dec 30, 2019
* Merge [Dushyant Mehta's](https://github.com/mehtadushy) PR for SelecSLS (Selective Short and Long Range Skip Connections) networks. Good GPU memory consumption and throughput. Original: https://github.com/mehtadushy/SelecSLS-Pytorch
### Dec 28, 2019
* Add new model weights and training hparams (see Training Hparams section)
* `efficientnet_b3` - 81.5 top-1, 95.7 top-5 at default res/crop, 81.9, 95.8 at 320x320 1.0 crop-pct
* trained with RandAugment, ended up with an interesting but less than perfect result (see training section)
* `seresnext26d_32x4d`- 77.6 top-1, 93.6 top-5
* deep stem (32, 32, 64), avgpool downsample
* stem/dowsample from bag-of-tricks paper
* `seresnext26t_32x4d`- 78.0 top-1, 93.7 top-5
* deep tiered stem (24, 48, 64), avgpool downsample (a modified 'D' variant)
* stem sizing mods from Jeremy Howard and fastai devs discussing ResNet architecture experiments
### Dec 23, 2019
* Add RandAugment trained MixNet-XL weights with 80.48 top-1.
* `--dist-bn` argument added to train.py, will distribute BN stats between nodes after each train epoch, before eval
### Dec 4, 2019
* Added weights from the first training from scratch of an EfficientNet (B2) with my new RandAugment implementation. Much better than my previous B2 and very close to the official AdvProp ones (80.4 top-1, 95.08 top-5).
### Nov 29, 2019
* Brought EfficientNet and MobileNetV3 up to date with my https://github.com/rwightman/gen-efficientnet-pytorch code. Torchscript and ONNX export compat excluded.
* AdvProp weights added
* Official TF MobileNetv3 weights added
* EfficientNet and MobileNetV3 hook based 'feature extraction' classes added. Will serve as basis for using models as backbones in obj detection/segmentation tasks. Lots more to be done here...
* HRNet classification models and weights added from https://github.com/HRNet/HRNet-Image-Classification
* Consistency in global pooling, `reset_classifer`, and `forward_features` across models
* `forward_features` always returns unpooled feature maps now
* Reasonable chance I broke something... let me know
### Nov 22, 2019
* Add ImageNet training RandAugment implementation alongside AutoAugment. PyTorch Transform compatible format, using PIL. Currently training two EfficientNet models from scratch with promising results... will update.
* `drop-connect` cmd line arg finally added to `train.py`, no need to hack model fns. Works for efficientnet/mobilenetv3 based models, ignored otherwise.

@ -1,187 +0,0 @@
# Recent Changes
### July 27, 2022
* All runtime benchmark and validation result csv files are up-to-date!
* A few more weights & model defs added:
* `darknetaa53` - 79.8 @ 256, 80.5 @ 288
* `convnext_nano` - 80.8 @ 224, 81.5 @ 288
* `cs3sedarknet_l` - 81.2 @ 256, 81.8 @ 288
* `cs3darknet_x` - 81.8 @ 256, 82.2 @ 288
* `cs3sedarknet_x` - 82.2 @ 256, 82.7 @ 288
* `cs3edgenet_x` - 82.2 @ 256, 82.7 @ 288
* `cs3se_edgenet_x` - 82.8 @ 256, 83.5 @ 320
* `cs3*` weights above all trained on TPU w/ `bits_and_tpu` branch. Thanks to TRC program!
* Add output_stride=8 and 16 support to ConvNeXt (dilation)
* deit3 models not being able to resize pos_emb fixed
* Version 0.6.7 PyPi release (/w above bug fixes and new weighs since 0.6.5)
### July 8, 2022
More models, more fixes
* Official research models (w/ weights) added:
* EdgeNeXt from (https://github.com/mmaaz60/EdgeNeXt)
* MobileViT-V2 from (https://github.com/apple/ml-cvnets)
* DeiT III (Revenge of the ViT) from (https://github.com/facebookresearch/deit)
* My own models:
* Small `ResNet` defs added by request with 1 block repeats for both basic and bottleneck (resnet10 and resnet14)
* `CspNet` refactored with dataclass config, simplified CrossStage3 (`cs3`) option. These are closer to YOLO-v5+ backbone defs.
* More relative position vit fiddling. Two `srelpos` (shared relative position) models trained, and a medium w/ class token.
* Add an alternate downsample mode to EdgeNeXt and train a `small` model. Better than original small, but not their new USI trained weights.
* My own model weight results (all ImageNet-1k training)
* `resnet10t` - 66.5 @ 176, 68.3 @ 224
* `resnet14t` - 71.3 @ 176, 72.3 @ 224
* `resnetaa50` - 80.6 @ 224 , 81.6 @ 288
* `darknet53` - 80.0 @ 256, 80.5 @ 288
* `cs3darknet_m` - 77.0 @ 256, 77.6 @ 288
* `cs3darknet_focus_m` - 76.7 @ 256, 77.3 @ 288
* `cs3darknet_l` - 80.4 @ 256, 80.9 @ 288
* `cs3darknet_focus_l` - 80.3 @ 256, 80.9 @ 288
* `vit_srelpos_small_patch16_224` - 81.1 @ 224, 82.1 @ 320
* `vit_srelpos_medium_patch16_224` - 82.3 @ 224, 83.1 @ 320
* `vit_relpos_small_patch16_cls_224` - 82.6 @ 224, 83.6 @ 320
* `edgnext_small_rw` - 79.6 @ 224, 80.4 @ 320
* `cs3`, `darknet`, and `vit_*relpos` weights above all trained on TPU thanks to TRC program! Rest trained on overheating GPUs.
* Hugging Face Hub support fixes verified, demo notebook TBA
* Pretrained weights / configs can be loaded externally (ie from local disk) w/ support for head adaptation.
* Add support to change image extensions scanned by `timm` datasets/parsers. See (https://github.com/rwightman/pytorch-image-models/pull/1274#issuecomment-1178303103)
* Default ConvNeXt LayerNorm impl to use `F.layer_norm(x.permute(0, 2, 3, 1), ...).permute(0, 3, 1, 2)` via `LayerNorm2d` in all cases.
* a bit slower than previous custom impl on some hardware (ie Ampere w/ CL), but overall fewer regressions across wider HW / PyTorch version ranges.
* previous impl exists as `LayerNormExp2d` in `models/layers/norm.py`
* Numerous bug fixes
* Currently testing for imminent PyPi 0.6.x release
* LeViT pretraining of larger models still a WIP, they don't train well / easily without distillation. Time to add distill support (finally)?
* ImageNet-22k weight training + finetune ongoing, work on multi-weight support (slowly) chugging along (there are a LOT of weights, sigh) ...
### May 13, 2022
* Official Swin-V2 models and weights added from (https://github.com/microsoft/Swin-Transformer). Cleaned up to support torchscript.
* Some refactoring for existing `timm` Swin-V2-CR impl, will likely do a bit more to bring parts closer to official and decide whether to merge some aspects.
* More Vision Transformer relative position / residual post-norm experiments (all trained on TPU thanks to TRC program)
* `vit_relpos_small_patch16_224` - 81.5 @ 224, 82.5 @ 320 -- rel pos, layer scale, no class token, avg pool
* `vit_relpos_medium_patch16_rpn_224` - 82.3 @ 224, 83.1 @ 320 -- rel pos + res-post-norm, no class token, avg pool
* `vit_relpos_medium_patch16_224` - 82.5 @ 224, 83.3 @ 320 -- rel pos, layer scale, no class token, avg pool
* `vit_relpos_base_patch16_gapcls_224` - 82.8 @ 224, 83.9 @ 320 -- rel pos, layer scale, class token, avg pool (by mistake)
* Bring 512 dim, 8-head 'medium' ViT model variant back to life (after using in a pre DeiT 'small' model for first ViT impl back in 2020)
* Add ViT relative position support for switching btw existing impl and some additions in official Swin-V2 impl for future trials
* Sequencer2D impl (https://arxiv.org/abs/2205.01972), added via PR from author (https://github.com/okojoalg)
### May 2, 2022
* Vision Transformer experiments adding Relative Position (Swin-V2 log-coord) (`vision_transformer_relpos.py`) and Residual Post-Norm branches (from Swin-V2) (`vision_transformer*.py`)
* `vit_relpos_base_patch32_plus_rpn_256` - 79.5 @ 256, 80.6 @ 320 -- rel pos + extended width + res-post-norm, no class token, avg pool
* `vit_relpos_base_patch16_224` - 82.5 @ 224, 83.6 @ 320 -- rel pos, layer scale, no class token, avg pool
* `vit_base_patch16_rpn_224` - 82.3 @ 224 -- rel pos + res-post-norm, no class token, avg pool
* Vision Transformer refactor to remove representation layer that was only used in initial vit and rarely used since with newer pretrain (ie `How to Train Your ViT`)
* `vit_*` models support removal of class token, use of global average pool, use of fc_norm (ala beit, mae).
### April 22, 2022
* `timm` models are now officially supported in [fast.ai](https://www.fast.ai/)! Just in time for the new Practical Deep Learning course. `timmdocs` documentation link updated to [timm.fast.ai](http://timm.fast.ai/).
* Two more model weights added in the TPU trained [series](https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights). Some In22k pretrain still in progress.
* `seresnext101d_32x8d` - 83.69 @ 224, 84.35 @ 288
* `seresnextaa101d_32x8d` (anti-aliased w/ AvgPool2d) - 83.85 @ 224, 84.57 @ 288
### March 23, 2022
* Add `ParallelBlock` and `LayerScale` option to base vit models to support model configs in [Three things everyone should know about ViT](https://arxiv.org/abs/2203.09795)
* `convnext_tiny_hnf` (head norm first) weights trained with (close to) A2 recipe, 82.2% top-1, could do better with more epochs.
### March 21, 2022
* Merge `norm_norm_norm`. **IMPORTANT** this update for a coming 0.6.x release will likely de-stabilize the master branch for a while. Branch [`0.5.x`](https://github.com/rwightman/pytorch-image-models/tree/0.5.x) or a previous 0.5.x release can be used if stability is required.
* Significant weights update (all TPU trained) as described in this [release](https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights)
* `regnety_040` - 82.3 @ 224, 82.96 @ 288
* `regnety_064` - 83.0 @ 224, 83.65 @ 288
* `regnety_080` - 83.17 @ 224, 83.86 @ 288
* `regnetv_040` - 82.44 @ 224, 83.18 @ 288 (timm pre-act)
* `regnetv_064` - 83.1 @ 224, 83.71 @ 288 (timm pre-act)
* `regnetz_040` - 83.67 @ 256, 84.25 @ 320
* `regnetz_040h` - 83.77 @ 256, 84.5 @ 320 (w/ extra fc in head)
* `resnetv2_50d_gn` - 80.8 @ 224, 81.96 @ 288 (pre-act GroupNorm)
* `resnetv2_50d_evos` 80.77 @ 224, 82.04 @ 288 (pre-act EvoNormS)
* `regnetz_c16_evos` - 81.9 @ 256, 82.64 @ 320 (EvoNormS)
* `regnetz_d8_evos` - 83.42 @ 256, 84.04 @ 320 (EvoNormS)
* `xception41p` - 82 @ 299 (timm pre-act)
* `xception65` - 83.17 @ 299
* `xception65p` - 83.14 @ 299 (timm pre-act)
* `resnext101_64x4d` - 82.46 @ 224, 83.16 @ 288
* `seresnext101_32x8d` - 83.57 @ 224, 84.270 @ 288
* `resnetrs200` - 83.85 @ 256, 84.44 @ 320
* HuggingFace hub support fixed w/ initial groundwork for allowing alternative 'config sources' for pretrained model definitions and weights (generic local file / remote url support soon)
* SwinTransformer-V2 implementation added. Submitted by [Christoph Reich](https://github.com/ChristophReich1996). Training experiments and model changes by myself are ongoing so expect compat breaks.
* Swin-S3 (AutoFormerV2) models / weights added from https://github.com/microsoft/Cream/tree/main/AutoFormerV2
* MobileViT models w/ weights adapted from https://github.com/apple/ml-cvnets
* PoolFormer models w/ weights adapted from https://github.com/sail-sg/poolformer
* VOLO models w/ weights adapted from https://github.com/sail-sg/volo
* Significant work experimenting with non-BatchNorm norm layers such as EvoNorm, FilterResponseNorm, GroupNorm, etc
* Enhance support for alternate norm + act ('NormAct') layers added to a number of models, esp EfficientNet/MobileNetV3, RegNet, and aligned Xception
* Grouped conv support added to EfficientNet family
* Add 'group matching' API to all models to allow grouping model parameters for application of 'layer-wise' LR decay, lr scale added to LR scheduler
* Gradient checkpointing support added to many models
* `forward_head(x, pre_logits=False)` fn added to all models to allow separate calls of `forward_features` + `forward_head`
* All vision transformer and vision MLP models update to return non-pooled / non-token selected features from `foward_features`, for consistency with CNN models, token selection or pooling now applied in `forward_head`
### Feb 2, 2022
* [Chris Hughes](https://github.com/Chris-hughes10) posted an exhaustive run through of `timm` on his blog yesterday. Well worth a read. [Getting Started with PyTorch Image Models (timm): A Practitioners Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055)
* I'm currently prepping to merge the `norm_norm_norm` branch back to master (ver 0.6.x) in next week or so.
* The changes are more extensive than usual and may destabilize and break some model API use (aiming for full backwards compat). So, beware `pip install git+https://github.com/rwightman/pytorch-image-models` installs!
* `0.5.x` releases and a `0.5.x` branch will remain stable with a cherry pick or two until dust clears. Recommend sticking to pypi install for a bit if you want stable.
### Jan 14, 2022
* Version 0.5.4 w/ release to be pushed to pypi. It's been a while since last pypi update and riskier changes will be merged to main branch soon....
* Add ConvNeXT models /w weights from official impl (https://github.com/facebookresearch/ConvNeXt), a few perf tweaks, compatible with timm features
* Tried training a few small (~1.8-3M param) / mobile optimized models, a few are good so far, more on the way...
* `mnasnet_small` - 65.6 top-1
* `mobilenetv2_050` - 65.9
* `lcnet_100/075/050` - 72.1 / 68.8 / 63.1
* `semnasnet_075` - 73
* `fbnetv3_b/d/g` - 79.1 / 79.7 / 82.0
* TinyNet models added by [rsomani95](https://github.com/rsomani95)
* LCNet added via MobileNetV3 architecture
### Nov 22, 2021
* A number of updated weights anew new model defs
* `eca_halonext26ts` - 79.5 @ 256
* `resnet50_gn` (new) - 80.1 @ 224, 81.3 @ 288
* `resnet50` - 80.7 @ 224, 80.9 @ 288 (trained at 176, not replacing current a1 weights as default since these don't scale as well to higher res, [weights](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1h2_176-001a1197.pth))
* `resnext50_32x4d` - 81.1 @ 224, 82.0 @ 288
* `sebotnet33ts_256` (new) - 81.2 @ 224
* `lamhalobotnet50ts_256` - 81.5 @ 256
* `halonet50ts` - 81.7 @ 256
* `halo2botnet50ts_256` - 82.0 @ 256
* `resnet101` - 82.0 @ 224, 82.8 @ 288
* `resnetv2_101` (new) - 82.1 @ 224, 83.0 @ 288
* `resnet152` - 82.8 @ 224, 83.5 @ 288
* `regnetz_d8` (new) - 83.5 @ 256, 84.0 @ 320
* `regnetz_e8` (new) - 84.5 @ 256, 85.0 @ 320
* `vit_base_patch8_224` (85.8 top-1) & `in21k` variant weights added thanks [Martins Bruveris](https://github.com/martinsbruveris)
* Groundwork in for FX feature extraction thanks to [Alexander Soare](https://github.com/alexander-soare)
* models updated for tracing compatibility (almost full support with some distlled transformer exceptions)
### Oct 19, 2021
* ResNet strikes back (https://arxiv.org/abs/2110.00476) weights added, plus any extra training components used. Model weights and some more details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-rsb-weights)
* BCE loss and Repeated Augmentation support for RSB paper
* 4 series of ResNet based attention model experiments being added (implemented across byobnet.py/byoanet.py). These include all sorts of attention, from channel attn like SE, ECA to 2D QKV self-attention layers such as Halo, Bottlneck, Lambda. Details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights)
* Working implementations of the following 2D self-attention modules (likely to be differences from paper or eventual official impl):
* Halo (https://arxiv.org/abs/2103.12731)
* Bottleneck Transformer (https://arxiv.org/abs/2101.11605)
* LambdaNetworks (https://arxiv.org/abs/2102.08602)
* A RegNetZ series of models with some attention experiments (being added to). These do not follow the paper (https://arxiv.org/abs/2103.06877) in any way other than block architecture, details of official models are not available. See more here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights)
* ConvMixer (https://openreview.net/forum?id=TVHS5Y4dNvM), CrossVit (https://arxiv.org/abs/2103.14899), and BeiT (https://arxiv.org/abs/2106.08254) architectures + weights added
* freeze/unfreeze helpers by [Alexander Soare](https://github.com/alexander-soare)
### Aug 18, 2021
* Optimizer bonanza!
* Add LAMB and LARS optimizers, incl trust ratio clipping options. Tweaked to work properly in PyTorch XLA (tested on TPUs w/ `timm bits` [branch](https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/timm/bits))
* Add MADGRAD from FB research w/ a few tweaks (decoupled decay option, step handling that works with PyTorch XLA)
* Some cleanup on all optimizers and factory. No more `.data`, a bit more consistency, unit tests for all!
* SGDP and AdamP still won't work with PyTorch XLA but others should (have yet to test Adabelief, Adafactor, Adahessian myself).
* EfficientNet-V2 XL TF ported weights added, but they don't validate well in PyTorch (L is better). The pre-processing for the V2 TF training is a bit diff and the fine-tuned 21k -> 1k weights are very sensitive and less robust than the 1k weights.
* Added PyTorch trained EfficientNet-V2 'Tiny' w/ GlobalContext attn weights. Only .1-.2 top-1 better than the SE so more of a curiosity for those interested.

@ -0,0 +1,54 @@
# Sharing and Loading Models From the Hugging Face Hub
The `timm` library has a built-in integration with the Hugging Face Hub, making it easy to share and load models from the 🤗 Hub.
In this short guide, we'll see how to:
1. Share a `timm` model on the Hub
2. How to load that model back from the Hub
## Authenticating
First, you'll need to make sure you have the `huggingface_hub` package installed.
```bash
pip install huggingface_hub
```
Then, you'll need to authenticate yourself. You can do this by running the following command:
```bash
huggingface-cli login
```
Or, if you're using a notebook, you can use the `notebook_login` helper:
```py
>>> from huggingface_hub import notebook_login
>>> notebook_login()
```
## Sharing a Model
```py
>>> import timm
>>> model = timm.create_model('resnet18', pretrained=True, num_classes=4)
```
Here is where you would normally train or fine-tune the model. We'll skip that for the sake of this tutorial.
Let's pretend we've now fine-tuned the model. The next step would be to push it to the Hub! We can do this with the `timm.models.hub.push_to_hf_hub` function.
```py
>>> model_cfg = dict(labels=['a', 'b', 'c', 'd'])
>>> timm.models.hub.push_to_hf_hub(model, 'resnet18-random', model_config=model_cfg)
```
Running the above would push the model to `<your-username>/resnet18-random` on the Hub. You can now share this model with your friends, or use it in your own code!
## Loading a Model
Loading a model from the Hub is as simple as calling `timm.create_model` with the `pretrained` argument set to the name of the model you want to load. In this case, we'll use [`nateraw/resnet18-random`](https://huggingface.co/nateraw/resnet18-random), which is the model we just pushed to the Hub.
```py
>>> model_reloaded = timm.create_model('hf_hub:nateraw/resnet18-random', pretrained=True)
```

@ -1,89 +1,22 @@
# Getting Started
# timm
## Welcome
<img class="float-left !m-0 !border-0 !dark:border-0 !shadow-none !max-w-lg w-[150px]" src="https://huggingface.co/front/thumbnails/docs/timm.png"/>
Welcome to the `timm` documentation, a lean set of docs that covers the basics of `timm`.
`timm` is a library containing SOTA computer vision models, layers, utilities, optimizers, schedulers, data-loaders, augmentations, and training/evaluation scripts.
For a more comprehensive set of docs (currently under development), please visit [timmdocs](http://timm.fast.ai) by [Aman Arora](https://github.com/amaarora).
It comes packaged with >700 pretrained models, and is designed to be flexible and easy to use.
## Install
Read the [quick start guide](quickstart) to get up and running with the `timm` library. You will learn how to load, discover, and use pretrained models included in the library.
The library can be installed with pip:
```
pip install timm
```
I update the PyPi (pip) packages when I'm confident there are no significant model regressions from previous releases. If you want to pip install the bleeding edge from GitHub, use:
```
pip install git+https://github.com/rwightman/pytorch-image-models.git
```
### Conda Environment
<Tip>
- All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically 3.7, 3.8, 3.9, 3.10
- Little to no care has been taken to be Python 2.x friendly and will not support it. If you run into any challenges running on Windows, or other OS, I'm definitely open to looking into those issues so long as it's in a reproducible (read Conda) environment.
- PyTorch versions 1.9, 1.10, 1.11 have been tested with the latest versions of this code.
</Tip>
I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda:
```bash
conda create -n torch-env
conda activate torch-env
conda install pytorch torchvision cudatoolkit=11.3 -c pytorch
conda install pyyaml
```
## Load a Pretrained Model
Pretrained models can be loaded using `timm.create_model`
```py
>>> import timm
>>> m = timm.create_model('mobilenetv3_large_100', pretrained=True)
>>> m.eval()
```
## List Models with Pretrained Weights
```py
>>> import timm
>>> from pprint import pprint
>>> model_names = timm.list_models(pretrained=True)
>>> pprint(model_names)
[
'adv_inception_v3',
'cspdarknet53',
'cspresnext50',
'densenet121',
'densenet161',
'densenet169',
'densenet201',
'densenetblur121d',
'dla34',
'dla46_c',
]
```
## List Model Architectures by Wildcard
```py
>>> import timm
>>> from pprint import pprint
>>> model_names = timm.list_models('*resne*t*')
>>> pprint(model_names)
[
'cspresnet50',
'cspresnet50d',
'cspresnet50w',
'cspresnext50',
...
]
```
<div class="mt-10">
<div class="w-full flex flex-col space-y-4 md:space-y-0 md:grid md:grid-cols-2 md:gap-y-4 md:gap-x-5">
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="./feature_extraction"
><div class="w-full text-center bg-gradient-to-br from-blue-400 to-blue-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed">Tutorials</div>
<p class="text-gray-700">Learn the basics and become familiar with timm. Start here if you are using timm for the first time!</p>
</a>
<a class="!no-underline border dark:border-gray-700 p-5 rounded-lg shadow hover:shadow-lg" href="./reference/models"
><div class="w-full text-center bg-gradient-to-br from-purple-400 to-purple-500 rounded-lg py-1.5 font-semibold mb-5 text-white text-lg leading-relaxed">Reference</div>
<p class="text-gray-700">Technical descriptions of how timm classes and methods work.</p>
</a>
</div>
</div>

@ -0,0 +1,74 @@
# Installation
Before you start, you'll need to setup your environment and install the appropriate packages. `timm` is tested on **Python 3+**.
## Virtual Environment
You should install `timm` in a [virtual environment](https://docs.python.org/3/library/venv.html) to keep things tidy and avoid dependency conflicts.
1. Create and navigate to your project directory:
```bash
mkdir ~/my-project
cd ~/my-project
```
2. Start a virtual environment inside your directory:
```bash
python -m venv .env
```
3. Activate and deactivate the virtual environment with the following commands:
```bash
# Activate the virtual environment
source .env/bin/activate
# Deactivate the virtual environment
source .env/bin/deactivate
```
`
Once you've created your virtual environment, you can install `timm` in it.
## Using pip
The most straightforward way to install `timm` is with pip:
```bash
pip install timm
```
Alternatively, you can install `timm` from GitHub directly to get the latest, bleeding-edge version:
```bash
pip install git+https://github.com/rwightman/pytorch-image-models.git
```
Run the following command to check if `timm` has been properly installed:
```bash
python -c "from timm import list_models; print(list_models(pretrained=True)[:5])"
```
This command lists the first five pretrained models available in `timm` (which are sorted alphebetically). You should see the following output:
```python
['adv_inception_v3', 'bat_resnext26ts', 'beit_base_patch16_224', 'beit_base_patch16_224_in22k', 'beit_base_patch16_384']
```
## From Source
Building `timm` from source lets you make changes to the code base. To install from the source, clone the repository and install with the following commands:
```bash
git clone https://github.com/rwightman/pytorch-image-models.git
cd timm
pip install -e .
```
Again, you can check if `timm` was properly installed with the following command:
```bash
python -c "from timm import list_models; print(list_models(pretrained=True)[:5])"
```

@ -1,5 +0,0 @@
# Available Models
`timm` comes bundled with a number of model architectures and corresponding pretrained models.
In these pages, you will find the models available in the `timm` library, as well as information on how to use them.

@ -0,0 +1,228 @@
# Quickstart
This quickstart is intended for developers who are ready to dive into the code and see an example of how to integrate `timm` into their model training workflow.
First, you'll need to install `timm`. For more information on installation, see [Installation](installation).
```bash
pip install timm
```
## Load a Pretrained Model
Pretrained models can be loaded using [`create_model`].
Here, we load the pretrained `mobilenetv3_large_100` model.
```py
>>> import timm
>>> m = timm.create_model('mobilenetv3_large_100', pretrained=True)
>>> m.eval()
```
<Tip>
Note: The returned PyTorch model is set to train mode by default, so you must call .eval() on it if you plan to use it for inference.
</Tip>
## List Models with Pretrained Weights
To list models packaged with `timm`, you can use [`list_models`]. If you specify `pretrained=True`, this function will only return model names that have associated pretrained weights available.
```py
>>> import timm
>>> from pprint import pprint
>>> model_names = timm.list_models(pretrained=True)
>>> pprint(model_names)
[
'adv_inception_v3',
'cspdarknet53',
'cspresnext50',
'densenet121',
'densenet161',
'densenet169',
'densenet201',
'densenetblur121d',
'dla34',
'dla46_c',
]
```
You can also list models with a specific pattern in their name.
```py
>>> import timm
>>> from pprint import pprint
>>> model_names = timm.list_models('*resne*t*')
>>> pprint(model_names)
[
'cspresnet50',
'cspresnet50d',
'cspresnet50w',
'cspresnext50',
...
]
```
## Fine-Tune a Pretrained Model
You can finetune any of the pre-trained models just by changing the classifier (the last layer).
```py
>>> model = timm.create_model('mobilenetv3_large_100', pretrained=True, num_classes=NUM_FINETUNE_CLASSES)
```
To fine-tune on your own dataset, you have to write a PyTorch training loop or adapt `timm`'s [training script](training_script) to use your dataset.
## Use a Pretrained Model for Feature Extraction
Without modifying the network, one can call model.forward_features(input) on any model instead of the usual model(input). This will bypass the head classifier and global pooling for networks.
For a more in depth guide to using `timm` for feature extraction, see [Feature Extraction](feature_extraction).
```py
>>> import timm
>>> import torch
>>> x = torch.randn(1, 3, 224, 224)
>>> model = timm.create_model('mobilenetv3_large_100', pretrained=True)
>>> features = model.forward_features(x)
>>> print(features.shape)
torch.Size([1, 960, 7, 7])
```
## Image Augmentation
To transform images into valid inputs for a model, you can use [`timm.data.create_transform`], providing the desired `input_size` that the model expects.
This will return a generic transform that uses reasonable defaults.
```py
>>> timm.data.create_transform((3, 224, 224))
Compose(
Resize(size=256, interpolation=bilinear, max_size=None, antialias=None)
CenterCrop(size=(224, 224))
ToTensor()
Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)
```
Pretrained models have specific transforms that were applied to images fed into them while training. If you use the wrong transform on your image, the model won't understand what it's seeing!
To figure out which transformations were used for a given pretrained model, we can start by taking a look at its `pretrained_cfg`
```py
>>> model.pretrained_cfg
{'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth',
'num_classes': 1000,
'input_size': (3, 224, 224),
'pool_size': (7, 7),
'crop_pct': 0.875,
'interpolation': 'bicubic',
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
'first_conv': 'conv_stem',
'classifier': 'classifier',
'architecture': 'mobilenetv3_large_100'}
```
We can then resolve only the data related configuration by using [`timm.data.resolve_data_config`].
```py
>>> timm.data.resolve_data_config(model.pretrained_cfg)
{'input_size': (3, 224, 224),
'interpolation': 'bicubic',
'mean': (0.485, 0.456, 0.406),
'std': (0.229, 0.224, 0.225),
'crop_pct': 0.875}
```
We can pass this data config to [`timm.data.create_transform`] to initialize the model's associated transform.
```py
>>> data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
>>> transform = timm.data.create_transform(**data_cfg)
>>> transform
Compose(
Resize(size=256, interpolation=bicubic, max_size=None, antialias=None)
CenterCrop(size=(224, 224))
ToTensor()
Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
)
```
<Tip>
Note: Here, the pretrained model's config happens to be the same as the generic config we made earlier. This is not always the case. So, it's safer to use the data config to create the transform as we did here instead of using the generic transform.
</Tip>
## Using Pretrained Models for Inference
Here, we will put together the above sections and use a pretrained model for inference.
First we'll need an image to do inference on. Here we load a picture of a leaf from the web:
```py
>>> import requests
>>> from PIL import Image
>>> from io import BytesIO
>>> url = 'https://datasets-server.huggingface.co/assets/imagenet-1k/--/default/test/12/image/image.jpg'
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> image
```
Here's the image we loaded:
<img src="https://datasets-server.huggingface.co/assets/imagenet-1k/--/default/test/12/image/image.jpg" alt="An Image from a link" width="300"/>
Now, we'll create our model and transforms again. This time, we make sure to set our model in evaluation mode.
```py
>>> model = timm.create_model('mobilenetv3_large_100', pretrained=True).eval()
>>> transform = timm.data.create_transform(
**timm.data.resolve_data_config(model.pretrained_cfg)
)
```
We can prepare this image for the model by passing it to the transform.
```py
>>> image_tensor = transform(image)
>>> image_tensor.shape
torch.Size([3, 224, 224])
```
Now we can pass that image to the model to get the predictions. We use `unsqueeze(0)` in this case, as the model is expecting a batch dimension.
```py
>>> output = model(image_tensor.unsqueeze(0))
>>> output.shape
torch.Size([1, 1000])
```
To get the predicted probabilities, we apply softmax to the output. This leaves us with a tensor of shape `(num_classes,)`.
```py
>>> probabilities = torch.nn.functional.softmax(output[0], dim=0)
>>> probabilities.shape
torch.Size([1000])
```
Now we'll find the top 5 predicted class indexes and values using `torch.topk`.
```py
>>> values, indices = torch.topk(probabilities, 5)
>>> indices
tensor([162, 166, 161, 164, 167])
```
If we check the imagenet labels for the top index, we can see what the model predicted...
```py
>>> IMAGENET_1k_URL = 'https://storage.googleapis.com/bit_models/ilsvrc2012_wordnet_lemmas.txt'
>>> IMAGENET_1k_LABELS = requests.get(IMAGENET_1k_URL).text.strip().split('\n')
>>> [{'label': IMAGENET_1k_LABELS[idx], 'value': val.item()} for val, idx in zip(values, indices)]
[{'label': 'beagle', 'value': 0.8486220836639404},
{'label': 'Walker_hound, Walker_foxhound', 'value': 0.03753996267914772},
{'label': 'basset, basset_hound', 'value': 0.024628572165966034},
{'label': 'bluetick', 'value': 0.010317106731235981},
{'label': 'English_foxhound', 'value': 0.006958036217838526}]
```

@ -0,0 +1,9 @@
# Data
[[autodoc]] timm.data.create_dataset
[[autodoc]] timm.data.create_loader
[[autodoc]] timm.data.create_transform
[[autodoc]] timm.data.resolve_data_config

@ -0,0 +1,5 @@
# Models
[[autodoc]] timm.create_model
[[autodoc]] timm.list_models

@ -0,0 +1,27 @@
# Optimization
This page contains the API reference documentation for learning rate optimizers included in `timm`.
## Optimizers
### Factory functions
[[autodoc]] timm.optim.optim_factory.create_optimizer
[[autodoc]] timm.optim.optim_factory.create_optimizer_v2
### Optimizer Classes
[[autodoc]] timm.optim.adabelief.AdaBelief
[[autodoc]] timm.optim.adafactor.Adafactor
[[autodoc]] timm.optim.adahessian.Adahessian
[[autodoc]] timm.optim.adamp.AdamP
[[autodoc]] timm.optim.adamw.AdamW
[[autodoc]] timm.optim.lamb.Lamb
[[autodoc]] timm.optim.lars.Lars
[[autodoc]] timm.optim.lookahead.Lookahead
[[autodoc]] timm.optim.madgrad.MADGRAD
[[autodoc]] timm.optim.nadam.Nadam
[[autodoc]] timm.optim.nvnovograd.NvNovoGrad
[[autodoc]] timm.optim.radam.RAdam
[[autodoc]] timm.optim.rmsprop_tf.RMSpropTF
[[autodoc]] timm.optim.sgdp.SGDP

@ -0,0 +1,19 @@
# Learning Rate Schedulers
This page contains the API reference documentation for learning rate schedulers included in `timm`.
## Schedulers
### Factory functions
[[autodoc]] timm.scheduler.scheduler_factory.create_scheduler
[[autodoc]] timm.scheduler.scheduler_factory.create_scheduler_v2
### Scheduler Classes
[[autodoc]] timm.scheduler.cosine_lr.CosineLRScheduler
[[autodoc]] timm.scheduler.multistep_lr.MultiStepLRScheduler
[[autodoc]] timm.scheduler.plateau_lr.PlateauLRScheduler
[[autodoc]] timm.scheduler.poly_lr.PolyLRScheduler
[[autodoc]] timm.scheduler.step_lr.StepLRScheduler
[[autodoc]] timm.scheduler.tanh_lr.TanhLRScheduler

@ -1,35 +0,0 @@
# Scripts
A train, validation, inference, and checkpoint cleaning script included in the github root folder. Scripts are not currently packaged in the pip release.
The training and validation scripts evolved from early versions of the [PyTorch Imagenet Examples](https://github.com/pytorch/examples). I have added significant functionality over time, including CUDA specific performance enhancements based on
[NVIDIA's APEX Examples](https://github.com/NVIDIA/apex/tree/master/examples).
## Training Script
The variety of training args is large and not all combinations of options (or even options) have been fully tested. For the training dataset folder, specify the folder to the base that contains a `train` and `validation` folder.
To train an SE-ResNet34 on ImageNet, locally distributed, 4 GPUs, one process per GPU w/ cosine schedule, random-erasing prob of 50% and per-pixel random value:
```bash
./distributed_train.sh 4 /data/imagenet --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 --amp -j 4
```
<Tip>
It is recommended to use PyTorch 1.9+ w/ PyTorch native AMP and DDP instead of APEX AMP. --amp defaults to native AMP as of timm ver 0.4.3. --apex-amp will force use of APEX components if they are installed.
</Tip>
## Validation / Inference Scripts
Validation and inference scripts are similar in usage. One outputs metrics on a validation set and the other outputs topk class ids in a csv. Specify the folder containing validation images, not the base as in training script.
To validate with the model's pretrained weights (if they exist):
```bash
python validate.py /imagenet/validation/ --model seresnext26_32x4d --pretrained
```
To run inference from a checkpoint:
```bash
python inference.py /imagenet/validation/ --model mobilenetv3_large_100 --checkpoint ./output/train/model_best.pth.tar
```

@ -1,6 +1,44 @@
# Training Examples
# Scripts
## EfficientNet-B2 with RandAugment - 80.4 top-1, 95.1 top-5
A train, validation, inference, and checkpoint cleaning script included in the github root folder. Scripts are not currently packaged in the pip release.
The training and validation scripts evolved from early versions of the [PyTorch Imagenet Examples](https://github.com/pytorch/examples). I have added significant functionality over time, including CUDA specific performance enhancements based on
[NVIDIA's APEX Examples](https://github.com/NVIDIA/apex/tree/master/examples).
## Training Script
The variety of training args is large and not all combinations of options (or even options) have been fully tested. For the training dataset folder, specify the folder to the base that contains a `train` and `validation` folder.
To train an SE-ResNet34 on ImageNet, locally distributed, 4 GPUs, one process per GPU w/ cosine schedule, random-erasing prob of 50% and per-pixel random value:
```bash
./distributed_train.sh 4 /data/imagenet --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 --amp -j 4
```
<Tip>
It is recommended to use PyTorch 1.9+ w/ PyTorch native AMP and DDP instead of APEX AMP. --amp defaults to native AMP as of timm ver 0.4.3. --apex-amp will force use of APEX components if they are installed.
</Tip>
## Validation / Inference Scripts
Validation and inference scripts are similar in usage. One outputs metrics on a validation set and the other outputs topk class ids in a csv. Specify the folder containing validation images, not the base as in training script.
To validate with the model's pretrained weights (if they exist):
```bash
python validate.py /imagenet/validation/ --model seresnext26_32x4d --pretrained
```
To run inference from a checkpoint:
```bash
python inference.py /imagenet/validation/ --model mobilenetv3_large_100 --checkpoint ./output/train/model_best.pth.tar
```
## Training Examples
### EfficientNet-B2 with RandAugment - 80.4 top-1, 95.1 top-5
These params are for dual Titan RTX cards with NVIDIA Apex installed:
@ -8,7 +46,7 @@ These params are for dual Titan RTX cards with NVIDIA Apex installed:
./distributed_train.sh 2 /imagenet/ --model efficientnet_b2 -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .016
```
## MixNet-XL with RandAugment - 80.5 top-1, 94.9 top-5
### MixNet-XL with RandAugment - 80.5 top-1, 94.9 top-5
This params are for dual Titan RTX cards with NVIDIA Apex installed:
@ -16,45 +54,45 @@ This params are for dual Titan RTX cards with NVIDIA Apex installed:
./distributed_train.sh 2 /imagenet/ --model mixnet_xl -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .969 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.3 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.3 --amp --lr .016 --dist-bn reduce
```
## SE-ResNeXt-26-D and SE-ResNeXt-26-T
### SE-ResNeXt-26-D and SE-ResNeXt-26-T
These hparams (or similar) work well for a wide range of ResNet architecture, generally a good idea to increase the epoch # as the model size increases... ie approx 180-200 for ResNe(X)t50, and 220+ for larger. Increase batch size and LR proportionally for better GPUs or with AMP enabled. These params were for 2 1080Ti cards:
```bash
./distributed_train.sh 2 /imagenet/ --model seresnext26t_32x4d --lr 0.1 --warmup-epochs 5 --epochs 160 --weight-decay 1e-4 --sched cosine --reprob 0.4 --remode pixel -b 112
```
## EfficientNet-B3 with RandAugment - 81.5 top-1, 95.7 top-5
### EfficientNet-B3 with RandAugment - 81.5 top-1, 95.7 top-5
The training of this model started with the same command line as EfficientNet-B2 w/ RA above. After almost three weeks of training the process crashed. The results weren't looking amazing so I resumed the training several times with tweaks to a few params (increase RE prob, decrease rand-aug, increase ema-decay). Nothing looked great. I ended up averaging the best checkpoints from all restarts. The result is mediocre at default res/crop but oddly performs much better with a full image test crop of 1.0.
## EfficientNet-B0 with RandAugment - 77.7 top-1, 95.3 top-5
### EfficientNet-B0 with RandAugment - 77.7 top-1, 95.3 top-5
[Michael Klachko](https://github.com/michaelklachko) achieved these results with the command line for B2 adapted for larger batch size, with the recommended B0 dropout rate of 0.2.
```bash
./distributed_train.sh 2 /imagenet/ --model efficientnet_b0 -b 384 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .048
```
## ResNet50 with JSD loss and RandAugment (clean + 2x RA augs) - 79.04 top-1, 94.39 top-5
### ResNet50 with JSD loss and RandAugment (clean + 2x RA augs) - 79.04 top-1, 94.39 top-5
Trained on two older 1080Ti cards, this took a while. Only slightly, non statistically better ImageNet validation result than my first good AugMix training of 78.99. However, these weights are more robust on tests with ImageNetV2, ImageNet-Sketch, etc. Unlike my first AugMix runs, I've enabled SplitBatchNorm, disabled random erasing on the clean split, and cranked up random erasing prob on the 2 augmented paths.
```bash
./distributed_train.sh 2 /imagenet -b 64 --model resnet50 --sched cosine --epochs 200 --lr 0.05 --amp --remode pixel --reprob 0.6 --aug-splits 3 --aa rand-m9-mstd0.5-inc1 --resplit --split-bn --jsd --dist-bn reduce
```
## EfficientNet-ES (EdgeTPU-Small) with RandAugment - 78.066 top-1, 93.926 top-5
### EfficientNet-ES (EdgeTPU-Small) with RandAugment - 78.066 top-1, 93.926 top-5
Trained by [Andrew Lavin](https://github.com/andravin) with 8 V100 cards. Model EMA was not used, final checkpoint is the average of 8 best checkpoints during training.
```bash
./distributed_train.sh 8 /imagenet --model efficientnet_es -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064
```
## MobileNetV3-Large-100 - 75.766 top-1, 92,542 top-5
### MobileNetV3-Large-100 - 75.766 top-1, 92,542 top-5
```bash
./distributed_train.sh 2 /imagenet/ --model mobilenetv3_large_100 -b 512 --sched step --epochs 600 --decay-epochs 2.4 --decay-rate .973 --opt rmsproptf --opt-eps .001 -j 7 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-path 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064 --lr-noise 0.42 0.9
```
## ResNeXt-50 32x4d w/ RandAugment - 79.762 top-1, 94.60 top-5
### ResNeXt-50 32x4d w/ RandAugment - 79.762 top-1, 94.60 top-5
These params will also work well for SE-ResNeXt-50 and SK-ResNeXt-50 and likely 101. I used them for the SK-ResNeXt-50 32x4d that I trained with 2 GPU using a slightly higher LR per effective batch size (lr=0.18, b=192 per GPU). The cmd line below are tuned for 8 GPU training.

@ -20,7 +20,7 @@ import torch
from timm.data import create_dataset, create_loader, resolve_data_config
from timm.layers import apply_test_time_pool
from timm.models import create_model
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs
try:
from apex import amp
@ -72,6 +72,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--in-chans', type=int, default=None, metavar='N',
help='Image input channels (default: None => 3)')
parser.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--use-train-size', action='store_true', default=False,
@ -110,6 +112,7 @@ parser.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', default=False, action='store_true',
@ -170,12 +173,19 @@ def main():
set_jit_fuser(args.fuser)
# create model
in_chans = 3
if args.in_chans is not None:
in_chans = args.in_chans
elif args.input_size is not None:
in_chans = args.input_size[0]
model = create_model(
args.model,
num_classes=args.num_classes,
in_chans=3,
in_chans=in_chans,
pretrained=args.pretrained,
checkpoint_path=args.checkpoint,
**args.model_kwargs,
)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'

@ -44,3 +44,11 @@ markdown_extensions:
plugins:
- search
- awesome-pages
- redirects:
redirect_maps:
'index.md': 'https://huggingface.co/docs/timm/index'
'models.md': 'https://huggingface.co/docs/timm/models'
'results.md': 'https://huggingface.co/docs/timm/results'
'scripts.md': 'https://huggingface.co/docs/timm/training_script'
'training_hparam_examples.md': 'https://huggingface.co/docs/timm/training_script#training-examples'
'feature_extraction.md': 'https://huggingface.co/docs/timm/feature_extraction'

@ -1,4 +1,5 @@
mkdocs
mkdocs-material
mkdocs-redirects
mdx_truly_sane_lists
mkdocs-awesome-pages-plugin
mkdocs-awesome-pages-plugin

@ -38,7 +38,7 @@ An ImageNet test set of 10,000 images sampled from new images roughly 10 years a
### ImageNet-Adversarial - [`results-imagenet-a.csv`](results-imagenet-a.csv)
A collection of 7500 images covering 200 of the 1000 ImageNet classes. Images are naturally occuring adversarial examples that confuse typical ImageNet classifiers. This is a challenging dataset, your typical ResNet-50 will score 0% top-1.
A collection of 7500 images covering 200 of the 1000 ImageNet classes. Images are naturally occurring adversarial examples that confuse typical ImageNet classifiers. This is a challenging dataset, your typical ResNet-50 will score 0% top-1.
For clean validation with same 200 classes, see [`results-imagenet-a-clean.csv`](results-imagenet-a-clean.csv)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -27,9 +27,7 @@ NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit*',
'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*', 'pvt_v2*', 'mvitv2*', 'gcvit*', 'efficientformer*',
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*'
'coatnet*', 'coatnext*', 'maxvit*', 'maxxvit*', 'eva_*', 'flexivit*'
]
NUM_NON_STD = len(NON_STD_FILTERS)
@ -40,7 +38,7 @@ if 'GITHUB_ACTIONS' in os.environ:
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
'swin*giant*', 'davit*giant', 'davit*huge']
'swin*giant*', 'davit_giant', 'davit_huge', 'convnextv2_huge*']
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'swin*giant*', 'eva_giant*']
else:
EXCLUDE_FILTERS = []
@ -131,7 +129,7 @@ def test_model_backward(model_name, batch_size):
@pytest.mark.timeout(300)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS))
@pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS, include_tags=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_default_cfgs(model_name, batch_size):
"""Run a single forward pass with each model"""
@ -193,7 +191,7 @@ def test_model_default_cfgs(model_name, batch_size):
@pytest.mark.timeout(300)
@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS, exclude_filters=NON_STD_EXCLUDE_FILTERS))
@pytest.mark.parametrize('model_name', list_models(filter=NON_STD_FILTERS, exclude_filters=NON_STD_EXCLUDE_FILTERS, include_tags=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_default_cfgs_non_std(model_name, batch_size):
"""Run a single forward pass with each model"""
@ -306,7 +304,7 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS))
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_FEAT_FILTERS, include_tags=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_features(model_name, batch_size):
"""Run a single forward pass with each model in feature extraction mode"""

@ -741,7 +741,6 @@ class RandAugment:
self.ops = ops
self.num_layers = num_layers
self.choice_weights = choice_weights
print(self.ops, self.choice_weights)
def __call__(self, img):
# no replacement when using weighted choice

@ -151,7 +151,7 @@ def create_dataset(
elif name.startswith('hfds/'):
# NOTE right now, HF datasets default arrow format is a random-access Dataset,
# There will be a IterableDataset variant too, TBD
ds = ImageDataset(root, reader=name, split=split, **kwargs)
ds = ImageDataset(root, reader=name, split=split, class_map=class_map, **kwargs)
elif name.startswith('tfds/'):
ds = IterableImageDataset(
root,

@ -6,7 +6,7 @@ from .reader_image_in_tar import ReaderImageInTar
def create_reader(name, root, split='train', **kwargs):
name = name.lower()
name = name.split('/', 2)
name = name.split('/', 1)
prefix = ''
if len(name) > 1:
prefix = name[0]

@ -13,13 +13,14 @@ try:
except ImportError as e:
print("Please install Hugging Face datasets package `pip install datasets`.")
exit(1)
from .class_map import load_class_map
from .reader import Reader
def get_class_labels(info):
def get_class_labels(info, label_key='label'):
if 'label' not in info.features:
return {}
class_label = info.features['label']
class_label = info.features[label_key]
class_to_idx = {n: class_label.str2int(n) for n in class_label.names}
return class_to_idx
@ -32,6 +33,7 @@ class ReaderHfds(Reader):
name,
split='train',
class_map=None,
label_key='label',
download=False,
):
"""
@ -43,12 +45,17 @@ class ReaderHfds(Reader):
name, # 'name' maps to path arg in hf datasets
split=split,
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
#use_auth_token=True,
)
# leave decode for caller, plus we want easy access to original path names...
self.dataset = self.dataset.cast_column('image', datasets.Image(decode=False))
self.class_to_idx = get_class_labels(self.dataset.info)
self.label_key = label_key
self.remap_class = False
if class_map:
self.class_to_idx = load_class_map(class_map)
self.remap_class = True
else:
self.class_to_idx = get_class_labels(self.dataset.info, self.label_key)
self.split_info = self.dataset.info.splits[split]
self.num_samples = self.split_info.num_examples
@ -60,7 +67,10 @@ class ReaderHfds(Reader):
else:
assert 'path' in image and image['path']
image = open(image['path'], 'rb')
return image, item['label']
label = item[self.label_key]
if self.remap_class:
label = self.class_to_idx[label]
return image, label
def __len__(self):
return len(self.dataset)

@ -1,6 +1,7 @@
from .activations import *
from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
from .blur_pool import BlurPool2d
from .classifier import ClassifierHead, create_classifier
from .cond_conv2d import CondConv2d, get_condconv_initializer
@ -25,13 +26,18 @@ from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible,
from .inplace_abn import InplaceAbn
from .linear import Linear
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
from .padding import get_padding, get_same_padding, pad_same
from .patch_embed import PatchEmbed
from .patch_embed import PatchEmbed, resample_patch_embed
from .pool2d_same import AvgPool2dSame, create_pool2d
from .pos_embed import resample_abs_pos_embed
from .pos_embed_rel import RelPosMlp, RelPosBias, RelPosBiasTf, gen_relative_position_index, gen_relative_log_coords
from .pos_embed_sincos import build_sincos2d_pos_embed, build_fourier_pos_embed, build_rotary_pos_embed, \
FourierEmbed, RotaryEmbedding
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernel
from .separable_conv import SeparableConv2d, SeparableConvNormAct

@ -13,7 +13,7 @@ import torch
import torch.nn as nn
from .helpers import to_2tuple
from .pos_embed import apply_rot_embed, RotaryEmbedding
from .pos_embed_sincos import apply_rot_embed, RotaryEmbedding
from .weight_init import trunc_normal_

@ -0,0 +1,39 @@
""" Global Response Normalization Module
Based on the GRN layer presented in
`ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808
This implementation
* works for both NCHW and NHWC tensor layouts
* uses affine param names matching existing torch norm layers
* slightly improves eager mode performance via fused addcmul
Hacked together by / Copyright 2023 Ross Wightman
"""
import torch
from torch import nn as nn
class GlobalResponseNorm(nn.Module):
""" Global Response Normalization layer
"""
def __init__(self, dim, eps=1e-6, channels_last=True):
super().__init__()
self.eps = eps
if channels_last:
self.spatial_dim = (1, 2)
self.channel_dim = -1
self.wb_shape = (1, 1, 1, -1)
else:
self.spatial_dim = (2, 3)
self.channel_dim = 1
self.wb_shape = (1, -1, 1, 1)
self.weight = nn.Parameter(torch.zeros(dim))
self.bias = nn.Parameter(torch.zeros(dim))
def forward(self, x):
x_g = x.norm(p=2, dim=self.spatial_dim, keepdim=True)
x_n = x_g / (x_g.mean(dim=self.channel_dim, keepdim=True) + self.eps)
return x + torch.addcmul(self.bias.view(self.wb_shape), self.weight.view(self.wb_shape), x * x_n)

@ -10,7 +10,7 @@ import collections.abc
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return x
return tuple(x)
return tuple(repeat(x, n))
return parse

@ -2,25 +2,38 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
from functools import partial
from torch import nn as nn
from .grn import GlobalResponseNorm
from .helpers import to_2tuple
class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
bias=True,
drop=0.,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
@ -36,18 +49,29 @@ class GluMlp(nn.Module):
""" MLP w/ GLU style gating
See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, bias=True, drop=0.):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.Sigmoid,
bias=True,
drop=0.,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
assert hidden_features % 2 == 0
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.chunk_dim = 1 if use_conv else -1
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias[1])
self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def init_weights(self):
@ -58,7 +82,7 @@ class GluMlp(nn.Module):
def forward(self, x):
x = self.fc1(x)
x, gates = x.chunk(2, dim=-1)
x, gates = x.chunk(2, dim=self.chunk_dim)
x = x * self.act(gates)
x = self.drop1(x)
x = self.fc2(x)
@ -70,8 +94,15 @@ class GatedMlp(nn.Module):
""" MLP as used in gMLP
"""
def __init__(
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
gate_layer=None, bias=True, drop=0.):
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
gate_layer=None,
bias=True,
drop=0.,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
@ -104,8 +135,15 @@ class ConvMlp(nn.Module):
""" MLP using 1x1 convs that keeps spatial dims
"""
def __init__(
self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU,
norm_layer=None, bias=True, drop=0.):
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.ReLU,
norm_layer=None,
bias=True,
drop=0.,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
@ -124,3 +162,40 @@ class ConvMlp(nn.Module):
x = self.drop(x)
x = self.fc2(x)
return x
class GlobalResponseNormMlp(nn.Module):
""" MLP w/ Global Response Norm (see grn.py), nn.Linear or 1x1 Conv2d
"""
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
bias=True,
drop=0.,
use_conv=False,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.grn = GlobalResponseNorm(hidden_features, channels_last=not use_conv)
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.grn(x)
x = self.fc2(x)
x = self.drop2(x)
return x

@ -17,6 +17,7 @@ from typing import Union, List, Optional, Any
import torch
from torch import nn as nn
from torch.nn import functional as F
from torchvision.ops.misc import FrozenBatchNorm2d
from .create_act import get_act_layer
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
@ -77,7 +78,7 @@ class BatchNormAct2d(nn.BatchNorm2d):
if self.training and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None
if self.num_batches_tracked is not None: # type: ignore[has-type]
self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type]
self.num_batches_tracked.add_(1) # type: ignore[has-type]
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
@ -169,6 +170,159 @@ def convert_sync_batchnorm(module, process_group=None):
return module_output
class FrozenBatchNormAct2d(torch.nn.Module):
"""
BatchNormAct2d where the batch statistics and the affine parameters are fixed
Args:
num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)``
eps (float): a value added to the denominator for numerical stability. Default: 1e-5
"""
def __init__(
self,
num_features: int,
eps: float = 1e-5,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super().__init__()
self.eps = eps
self.register_buffer("weight", torch.ones(num_features))
self.register_buffer("bias", torch.zeros(num_features))
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features))
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
def _load_from_state_dict(
self,
state_dict: dict,
prefix: str,
local_metadata: dict,
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
scale = w * (rv + self.eps).rsqrt()
bias = b - rm * scale
x = x * scale + bias
x = self.act(self.drop(x))
return x
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps}, act={self.act})"
def freeze_batch_norm_2d(module):
"""
Converts all `BatchNorm2d` and `SyncBatchNorm` or `BatchNormAct2d` and `SyncBatchNormAct2d` layers
of provided module into `FrozenBatchNorm2d` or `FrozenBatchNormAct2d` respectively.
Args:
module (torch.nn.Module): Any PyTorch module.
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
if isinstance(module, (BatchNormAct2d, SyncBatchNormAct)):
res = FrozenBatchNormAct2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
res.drop = module.drop
res.act = module.act
elif isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
res = FrozenBatchNorm2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = freeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def unfreeze_batch_norm_2d(module):
"""
Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance
of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked
recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
if isinstance(module, FrozenBatchNormAct2d):
res = BatchNormAct2d(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
res.drop = module.drop
res.act = module.act
elif isinstance(module, FrozenBatchNorm2d):
res = torch.nn.BatchNorm2d(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = unfreeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def _num_groups(num_channels, num_groups, group_size):
if group_size:
assert num_channels % group_size == 0
@ -179,10 +333,54 @@ def _num_groups(num_channels, num_groups, group_size):
class GroupNormAct(nn.GroupNorm):
# NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
def __init__(
self, num_channels, num_groups=32, eps=1e-5, affine=True, group_size=None,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
self,
num_channels,
num_groups=32,
eps=1e-5,
affine=True,
group_size=None,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super(GroupNormAct, self).__init__(
_num_groups(num_channels, num_groups, group_size), num_channels, eps=eps, affine=affine)
_num_groups(num_channels, num_groups, group_size),
num_channels,
eps=eps,
affine=affine,
)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
self._fast_norm = is_fast_norm()
def forward(self, x):
if self._fast_norm:
x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
x = self.drop(x)
x = self.act(x)
return x
class GroupNorm1Act(nn.GroupNorm):
def __init__(
self,
num_channels,
eps=1e-5,
affine=True,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super(GroupNorm1Act, self).__init__(1, num_channels, eps=eps, affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
@ -204,8 +402,15 @@ class GroupNormAct(nn.GroupNorm):
class LayerNormAct(nn.LayerNorm):
def __init__(
self, normalization_shape: Union[int, List[int], torch.Size], eps=1e-5, affine=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
self,
normalization_shape: Union[int, List[int], torch.Size],
eps=1e-5,
affine=True,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
@ -228,8 +433,15 @@ class LayerNormAct(nn.LayerNorm):
class LayerNormAct2d(nn.LayerNorm):
def __init__(
self, num_channels, eps=1e-5, affine=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
self,
num_channels,
eps=1e-5,
affine=True,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module

@ -2,15 +2,24 @@
A convolution based approach to patchifying a 2D image w/ embedding projection.
Based on the impl in https://github.com/google-research/vision_transformer
Based on code in:
* https://github.com/google-research/vision_transformer
* https://github.com/google-research/big_vision/tree/main/big_vision
Hacked together by / Copyright 2020 Ross Wightman
"""
import logging
from typing import List
import torch
from torch import nn as nn
import torch.nn.functional as F
from .helpers import to_2tuple
from .trace_utils import _assert
_logger = logging.getLogger(__name__)
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
@ -46,3 +55,130 @@ class PatchEmbed(nn.Module):
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x
def resample_patch_embed(
patch_embed,
new_size: List[int],
interpolation: str = 'bicubic',
antialias: bool = True,
verbose: bool = False,
):
"""Resample the weights of the patch embedding kernel to target resolution.
We resample the patch embedding kernel by approximately inverting the effect
of patch resizing.
Code based on:
https://github.com/google-research/big_vision/blob/b00544b81f8694488d5f36295aeb7972f3755ffe/big_vision/models/proj/flexi/vit.py
With this resizing, we can for example load a B/8 filter into a B/16 model
and, on 2x larger input image, the result will match.
Args:
patch_embed: original parameter to be resized.
new_size (tuple(int, int): target shape (height, width)-only.
interpolation (str): interpolation for resize
antialias (bool): use anti-aliasing filter in resize
verbose (bool): log operation
Returns:
Resized patch embedding kernel.
"""
import numpy as np
try:
import functorch
vmap = functorch.vmap
except ImportError:
if hasattr(torch, 'vmap'):
vmap = torch.vmap
else:
assert False, "functorch or a version of torch with vmap is required for FlexiViT resizing."
assert len(patch_embed.shape) == 4, "Four dimensions expected"
assert len(new_size) == 2, "New shape should only be hw"
old_size = patch_embed.shape[-2:]
if tuple(old_size) == tuple(new_size):
return patch_embed
if verbose:
_logger.info(f"Resize patch embedding {patch_embed.shape} to {new_size}, w/ {interpolation} interpolation.")
def resize(x_np, _new_size):
x_tf = torch.Tensor(x_np)[None, None, ...]
x_upsampled = F.interpolate(
x_tf, size=_new_size, mode=interpolation, antialias=antialias)[0, 0, ...].numpy()
return x_upsampled
def get_resize_mat(_old_size, _new_size):
mat = []
for i in range(np.prod(_old_size)):
basis_vec = np.zeros(_old_size)
basis_vec[np.unravel_index(i, _old_size)] = 1.
mat.append(resize(basis_vec, _new_size).reshape(-1))
return np.stack(mat).T
resize_mat = get_resize_mat(old_size, new_size)
resize_mat_pinv = torch.Tensor(np.linalg.pinv(resize_mat.T))
def resample_kernel(kernel):
resampled_kernel = resize_mat_pinv @ kernel.reshape(-1)
return resampled_kernel.reshape(new_size)
v_resample_kernel = vmap(vmap(resample_kernel, 0, 0), 1, 1)
return v_resample_kernel(patch_embed)
# def divs(n, m=None):
# m = m or n // 2
# if m == 1:
# return [1]
# if n % m == 0:
# return [m] + divs(n, m - 1)
# return divs(n, m - 1)
#
#
# class FlexiPatchEmbed(nn.Module):
# """ 2D Image to Patch Embedding w/ Flexible Patch sizes (FlexiViT)
# FIXME WIP
# """
# def __init__(
# self,
# img_size=240,
# patch_size=16,
# in_chans=3,
# embed_dim=768,
# base_img_size=240,
# base_patch_size=32,
# norm_layer=None,
# flatten=True,
# bias=True,
# ):
# super().__init__()
# self.img_size = to_2tuple(img_size)
# self.patch_size = to_2tuple(patch_size)
# self.num_patches = 0
#
# # full range for 240 = (5, 6, 8, 10, 12, 14, 15, 16, 20, 24, 30, 40, 48)
# self.seqhw = (6, 8, 10, 12, 14, 15, 16, 20, 24, 30)
#
# self.base_img_size = to_2tuple(base_img_size)
# self.base_patch_size = to_2tuple(base_patch_size)
# self.base_grid_size = tuple([i // p for i, p in zip(self.base_img_size, self.base_patch_size)])
# self.base_num_patches = self.base_grid_size[0] * self.base_grid_size[1]
#
# self.flatten = flatten
# self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=bias)
# self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
#
# def forward(self, x):
# B, C, H, W = x.shape
#
# if self.patch_size == self.base_patch_size:
# weight = self.proj.weight
# else:
# weight = resample_patch_embed(self.proj.weight, self.patch_size)
# patch_size = self.patch_size
# x = F.conv2d(x, weight, bias=self.proj.bias, stride=patch_size)
# if self.flatten:
# x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
# x = self.norm(x)
# return x

@ -1,207 +1,52 @@
""" Position Embedding Utilities
Hacked together by / Copyright 2022 Ross Wightman
"""
import logging
import math
from typing import List, Tuple, Optional, Union
import torch
from torch import nn as nn
def pixel_freq_bands(
num_bands: int,
max_freq: float = 224.,
linear_bands: bool = True,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
):
if linear_bands:
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device)
else:
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device)
return bands * torch.pi
def inv_freq_bands(
num_bands: int,
temperature: float = 100000.,
step: int = 2,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> torch.Tensor:
inv_freq = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands))
return inv_freq
def build_sincos2d_pos_embed(
feat_shape: List[int],
dim: int = 64,
temperature: float = 10000.,
reverse_coord: bool = False,
interleave_sin_cos: bool = False,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None
) -> torch.Tensor:
"""
Args:
feat_shape:
dim:
temperature:
reverse_coord: stack grid order W, H instead of H, W
interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos
dtype:
device:
Returns:
"""
assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
pos_dim = dim // 4
bands = inv_freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device)
if reverse_coord:
feat_shape = feat_shape[::-1] # stack W, H instead of H, W
grid = torch.stack(
torch.meshgrid([torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1)
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
# FIXME add support for unflattened spatial dim?
stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos
pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1)
return pos_emb
def build_fourier_pos_embed(
feat_shape: List[int],
bands: Optional[torch.Tensor] = None,
num_bands: int = 64,
max_res: int = 224,
linear_bands: bool = False,
include_grid: bool = False,
concat_out: bool = True,
in_pixels: bool = True,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> List[torch.Tensor]:
if bands is None:
if in_pixels:
bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device)
else:
bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device)
else:
if device is None:
device = bands.device
if dtype is None:
dtype = bands.dtype
if in_pixels:
grid = torch.stack(torch.meshgrid(
[torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
else:
grid = torch.stack(torch.meshgrid(
[torch.arange(s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
grid = grid.unsqueeze(-1)
pos = grid * bands
pos_sin, pos_cos = pos.sin(), pos.cos()
out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos)
# FIXME torchscript doesn't like multiple return types, probably need to always cat?
if concat_out:
out = torch.cat(out, dim=-1)
return out
class FourierEmbed(nn.Module):
def __init__(self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False):
super().__init__()
self.max_res = max_res
self.num_bands = num_bands
self.concat_grid = concat_grid
self.keep_spatial = keep_spatial
self.register_buffer('bands', pixel_freq_bands(max_res, num_bands), persistent=False)
def forward(self, x):
B, C = x.shape[:2]
feat_shape = x.shape[2:]
emb = build_fourier_pos_embed(
feat_shape,
self.bands,
include_grid=self.concat_grid,
dtype=x.dtype,
device=x.device)
emb = emb.transpose(-1, -2).flatten(len(feat_shape))
batch_expand = (B,) + (-1,) * (x.ndim - 1)
# FIXME support nD
if self.keep_spatial:
x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1)
else:
x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1)
x = x.reshape(B, feat_shape.numel(), -1)
return x
import torch.nn.functional as F
def rot(x):
return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
from .helpers import to_2tuple
_logger = logging.getLogger(__name__)
def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):
return x * cos_emb + rot(x) * sin_emb
def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
if isinstance(x, torch.Tensor):
x = [x]
return [t * cos_emb + rot(t) * sin_emb for t in x]
def apply_rot_embed_split(x: torch.Tensor, emb):
split = emb.shape[-1] // 2
return x * emb[:, :split] + rot(x) * emb[:, split:]
def build_rotary_pos_embed(
feat_shape: List[int],
bands: Optional[torch.Tensor] = None,
dim: int = 64,
max_freq: float = 224,
linear_bands: bool = False,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
def resample_abs_pos_embed(
posemb,
new_size: List[int],
old_size: Optional[List[int]] = None,
num_prefix_tokens: int = 1,
interpolation: str = 'bicubic',
antialias: bool = True,
verbose: bool = False,
):
"""
NOTE: shape arg should include spatial dim only
"""
feat_shape = torch.Size(feat_shape)
sin_emb, cos_emb = build_fourier_pos_embed(
feat_shape, bands=bands, num_bands=dim // 4, max_res=max_freq, linear_bands=linear_bands,
concat_out=False, device=device, dtype=dtype)
N = feat_shape.numel()
sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1)
cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1)
return sin_emb, cos_emb
class RotaryEmbedding(nn.Module):
""" Rotary position embedding
NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not
been well tested, and will likely change. It will be moved to its own file.
# sort out sizes, assume square if old size not provided
new_size = to_2tuple(new_size)
new_ntok = new_size[0] * new_size[1]
if not old_size:
old_size = int(math.sqrt(posemb.shape[1] - num_prefix_tokens))
old_size = to_2tuple(old_size)
if new_size == old_size: # might not both be same container type
return posemb
if num_prefix_tokens:
posemb_prefix, posemb = posemb[:, :num_prefix_tokens], posemb[:, num_prefix_tokens:]
else:
posemb_prefix, posemb = None, posemb
The following impl/resources were referenced for this impl:
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
* https://blog.eleuther.ai/rotary-embeddings/
"""
def __init__(self, dim, max_res=224, linear_bands: bool = False):
super().__init__()
self.dim = dim
self.register_buffer('bands', pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False)
# do the interpolation
posemb = posemb.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
posemb = F.interpolate(posemb, size=new_size, mode=interpolation, antialias=antialias)
posemb = posemb.permute(0, 2, 3, 1).reshape(1, new_ntok, -1)
def get_embed(self, shape: List[int]):
return build_rotary_pos_embed(shape, self.bands)
if verbose:
_logger.info(f'Resized position embedding: {old_size} to {new_size}.')
def forward(self, x):
# assuming channel-first tensor where spatial dim are >= 2
sin_emb, cos_emb = self.get_embed(x.shape[2:])
return apply_rot_embed(x, sin_emb, cos_emb)
# add back extra (class, etc) prefix tokens
if posemb_prefix is not None:
print(posemb_prefix.shape, posemb.shape)
posemb = torch.cat([posemb_prefix, posemb], dim=1)
return posemb

@ -0,0 +1,283 @@
""" Relative position embedding modules and functions
Hacked together by / Copyright 2022 Ross Wightman
"""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from .mlp import Mlp
from .weight_init import trunc_normal_
def gen_relative_position_index(
q_size: Tuple[int, int],
k_size: Tuple[int, int] = None,
class_token: bool = False) -> torch.Tensor:
# Adapted with significant modifications from Swin / BeiT codebases
# get pair-wise relative position index for each token inside the window
q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww
if k_size is None:
k_coords = q_coords
k_size = q_size
else:
# different q vs k sizes is a WIP
k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1)
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2
_, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0)
if class_token:
# handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
# NOTE not intended or tested with MLP log-coords
max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1]))
num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3
relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = num_relative_distance - 2
relative_position_index[0, 0] = num_relative_distance - 1
return relative_position_index.contiguous()
class RelPosBias(nn.Module):
""" Relative Position Bias
Adapted from Swin-V1 relative position bias impl, modularized.
"""
def __init__(self, window_size, num_heads, prefix_tokens=0):
super().__init__()
assert prefix_tokens <= 1
self.window_size = window_size
self.window_area = window_size[0] * window_size[1]
self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,)
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
self.register_buffer(
"relative_position_index",
gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0),
persistent=False,
)
self.init_weights()
def init_weights(self):
trunc_normal_(self.relative_position_bias_table, std=.02)
def get_bias(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
# win_h * win_w, win_h * win_w, num_heads
relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1)
return relative_position_bias.unsqueeze(0).contiguous()
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
return attn + self.get_bias()
def gen_relative_log_coords(
win_size: Tuple[int, int],
pretrained_win_size: Tuple[int, int] = (0, 0),
mode='swin',
):
assert mode in ('swin', 'cr', 'rw')
# as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2
if mode == 'swin':
if pretrained_win_size[0] > 0:
relative_coords_table[:, :, 0] /= (pretrained_win_size[0] - 1)
relative_coords_table[:, :, 1] /= (pretrained_win_size[1] - 1)
else:
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
relative_coords_table *= 8 # normalize to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs()) / math.log2(8)
else:
if mode == 'rw':
# cr w/ window size normalization -> [-1,1] log coords
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
relative_coords_table *= 8 # scale to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs())
relative_coords_table /= math.log2(9) # -> [-1, 1]
else:
# mode == 'cr'
relative_coords_table = torch.sign(relative_coords_table) * torch.log(
1.0 + relative_coords_table.abs())
return relative_coords_table
class RelPosMlp(nn.Module):
""" Log-Coordinate Relative Position MLP
Based on ideas presented in Swin-V2 paper (https://arxiv.org/abs/2111.09883)
This impl covers the 'swin' implementation as well as two timm specific modes ('cr', and 'rw')
"""
def __init__(
self,
window_size,
num_heads=8,
hidden_dim=128,
prefix_tokens=0,
mode='cr',
pretrained_window_size=(0, 0)
):
super().__init__()
self.window_size = window_size
self.window_area = self.window_size[0] * self.window_size[1]
self.prefix_tokens = prefix_tokens
self.num_heads = num_heads
self.bias_shape = (self.window_area,) * 2 + (num_heads,)
if mode == 'swin':
self.bias_act = nn.Sigmoid()
self.bias_gain = 16
mlp_bias = (True, False)
elif mode == 'rw':
self.bias_act = nn.Tanh()
self.bias_gain = 4
mlp_bias = True
else:
self.bias_act = nn.Identity()
self.bias_gain = None
mlp_bias = True
self.mlp = Mlp(
2, # x, y
hidden_features=hidden_dim,
out_features=num_heads,
act_layer=nn.ReLU,
bias=mlp_bias,
drop=(0.125, 0.)
)
self.register_buffer(
"relative_position_index",
gen_relative_position_index(window_size),
persistent=False)
# get relative_coords_table
self.register_buffer(
"rel_coords_log",
gen_relative_log_coords(window_size, pretrained_window_size, mode=mode),
persistent=False)
def get_bias(self) -> torch.Tensor:
relative_position_bias = self.mlp(self.rel_coords_log)
if self.relative_position_index is not None:
relative_position_bias = relative_position_bias.view(-1, self.num_heads)[
self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.view(self.bias_shape)
relative_position_bias = relative_position_bias.permute(2, 0, 1)
relative_position_bias = self.bias_act(relative_position_bias)
if self.bias_gain is not None:
relative_position_bias = self.bias_gain * relative_position_bias
if self.prefix_tokens:
relative_position_bias = F.pad(relative_position_bias, [self.prefix_tokens, 0, self.prefix_tokens, 0])
return relative_position_bias.unsqueeze(0).contiguous()
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
return attn + self.get_bias()
def generate_lookup_tensor(
length: int,
max_relative_position: Optional[int] = None,
):
"""Generate a one_hot lookup tensor to reindex embeddings along one dimension.
Args:
length: the length to reindex to.
max_relative_position: the maximum relative position to consider.
Relative position embeddings for distances above this threshold
are zeroed out.
Returns:
a lookup Tensor of size [length, length, vocab_size] that satisfies
ret[n,m,v] = 1{m - n + max_relative_position = v}.
"""
if max_relative_position is None:
max_relative_position = length - 1
# Return the cached lookup tensor, otherwise compute it and cache it.
vocab_size = 2 * max_relative_position + 1
ret = torch.zeros(length, length, vocab_size)
for i in range(length):
for x in range(length):
v = x - i + max_relative_position
if abs(x - i) > max_relative_position:
continue
ret[i, x, v] = 1
return ret
def reindex_2d_einsum_lookup(
relative_position_tensor,
height: int,
width: int,
height_lookup: torch.Tensor,
width_lookup: torch.Tensor,
) -> torch.Tensor:
"""Reindex 2d relative position bias with 2 independent einsum lookups.
Adapted from:
https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py
Args:
relative_position_tensor: tensor of shape
[..., vocab_height, vocab_width, ...].
height: height to reindex to.
width: width to reindex to.
height_lookup: one-hot height lookup
width_lookup: one-hot width lookup
Returns:
reindexed_tensor: a Tensor of shape
[..., height * width, height * width, ...]
"""
reindexed_tensor = torch.einsum('nhw,ixh->nixw', relative_position_tensor, height_lookup)
reindexed_tensor = torch.einsum('nixw,jyw->nijxy', reindexed_tensor, width_lookup)
area = height * width
return reindexed_tensor.reshape(relative_position_tensor.shape[0], area, area)
class RelPosBiasTf(nn.Module):
""" Relative Position Bias Impl (Compatible with Tensorflow MaxViT models)
Adapted from:
https://github.com/google-research/maxvit/blob/2e06a7f1f70c76e64cd3dabe5cd1b8c1a23c9fb7/maxvit/models/attention_utils.py
"""
def __init__(self, window_size, num_heads, prefix_tokens=0):
super().__init__()
assert prefix_tokens <= 1
self.window_size = window_size
self.window_area = window_size[0] * window_size[1]
self.num_heads = num_heads
vocab_height = 2 * window_size[0] - 1
vocab_width = 2 * window_size[1] - 1
self.bias_shape = (self.num_heads, vocab_height, vocab_width)
self.relative_position_bias_table = nn.Parameter(torch.zeros(self.bias_shape))
self.register_buffer('height_lookup', generate_lookup_tensor(window_size[0]), persistent=False)
self.register_buffer('width_lookup', generate_lookup_tensor(window_size[1]), persistent=False)
self.init_weights()
def init_weights(self):
nn.init.normal_(self.relative_position_bias_table, std=.02)
def get_bias(self) -> torch.Tensor:
# FIXME change to not use one-hot/einsum?
return reindex_2d_einsum_lookup(
self.relative_position_bias_table,
self.window_size[0],
self.window_size[1],
self.height_lookup,
self.width_lookup
)
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
return attn + self.get_bias()

@ -0,0 +1,219 @@
""" Sin-cos, fourier, rotary position embedding modules and functions
Hacked together by / Copyright 2022 Ross Wightman
"""
import math
from typing import List, Tuple, Optional, Union
import torch
from torch import nn as nn
def pixel_freq_bands(
num_bands: int,
max_freq: float = 224.,
linear_bands: bool = True,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
):
if linear_bands:
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device)
else:
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device)
return bands * torch.pi
def inv_freq_bands(
num_bands: int,
temperature: float = 100000.,
step: int = 2,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> torch.Tensor:
inv_freq = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands))
return inv_freq
def build_sincos2d_pos_embed(
feat_shape: List[int],
dim: int = 64,
temperature: float = 10000.,
reverse_coord: bool = False,
interleave_sin_cos: bool = False,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None
) -> torch.Tensor:
"""
Args:
feat_shape:
dim:
temperature:
reverse_coord: stack grid order W, H instead of H, W
interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos
dtype:
device:
Returns:
"""
assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
pos_dim = dim // 4
bands = inv_freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device)
if reverse_coord:
feat_shape = feat_shape[::-1] # stack W, H instead of H, W
grid = torch.stack(
torch.meshgrid([torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1)
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
# FIXME add support for unflattened spatial dim?
stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos
pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1)
return pos_emb
def build_fourier_pos_embed(
feat_shape: List[int],
bands: Optional[torch.Tensor] = None,
num_bands: int = 64,
max_res: int = 224,
linear_bands: bool = False,
include_grid: bool = False,
concat_out: bool = True,
in_pixels: bool = True,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> List[torch.Tensor]:
if bands is None:
if in_pixels:
bands = pixel_freq_bands(num_bands, float(max_res), linear_bands=linear_bands, dtype=dtype, device=device)
else:
bands = inv_freq_bands(num_bands, step=1, dtype=dtype, device=device)
else:
if device is None:
device = bands.device
if dtype is None:
dtype = bands.dtype
if in_pixels:
grid = torch.stack(torch.meshgrid(
[torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
else:
grid = torch.stack(torch.meshgrid(
[torch.arange(s, device=device, dtype=dtype) for s in feat_shape]), dim=-1)
grid = grid.unsqueeze(-1)
pos = grid * bands
pos_sin, pos_cos = pos.sin(), pos.cos()
out = (grid, pos_sin, pos_cos) if include_grid else (pos_sin, pos_cos)
# FIXME torchscript doesn't like multiple return types, probably need to always cat?
if concat_out:
out = torch.cat(out, dim=-1)
return out
class FourierEmbed(nn.Module):
def __init__(self, max_res: int = 224, num_bands: int = 64, concat_grid=True, keep_spatial=False):
super().__init__()
self.max_res = max_res
self.num_bands = num_bands
self.concat_grid = concat_grid
self.keep_spatial = keep_spatial
self.register_buffer('bands', pixel_freq_bands(max_res, num_bands), persistent=False)
def forward(self, x):
B, C = x.shape[:2]
feat_shape = x.shape[2:]
emb = build_fourier_pos_embed(
feat_shape,
self.bands,
include_grid=self.concat_grid,
dtype=x.dtype,
device=x.device)
emb = emb.transpose(-1, -2).flatten(len(feat_shape))
batch_expand = (B,) + (-1,) * (x.ndim - 1)
# FIXME support nD
if self.keep_spatial:
x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1)
else:
x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1)
x = x.reshape(B, feat_shape.numel(), -1)
return x
def rot(x):
return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):
return x * cos_emb + rot(x) * sin_emb
def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
if isinstance(x, torch.Tensor):
x = [x]
return [t * cos_emb + rot(t) * sin_emb for t in x]
def apply_rot_embed_split(x: torch.Tensor, emb):
split = emb.shape[-1] // 2
return x * emb[:, :split] + rot(x) * emb[:, split:]
def build_rotary_pos_embed(
feat_shape: List[int],
bands: Optional[torch.Tensor] = None,
dim: int = 64,
max_freq: float = 224,
linear_bands: bool = False,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
):
"""
NOTE: shape arg should include spatial dim only
"""
feat_shape = torch.Size(feat_shape)
sin_emb, cos_emb = build_fourier_pos_embed(
feat_shape,
bands=bands,
num_bands=dim // 4,
max_res=max_freq,
linear_bands=linear_bands,
concat_out=False,
device=device,
dtype=dtype,
)
N = feat_shape.numel()
sin_emb = sin_emb.reshape(N, -1).repeat_interleave(2, -1)
cos_emb = cos_emb.reshape(N, -1).repeat_interleave(2, -1)
return sin_emb, cos_emb
class RotaryEmbedding(nn.Module):
""" Rotary position embedding
NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not
been well tested, and will likely change. It will be moved to its own file.
The following impl/resources were referenced for this impl:
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
* https://blog.eleuther.ai/rotary-embeddings/
"""
def __init__(self, dim, max_res=224, linear_bands: bool = False):
super().__init__()
self.dim = dim
self.register_buffer('bands', pixel_freq_bands(dim // 4, max_res, linear_bands=linear_bands), persistent=False)
def get_embed(self, shape: List[int]):
return build_rotary_pos_embed(shape, self.bands)
def forward(self, x):
# assuming channel-first tensor where spatial dim are >= 2
sin_emb, cos_emb = self.get_embed(x.shape[2:])
return apply_rot_embed(x, sin_emb, cos_emb)

@ -1,5 +1,6 @@
import dataclasses
import logging
import os
from copy import deepcopy
from typing import Optional, Dict, Callable, Any, Tuple
@ -9,7 +10,7 @@ from torch.hub import load_state_dict_from_url
from timm.models._features import FeatureListNet, FeatureHookNet
from timm.models._features_fx import FeatureGraphNet
from timm.models._helpers import load_state_dict
from timm.models._hub import has_hf_hub, download_cached_file, load_state_dict_from_hf
from timm.models._hub import has_hf_hub, download_cached_file, check_cached_file, load_state_dict_from_hf
from timm.models._manipulate import adapt_input_conv
from timm.models._pretrained import PretrainedCfg
from timm.models._prune import adapt_model_from_file
@ -32,6 +33,7 @@ def _resolve_pretrained_source(pretrained_cfg):
pretrained_url = pretrained_cfg.get('url', None)
pretrained_file = pretrained_cfg.get('file', None)
hf_hub_id = pretrained_cfg.get('hf_hub_id', None)
# resolve where to load pretrained weights from
load_from = ''
pretrained_loc = ''
@ -43,15 +45,20 @@ def _resolve_pretrained_source(pretrained_cfg):
else:
# default source == timm or unspecified
if pretrained_file:
# file load override is the highest priority if set
load_from = 'file'
pretrained_loc = pretrained_file
elif pretrained_url:
load_from = 'url'
pretrained_loc = pretrained_url
elif hf_hub_id and has_hf_hub(necessary=True):
# hf-hub available as alternate weight source in default_cfg
load_from = 'hf-hub'
pretrained_loc = hf_hub_id
else:
# next, HF hub is prioritized unless a valid cached version of weights exists already
cached_url_valid = check_cached_file(pretrained_url) if pretrained_url else False
if hf_hub_id and has_hf_hub(necessary=True) and not cached_url_valid:
# hf-hub available as alternate weight source in default_cfg
load_from = 'hf-hub'
pretrained_loc = hf_hub_id
elif pretrained_url:
load_from = 'url'
pretrained_loc = pretrained_url
if load_from == 'hf-hub' and pretrained_cfg.get('hf_hub_filename', None):
# if a filename override is set, return tuple for location w/ (hub_id, filename)
pretrained_loc = pretrained_loc, pretrained_cfg['hf_hub_filename']
@ -105,7 +112,7 @@ def load_custom_pretrained(
pretrained_loc = download_cached_file(
pretrained_loc,
check_hash=_CHECK_HASH,
progress=_DOWNLOAD_PROGRESS
progress=_DOWNLOAD_PROGRESS,
)
if load_fn is not None:
@ -146,12 +153,21 @@ def load_pretrained(
state_dict = load_state_dict(pretrained_loc)
elif load_from == 'url':
_logger.info(f'Loading pretrained weights from url ({pretrained_loc})')
state_dict = load_state_dict_from_url(
pretrained_loc,
map_location='cpu',
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
)
if pretrained_cfg.get('custom_load', False):
pretrained_loc = download_cached_file(
pretrained_loc,
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
)
model.load_pretrained(pretrained_loc)
return
else:
state_dict = load_state_dict_from_url(
pretrained_loc,
map_location='cpu',
progress=_DOWNLOAD_PROGRESS,
check_hash=_CHECK_HASH,
)
elif load_from == 'hf-hub':
_logger.info(f'Loading pretrained weights from Hugging Face hub ({pretrained_loc})')
if isinstance(pretrained_loc, (list, tuple)):
@ -364,20 +380,14 @@ def build_model_with_cfg(
# For classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
if pretrained:
if pretrained_cfg.get('custom_load', False):
load_custom_pretrained(
model,
pretrained_cfg=pretrained_cfg,
)
else:
load_pretrained(
model,
pretrained_cfg=pretrained_cfg,
num_classes=num_classes_pretrained,
in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn,
strict=pretrained_strict,
)
load_pretrained(
model,
pretrained_cfg=pretrained_cfg,
num_classes=num_classes_pretrained,
in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn,
strict=pretrained_strict,
)
# Wrap the model in a feature extraction module if enabled
if features:

@ -1,3 +1,4 @@
import hashlib
import json
import logging
import os
@ -67,6 +68,26 @@ def download_cached_file(url, check_hash=True, progress=False):
return cached_file
def check_cached_file(url, check_hash=True):
if isinstance(url, (list, tuple)):
url, filename = url
else:
parts = urlparse(url)
filename = os.path.basename(parts.path)
cached_file = os.path.join(get_cache_dir(), filename)
if os.path.exists(cached_file):
if check_hash:
r = HASH_REGEX.search(filename) # r is Optional[Match[str]]
hash_prefix = r.group(1) if r else None
if hash_prefix:
with open(cached_file, 'rb') as f:
hd = hashlib.sha256(f.read()).hexdigest()
if hd[:len(hash_prefix)] != hash_prefix:
return False
return True
return False
def has_hf_hub(necessary=False):
if not _has_hf_hub and necessary:
# if no HF Hub module installed, and it is necessary to continue, raise error
@ -90,14 +111,14 @@ def load_cfg_from_json(json_file: Union[str, os.PathLike]):
return json.loads(text)
def _download_from_hf(model_id: str, filename: str):
def download_from_hf(model_id: str, filename: str):
hf_model_id, hf_revision = hf_split(model_id)
return hf_hub_download(hf_model_id, filename, revision=hf_revision)
def load_model_config_from_hf(model_id: str):
assert has_hf_hub(True)
cached_file = _download_from_hf(model_id, 'config.json')
cached_file = download_from_hf(model_id, 'config.json')
hf_config = load_cfg_from_json(cached_file)
if 'pretrained_cfg' not in hf_config:
@ -124,34 +145,28 @@ def load_model_config_from_hf(model_id: str):
def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
assert has_hf_hub(True)
cached_file = _download_from_hf(model_id, filename)
cached_file = download_from_hf(model_id, filename)
state_dict = torch.load(cached_file, map_location='cpu')
return state_dict
def save_for_hf(model, save_directory, model_config=None):
assert has_hf_hub(True)
def save_config_for_hf(model, config_path, model_config=None):
model_config = model_config or {}
save_directory = Path(save_directory)
save_directory.mkdir(exist_ok=True, parents=True)
weights_path = save_directory / 'pytorch_model.bin'
torch.save(model.state_dict(), weights_path)
config_path = save_directory / 'config.json'
hf_config = {}
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
# set some values at root config level
hf_config['architecture'] = pretrained_cfg.pop('architecture')
hf_config['num_classes'] = model_config.get('num_classes', model.num_classes)
hf_config['num_features'] = model_config.get('num_features', model.num_features)
hf_config['global_pool'] = model_config.get('global_pool', getattr(model, 'global_pool', None))
global_pool_type = model_config.get('global_pool', getattr(model, 'global_pool', None))
if isinstance(global_pool_type, str) and global_pool_type:
hf_config['global_pool'] = global_pool_type
if 'label' in model_config:
if 'labels' in model_config:
_logger.warning(
"'label' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. "
"'labels' as a config field for timm models is deprecated. Please use 'label_name' and 'display_name'. "
"Using provided 'label' field as 'label_name'.")
model_config['label_name'] = model_config.pop('label')
model_config['label_name'] = model_config.pop('labels')
label_name = model_config.pop('label_name', None)
if label_name:
@ -173,6 +188,18 @@ def save_for_hf(model, save_directory, model_config=None):
json.dump(hf_config, f, indent=2)
def save_for_hf(model, save_directory, model_config=None):
assert has_hf_hub(True)
save_directory = Path(save_directory)
save_directory.mkdir(exist_ok=True, parents=True)
weights_path = save_directory / 'pytorch_model.bin'
torch.save(model.state_dict(), weights_path)
config_path = save_directory / 'config.json'
save_config_for_hf(model, config_path, model_config=model_config)
def push_to_hf_hub(
model,
repo_id: str,
@ -182,6 +209,7 @@ def push_to_hf_hub(
private: bool = False,
create_pr: bool = False,
model_config: Optional[dict] = None,
model_card: Optional[dict] = None,
):
# Create repo if it doesn't exist yet
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
@ -205,9 +233,23 @@ def push_to_hf_hub(
# Add readme if it does not exist
if not has_readme:
model_card = model_card or {}
model_name = repo_id.split('/')[-1]
readme_path = Path(tmpdir) / "README.md"
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {model_name}'
readme_text = "---\n"
readme_text += "tags:\n- image-classification\n- timm\n"
readme_text += "library_tag: timm\n"
readme_text += f"license: {model_card.get('license', 'apache-2.0')}\n"
readme_text += "---\n"
readme_text += f"# Model card for {model_name}\n"
if 'description' in model_card:
readme_text += f"\n{model_card['description']}\n"
if 'details' in model_card:
readme_text += f"\n## Model Details\n"
for k, v in model_card['details'].items():
readme_text += f"- **{k}:** {v}\n"
if 'citation' in model_card:
readme_text += f"\n## Citation\n```\n{model_card['citation']}```\n"
readme_path.write_text(readme_text)
# Upload model and return

@ -19,6 +19,7 @@ class PretrainedCfg:
source: Optional[str] = None # source of cfg / weight location used (url, file, hf-hub)
architecture: Optional[str] = None # architecture variant can be set when not implicit
tag: Optional[str] = None # pretrained tag of source
custom_load: bool = False # use custom model specific model.load_pretrained() (ie for npz files)
# input / data config
@ -44,9 +45,11 @@ class PretrainedCfg:
classifier: Optional[str] = None
license: Optional[str] = None
source_url: Optional[str] = None
paper: Optional[str] = None
notes: Optional[str] = None
description: Optional[str] = None
origin_url: Optional[str] = None
paper_name: Optional[str] = None
paper_ids: Optional[Union[str, Tuple[str]]] = None
notes: Optional[Tuple[str]] = None
@property
def has_weights(self):
@ -62,11 +65,11 @@ class PretrainedCfg:
def filter_pretrained_cfg(cfg, remove_source=False, remove_null=True):
filtered_cfg = {}
keep_none = {'pool_size', 'first_conv', 'classifier'} # always keep these keys, even if none
keep_null = {'pool_size', 'first_conv', 'classifier'} # always keep these keys, even if none
for k, v in cfg.items():
if remove_source and k in {'url', 'file', 'hf_hub_id', 'hf_hub_id', 'hf_hub_filename', 'source'}:
continue
if remove_null and v is None and k not in keep_none:
if remove_null and v is None and k not in keep_null:
continue
filtered_cfg[k] = v
return filtered_cfg

@ -7,6 +7,7 @@ import re
import sys
from collections import defaultdict, deque
from copy import deepcopy
from dataclasses import replace
from typing import List, Optional, Union, Tuple
from ._pretrained import PretrainedCfg, DefaultCfg, split_model_name_tag
@ -20,7 +21,7 @@ _model_to_module = {} # mapping of model names to module names
_model_entrypoints = {} # mapping of model names to architecture entrypoint fns
_model_has_pretrained = set() # set of model names that have pretrained weight url present
_model_default_cfgs = dict() # central repo for model arch -> default cfg objects
_model_pretrained_cfgs = dict() # central repo for model arch + tag -> pretrained cfgs
_model_pretrained_cfgs = dict() # central repo for model arch.tag -> pretrained cfgs
_model_with_tags = defaultdict(list) # shortcut to map each model arch to all model + tag names
@ -48,24 +49,31 @@ def register_model(fn):
if hasattr(mod, 'default_cfgs') and model_name in mod.default_cfgs:
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos
cfg = mod.default_cfgs[model_name]
if not isinstance(cfg, DefaultCfg):
default_cfg = mod.default_cfgs[model_name]
if not isinstance(default_cfg, DefaultCfg):
# new style default cfg dataclass w/ multiple entries per model-arch
assert isinstance(cfg, dict)
assert isinstance(default_cfg, dict)
# old style cfg dict per model-arch
cfg = PretrainedCfg(**cfg)
cfg = DefaultCfg(tags=deque(['']), cfgs={'': cfg})
pretrained_cfg = PretrainedCfg(**default_cfg)
default_cfg = DefaultCfg(tags=deque(['']), cfgs={'': pretrained_cfg})
for tag_idx, tag in enumerate(cfg.tags):
for tag_idx, tag in enumerate(default_cfg.tags):
is_default = tag_idx == 0
pretrained_cfg = cfg.cfgs[tag]
pretrained_cfg = default_cfg.cfgs[tag]
model_name_tag = '.'.join([model_name, tag]) if tag else model_name
replace_items = dict(architecture=model_name, tag=tag if tag else None)
if pretrained_cfg.hf_hub_id and pretrained_cfg.hf_hub_id == 'timm/':
# auto-complete hub name w/ architecture.tag
replace_items['hf_hub_id'] = pretrained_cfg.hf_hub_id + model_name_tag
pretrained_cfg = replace(pretrained_cfg, **replace_items)
if is_default:
_model_pretrained_cfgs[model_name] = pretrained_cfg
if pretrained_cfg.has_weights:
# add tagless entry if it's default and has weights
_model_has_pretrained.add(model_name)
if tag:
model_name_tag = '.'.join([model_name, tag])
_model_pretrained_cfgs[model_name_tag] = pretrained_cfg
if pretrained_cfg.has_weights:
# add model w/ tag if tag is valid
@ -74,7 +82,7 @@ def register_model(fn):
else:
_model_with_tags[model_name].append(model_name) # has empty tag (to slowly remove these instances)
_model_default_cfgs[model_name] = cfg
_model_default_cfgs[model_name] = default_cfg
return fn
@ -198,15 +206,21 @@ def is_model_pretrained(model_name):
return model_name in _model_has_pretrained
def get_pretrained_cfg(model_name):
def get_pretrained_cfg(model_name, allow_unregistered=True):
if model_name in _model_pretrained_cfgs:
return deepcopy(_model_pretrained_cfgs[model_name])
raise RuntimeError(f'No pretrained config exists for model {model_name}.')
arch_name, tag = split_model_name_tag(model_name)
if arch_name in _model_default_cfgs:
# if model arch exists, but the tag is wrong, error out
raise RuntimeError(f'Invalid pretrained tag ({tag}) for {arch_name}.')
if allow_unregistered:
# if model arch doesn't exist, it has no pretrained_cfg registered, allow a default to be created
return None
raise RuntimeError(f'Model architecture ({arch_name}) has no pretrained cfg registered.')
def get_pretrained_cfg_value(model_name, cfg_key):
""" Get a specific model default_cfg value by key. None if key doesn't exist.
"""
if model_name in _model_pretrained_cfgs:
return getattr(_model_pretrained_cfgs[model_name], cfg_key, None)
raise RuntimeError(f'No pretrained config exist for model {model_name}.')
cfg = get_pretrained_cfg(model_name, allow_unregistered=False)
return getattr(cfg, cfg_key, None)

@ -355,64 +355,76 @@ def _cfg(url='', **kwargs):
default_cfgs = generate_default_cfgs({
'beit_base_patch16_224.in22k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'),
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth',
hf_hub_id='timm/'),
'beit_base_patch16_384.in22k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0,
),
'beit_base_patch16_224.in22k_ft_in22k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth',
hf_hub_id='timm/',
num_classes=21841,
),
'beit_large_patch16_224.in22k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'),
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth',
hf_hub_id='timm/'),
'beit_large_patch16_384.in22k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0,
),
'beit_large_patch16_512.in22k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth',
hf_hub_id='timm/',
input_size=(3, 512, 512), crop_pct=1.0,
),
'beit_large_patch16_224.in22k_ft_in22k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth',
hf_hub_id='timm/',
num_classes=21841,
),
'beitv2_base_patch16_224.in1k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21kto1k.pth',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'beitv2_base_patch16_224.in1k_ft_in22k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_base_patch16_224_pt1k_ft21k.pth',
num_classes=21841,
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
hf_hub_id='timm/',
num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'beitv2_large_patch16_224.in1k_ft_in22k_in1k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21kto1k.pth',
crop_pct=0.95,
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
hf_hub_id='timm/',
crop_pct=0.95, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'beitv2_large_patch16_224.in1k_ft_in22k': _cfg(
url='https://conversationhub.blob.core.windows.net/beit-share-public/beitv2/beitv2_large_patch16_224_pt1k_ft21k.pth',
num_classes=21841,
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
hf_hub_id='timm/',
num_classes=21841, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD
),
'eva_giant_patch14_224.clip_ft_in1k': _cfg(
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz224_ftcls_89p1.pt',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, crop_pct=1.0,
),
'eva_giant_patch14_336.clip_ft_in1k': _cfg(
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_clip_vis_enc_sz336_ftcls_89p4.pt',
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
'eva_giant_patch14_336.m30m_ft_in22k_in1k': _cfg(
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_336px_psz14_ema_89p6.pt',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
input_size=(3, 336, 336), crop_pct=1.0, crop_mode='squash'),
'eva_giant_patch14_560.m30m_ft_in22k_in1k': _cfg(
hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
# hf_hub_id='BAAI/EVA', hf_hub_filename='eva_21k_1k_560px_psz14_ema_89p7.pt',
hf_hub_id='timm/',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD,
input_size=(3, 560, 560), crop_pct=1.0, crop_mode='squash'),
})

@ -218,7 +218,10 @@ def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0):
def interleave_blocks(
types: Tuple[str, str], d, every: Union[int, List[int]] = 1, first: bool = False, **kwargs
types: Tuple[str, str], d,
every: Union[int, List[int]] = 1,
first: bool = False,
**kwargs,
) -> Tuple[ByoBlockCfg]:
""" interleave 2 block types in stack
"""
@ -962,9 +965,21 @@ class BasicBlock(nn.Module):
"""
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), group_size=None, bottle_ratio=1.0,
downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None,
drop_path_rate=0.):
self,
in_chs,
out_chs,
kernel_size=3,
stride=1,
dilation=(1, 1),
group_size=None,
bottle_ratio=1.0,
downsample='avg',
attn_last=True,
linear_out=False,
layers: LayerFn = None,
drop_block=None,
drop_path_rate=0.,
):
super(BasicBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
@ -983,7 +998,7 @@ class BasicBlock(nn.Module):
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last: bool = False):
if zero_init_last and self.shortcut is not None:
if zero_init_last and self.shortcut is not None and getattr(self.conv2_kxk.bn, 'weight', None) is not None:
nn.init.zeros_(self.conv2_kxk.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
@ -1005,9 +1020,23 @@ class BottleneckBlock(nn.Module):
"""
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', attn_last=False, linear_out=False, extra_conv=False, bottle_in=False,
layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
self,
in_chs,
out_chs,
kernel_size=3,
stride=1,
dilation=(1, 1),
bottle_ratio=1.,
group_size=None,
downsample='avg',
attn_last=False,
linear_out=False,
extra_conv=False,
bottle_in=False,
layers: LayerFn = None,
drop_block=None,
drop_path_rate=0.,
):
super(BottleneckBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
@ -1031,7 +1060,7 @@ class BottleneckBlock(nn.Module):
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last: bool = False):
if zero_init_last and self.shortcut is not None:
if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None:
nn.init.zeros_(self.conv3_1x1.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
@ -1063,9 +1092,21 @@ class DarkBlock(nn.Module):
"""
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None,
drop_path_rate=0.):
self,
in_chs,
out_chs,
kernel_size=3,
stride=1,
dilation=(1, 1),
bottle_ratio=1.0,
group_size=None,
downsample='avg',
attn_last=True,
linear_out=False,
layers: LayerFn = None,
drop_block=None,
drop_path_rate=0.,
):
super(DarkBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
@ -1085,7 +1126,7 @@ class DarkBlock(nn.Module):
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last: bool = False):
if zero_init_last and self.shortcut is not None:
if zero_init_last and self.shortcut is not None and getattr(self.conv2_kxk.bn, 'weight', None) is not None:
nn.init.zeros_(self.conv2_kxk.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
@ -1114,9 +1155,21 @@ class EdgeBlock(nn.Module):
"""
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None,
drop_block=None, drop_path_rate=0.):
self,
in_chs,
out_chs,
kernel_size=3,
stride=1,
dilation=(1, 1),
bottle_ratio=1.0,
group_size=None,
downsample='avg',
attn_last=False,
linear_out=False,
layers: LayerFn = None,
drop_block=None,
drop_path_rate=0.,
):
super(EdgeBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
@ -1135,7 +1188,7 @@ class EdgeBlock(nn.Module):
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last: bool = False):
if zero_init_last and self.shortcut is not None:
if zero_init_last and self.shortcut is not None and getattr(self.conv2_1x1.bn, 'weight', None) is not None:
nn.init.zeros_(self.conv2_1x1.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
@ -1162,8 +1215,19 @@ class RepVggBlock(nn.Module):
"""
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='', layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
self,
in_chs,
out_chs,
kernel_size=3,
stride=1,
dilation=(1, 1),
bottle_ratio=1.0,
group_size=None,
downsample='',
layers: LayerFn = None,
drop_block=None,
drop_path_rate=0.,
):
super(RepVggBlock, self).__init__()
layers = layers or LayerFn()
groups = num_groups(group_size, in_chs)
@ -1204,9 +1268,24 @@ class SelfAttnBlock(nn.Module):
"""
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', extra_conv=False, linear_out=False, bottle_in=False, post_attn_na=True,
feat_size=None, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
self,
in_chs,
out_chs,
kernel_size=3,
stride=1,
dilation=(1, 1),
bottle_ratio=1.,
group_size=None,
downsample='avg',
extra_conv=False,
linear_out=False,
bottle_in=False,
post_attn_na=True,
feat_size=None,
layers: LayerFn = None,
drop_block=None,
drop_path_rate=0.,
):
super(SelfAttnBlock, self).__init__()
assert layers is not None
mid_chs = make_divisible((in_chs if bottle_in else out_chs) * bottle_ratio)
@ -1233,7 +1312,7 @@ class SelfAttnBlock(nn.Module):
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last: bool = False):
if zero_init_last and self.shortcut is not None:
if zero_init_last and self.shortcut is not None and getattr(self.conv3_1x1.bn, 'weight', None) is not None:
nn.init.zeros_(self.conv3_1x1.bn.weight)
if hasattr(self.self_attn, 'reset_parameters'):
self.self_attn.reset_parameters()
@ -1274,8 +1353,17 @@ def create_block(block: Union[str, nn.Module], **kwargs):
class Stem(nn.Sequential):
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=4, pool='maxpool',
num_rep=3, num_act=None, chs_decay=0.5, layers: LayerFn = None):
self,
in_chs,
out_chs,
kernel_size=3,
stride=4,
pool='maxpool',
num_rep=3,
num_act=None,
chs_decay=0.5,
layers: LayerFn = None,
):
super().__init__()
assert stride in (2, 4)
layers = layers or LayerFn()
@ -1319,7 +1407,14 @@ class Stem(nn.Sequential):
assert curr_stride == stride
def create_byob_stem(in_chs, out_chs, stem_type='', pool_type='', feat_prefix='stem', layers: LayerFn = None):
def create_byob_stem(
in_chs,
out_chs,
stem_type='',
pool_type='',
feat_prefix='stem',
layers: LayerFn = None,
):
layers = layers or LayerFn()
assert stem_type in ('', 'quad', 'quad2', 'tiered', 'deep', 'rep', '7x7', '3x3')
if 'quad' in stem_type:
@ -1407,10 +1502,14 @@ def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, mo
def create_byob_stages(
cfg: ByoModelCfg, drop_path_rate: float, output_stride: int, stem_feat: Dict[str, Any],
cfg: ByoModelCfg,
drop_path_rate: float,
output_stride: int,
stem_feat: Dict[str, Any],
feat_size: Optional[int] = None,
layers: Optional[LayerFn] = None,
block_kwargs_fn: Optional[Callable] = update_block_kwargs):
block_kwargs_fn: Optional[Callable] = update_block_kwargs,
):
layers = layers or LayerFn()
feature_info = []
@ -1485,12 +1584,38 @@ class ByobNet(nn.Module):
Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act).
"""
def __init__(
self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
zero_init_last=True, img_size=None, drop_rate=0., drop_path_rate=0.):
self,
cfg: ByoModelCfg,
num_classes=1000,
in_chans=3,
global_pool='avg',
output_stride=32,
img_size=None,
drop_rate=0.,
drop_path_rate=0.,
zero_init_last=True,
**kwargs,
):
"""
Args:
cfg (ByoModelCfg): Model architecture configuration
num_classes (int): Number of classifier classes (default: 1000)
in_chans (int): Number of input channels (default: 3)
global_pool (str): Global pooling type (default: 'avg')
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
img_size (Union[int, Tuple[int]): Image size for fixed image size models (i.e. self-attn)
drop_rate (float): Dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
zero_init_last (bool): Zero-init last weight of residual path
kwargs (dict): Extra kwargs overlayed onto cfg
"""
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
self.grad_checkpointing = False
cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
layers = get_layer_fns(cfg)
if cfg.fixed_input_size:
assert img_size is not None, 'img_size argument is required for fixed input size model'

@ -1,25 +1,51 @@
""" ConvNeXt
Paper: `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
Original code and weights from https://github.com/facebookresearch/ConvNeXt, original copyright below
Model defs atto, femto, pico, nano and _ols / _hnf variants are timm specific.
Papers:
* `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
@Article{liu2022convnet,
author = {Zhuang Liu and Hanzi Mao and Chao-Yuan Wu and Christoph Feichtenhofer and Trevor Darrell and Saining Xie},
title = {A ConvNet for the 2020s},
journal = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2022},
}
* `ConvNeXt-V2 - Co-designing and Scaling ConvNets with Masked Autoencoders` - https://arxiv.org/abs/2301.00808
@article{Woo2023ConvNeXtV2,
title={ConvNeXt V2: Co-designing and Scaling ConvNets with Masked Autoencoders},
author={Sanghyun Woo, Shoubhik Debnath, Ronghang Hu, Xinlei Chen, Zhuang Liu, In So Kweon and Saining Xie},
year={2023},
journal={arXiv preprint arXiv:2301.00808},
}
Original code and weights from:
* https://github.com/facebookresearch/ConvNeXt, original copyright below
* https://github.com/facebookresearch/ConvNeXt-V2, original copyright below
Model defs atto, femto, pico, nano and _ols / _hnf variants are timm originals.
Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
"""
# ConvNeXt
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the MIT license
# ConvNeXt-V2
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree (Attribution-NonCommercial 4.0 International (CC BY-NC 4.0))
# No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially.
from collections import OrderedDict
from functools import partial
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, LayerNorm, \
create_conv2d, get_act_layer, make_divisible, to_ntuple
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq
from ._pretrained import generate_default_cfgs
@ -54,6 +80,7 @@ class ConvNeXtBlock(nn.Module):
mlp_ratio=4,
conv_mlp=False,
conv_bias=True,
use_grn=False,
ls_init_value=1e-6,
act_layer='gelu',
norm_layer=None,
@ -64,14 +91,13 @@ class ConvNeXtBlock(nn.Module):
act_layer = get_act_layer(act_layer)
if not norm_layer:
norm_layer = LayerNorm2d if conv_mlp else LayerNorm
mlp_layer = ConvMlp if conv_mlp else Mlp
mlp_layer = partial(GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp)
self.use_conv_mlp = conv_mlp
self.conv_dw = create_conv2d(
in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation, depthwise=True, bias=conv_bias)
self.norm = norm_layer(out_chs)
self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value > 0 else None
self.gamma = nn.Parameter(ls_init_value * torch.ones(out_chs)) if ls_init_value is not None else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
@ -106,6 +132,7 @@ class ConvNeXtStage(nn.Module):
ls_init_value=1.0,
conv_mlp=False,
conv_bias=True,
use_grn=False,
act_layer='gelu',
norm_layer=None,
norm_layer_cl=None
@ -138,8 +165,9 @@ class ConvNeXtStage(nn.Module):
ls_init_value=ls_init_value,
conv_mlp=conv_mlp,
conv_bias=conv_bias,
use_grn=use_grn,
act_layer=act_layer,
norm_layer=norm_layer if conv_mlp else norm_layer_cl
norm_layer=norm_layer if conv_mlp else norm_layer_cl,
))
in_chs = out_chs
self.blocks = nn.Sequential(*stage_blocks)
@ -156,16 +184,6 @@ class ConvNeXtStage(nn.Module):
class ConvNeXt(nn.Module):
r""" ConvNeXt
A PyTorch impl of : `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf
Args:
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
dims (tuple(int)): Feature dimension at each stage. Default: [96, 192, 384, 768]
drop_rate (float): Head dropout rate
drop_path_rate (float): Stochastic depth rate. Default: 0.
ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
"""
def __init__(
@ -184,21 +202,50 @@ class ConvNeXt(nn.Module):
head_norm_first=False,
conv_mlp=False,
conv_bias=True,
use_grn=False,
act_layer='gelu',
norm_layer=None,
norm_eps=None,
drop_rate=0.,
drop_path_rate=0.,
):
"""
Args:
in_chans (int): Number of input image channels (default: 3)
num_classes (int): Number of classes for classification head (default: 1000)
global_pool (str): Global pooling type (default: 'avg')
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
depths (tuple(int)): Number of blocks at each stage. (default: [3, 3, 9, 3])
dims (tuple(int)): Feature dimension at each stage. (default: [96, 192, 384, 768])
kernel_sizes (Union[int, List[int]]: Depthwise convolution kernel-sizes for each stage (default: 7)
ls_init_value (float): Init value for Layer Scale (default: 1e-6)
stem_type (str): Type of stem (default: 'patch')
patch_size (int): Stem patch size for patch stem (default: 4)
head_init_scale (float): Init scaling value for classifier weights and biases (default: 1)
head_norm_first (bool): Apply normalization before global pool + head (default: False)
conv_mlp (bool): Use 1x1 conv in MLP, improves speed for small networks w/ chan last (default: False)
conv_bias (bool): Use bias layers w/ all convolutions (default: True)
use_grn (bool): Use Global Response Norm (ConvNeXt-V2) in MLP (default: False)
act_layer (Union[str, nn.Module]): Activation Layer
norm_layer (Union[str, nn.Module]): Normalization Layer
drop_rate (float): Head dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth rate (default: 0.)
"""
super().__init__()
assert output_stride in (8, 16, 32)
kernel_sizes = to_ntuple(4)(kernel_sizes)
if norm_layer is None:
norm_layer = LayerNorm2d
norm_layer_cl = norm_layer if conv_mlp else LayerNorm
if norm_eps is not None:
norm_layer = partial(norm_layer, eps=norm_eps)
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
else:
assert conv_mlp,\
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
norm_layer_cl = norm_layer
if norm_eps is not None:
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
self.num_classes = num_classes
self.drop_rate = drop_rate
@ -209,7 +256,7 @@ class ConvNeXt(nn.Module):
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size, bias=conv_bias),
norm_layer(dims[0])
norm_layer(dims[0]),
)
stem_stride = patch_size
else:
@ -247,9 +294,10 @@ class ConvNeXt(nn.Module):
ls_init_value=ls_init_value,
conv_mlp=conv_mlp,
conv_bias=conv_bias,
use_grn=use_grn,
act_layer=act_layer,
norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl
norm_layer_cl=norm_layer_cl,
))
prev_chs = out_chs
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
@ -334,7 +382,15 @@ def checkpoint_filter_fn(state_dict, model):
return state_dict # non-FB checkpoint
if 'model' in state_dict:
state_dict = state_dict['model']
out_dict = {}
if 'visual.trunk.stem.0.weight' in state_dict:
out_dict = {k.replace('visual.trunk.', ''): v for k, v in state_dict.items() if k.startswith('visual.trunk.')}
if 'visual.head.proj.weight' in state_dict:
out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
return out_dict
import re
for k, v in state_dict.items():
k = k.replace('downsample_layers.0.', 'stem.')
@ -342,6 +398,10 @@ def checkpoint_filter_fn(state_dict, model):
k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
k = k.replace('dwconv', 'conv_dw')
k = k.replace('pwconv', 'mlp.fc')
if 'grn' in k:
k = k.replace('grn.beta', 'mlp.grn.bias')
k = k.replace('grn.gamma', 'mlp.grn.weight')
v = v.reshape(v.shape[-1])
k = k.replace('head.', 'head.fc.')
if k.startswith('norm.'):
k = k.replace('norm', 'head.norm')
@ -349,10 +409,16 @@ def checkpoint_filter_fn(state_dict, model):
model_shape = model.state_dict()[k].shape
v = v.reshape(model_shape)
out_dict[k] = v
return out_dict
def _create_convnext(variant, pretrained=False, **kwargs):
if kwargs.get('pretrained_cfg', '') == 'fcmae':
# NOTE fcmae pretrained weights have no classifier or final norm-layer (`head.norm`)
# This is workaround loading with num_classes=0 w/o removing norm-layer.
kwargs.setdefault('pretrained_strict', False)
model = build_model_with_cfg(
ConvNeXt, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
@ -361,7 +427,6 @@ def _create_convnext(variant, pretrained=False, **kwargs):
return model
def _cfg(url='', **kwargs):
return {
'url': url,
@ -373,92 +438,295 @@ def _cfg(url='', **kwargs):
}
def _cfgv2(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.0', 'classifier': 'head.fc',
'license': 'cc-by-nc-4.0', 'paper_ids': 'arXiv:2301.00808',
'paper_name': 'ConvNeXt-V2: Co-designing and Scaling ConvNets with Masked Autoencoders',
'origin_url': 'https://github.com/facebookresearch/ConvNeXt-V2',
**kwargs
}
default_cfgs = generate_default_cfgs({
# timm specific variants
'convnext_atto.timm_in1k': _cfg(
'convnext_atto.d2_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_d2-01bb0f51.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnext_atto_ols.timm_in1k': _cfg(
'convnext_atto_ols.a2_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_atto_ols_a2-78d1c8f3.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnext_femto.timm_in1k': _cfg(
'convnext_femto.d1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_d1-d71d5b4c.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnext_femto_ols.timm_in1k': _cfg(
'convnext_femto_ols.d1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_femto_ols_d1-246bf2ed.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnext_pico.timm_in1k': _cfg(
'convnext_pico.d1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_d1-10ad7f0d.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnext_pico_ols.timm_in1k': _cfg(
'convnext_pico_ols.d1_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_pico_ols_d1-611f0ca7.pth',
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_nano.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_nano.timm_in1k': _cfg(
'convnext_nano.d1h_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_d1h-7eb4bdea.pth',
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_nano_ols.timm_in1k': _cfg(
'convnext_nano_ols.d1h_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_nano_ols_d1h-ae424a9a.pth',
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny_hnf.timm_in1k': _cfg(
'convnext_tiny_hnf.a2h_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.in12k_ft_in1k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_nano.in12k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, num_classes=11821),
'convnext_tiny.in12k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, num_classes=11821),
'convnext_small.in12k': _cfg(
hf_hub_id='timm/',
crop_pct=0.95, num_classes=11821),
'convnext_tiny.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_base.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_large.fb_in1k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_xlarge.untrained': _cfg(),
'convnext_xxlarge.untrained': _cfg(),
'convnext_tiny.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_224.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_small.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_224.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_base.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_224.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_large.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_224.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_xlarge.fb_in22k_ft_in1k': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_224_ema.pth',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_1k_384.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_small..fb_in22k_ft_in1k_384': _cfg(
'convnext_small.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_1k_384.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_base.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_1k_384.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_large.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_1k_384.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_xlarge.fb_in22k_ft_in1k_384': _cfg(
url='https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_1k_384_ema.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnext_tiny.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth",
hf_hub_id='timm/',
num_classes=21841),
'convnext_small.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth",
hf_hub_id='timm/',
num_classes=21841),
'convnext_base.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth",
hf_hub_id='timm/',
num_classes=21841),
'convnext_large.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth",
hf_hub_id='timm/',
num_classes=21841),
'convnext_xlarge.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth",
hf_hub_id='timm/',
num_classes=21841),
'convnextv2_nano.fcmae_ft_in22k_in1k': _cfgv2(
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_224_ema.pt',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnextv2_nano.fcmae_ft_in22k_in1k_384': _cfgv2(
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_nano_22k_384_ema.pt',
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnextv2_tiny.fcmae_ft_in22k_in1k': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_224_ema.pt",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnextv2_tiny.fcmae_ft_in22k_in1k_384': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_tiny_22k_384_ema.pt",
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnextv2_base.fcmae_ft_in22k_in1k': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_224_ema.pt",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnextv2_base.fcmae_ft_in22k_in1k_384': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt",
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnextv2_large.fcmae_ft_in22k_in1k': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_224_ema.pt",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnextv2_large.fcmae_ft_in22k_in1k_384': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt",
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnextv2_huge.fcmae_ft_in22k_in1k_384': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_384_ema.pt",
hf_hub_id='timm/',
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0, crop_mode='squash'),
'convnextv2_huge.fcmae_ft_in22k_in1k_512': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt",
hf_hub_id='timm/',
input_size=(3, 512, 512), pool_size=(15, 15), crop_pct=1.0, crop_mode='squash'),
'convnextv2_atto.fcmae_ft_in1k': _cfgv2(
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_atto_1k_224_ema.pt',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnextv2_femto.fcmae_ft_in1k': _cfgv2(
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_femto_1k_224_ema.pt',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnextv2_pico.fcmae_ft_in1k': _cfgv2(
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_pico_1k_224_ema.pt',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=0.95),
'convnextv2_nano.fcmae_ft_in1k': _cfgv2(
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_nano_1k_224_ema.pt',
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnextv2_tiny.fcmae_ft_in1k': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_tiny_1k_224_ema.pt",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnextv2_base.fcmae_ft_in1k': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_base_1k_224_ema.pt",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnextv2_large.fcmae_ft_in1k': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_large_1k_224_ema.pt",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnextv2_huge.fcmae_ft_in1k': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/im1k/convnextv2_huge_1k_224_ema.pt",
hf_hub_id='timm/',
test_input_size=(3, 288, 288), test_crop_pct=1.0),
'convnext_tiny_in22k.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", num_classes=21841),
'convnext_small_in22k.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", num_classes=21841),
'convnext_base_in22k.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841),
'convnext_large_in22k.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", num_classes=21841),
'convnext_xlarge_in22k.fb_in22k': _cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", num_classes=21841),
'convnextv2_atto.fcmae': _cfgv2(
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_atto_1k_224_fcmae.pt',
hf_hub_id='timm/',
num_classes=0),
'convnextv2_femto.fcmae': _cfgv2(
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_femto_1k_224_fcmae.pt',
hf_hub_id='timm/',
num_classes=0),
'convnextv2_pico.fcmae': _cfgv2(
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_pico_1k_224_fcmae.pt',
hf_hub_id='timm/',
num_classes=0),
'convnextv2_nano.fcmae': _cfgv2(
url='https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_nano_1k_224_fcmae.pt',
hf_hub_id='timm/',
num_classes=0),
'convnextv2_tiny.fcmae': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_tiny_1k_224_fcmae.pt",
hf_hub_id='timm/',
num_classes=0),
'convnextv2_base.fcmae': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_base_1k_224_fcmae.pt",
hf_hub_id='timm/',
num_classes=0),
'convnextv2_large.fcmae': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_large_1k_224_fcmae.pt",
hf_hub_id='timm/',
num_classes=0),
'convnextv2_huge.fcmae': _cfgv2(
url="https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_huge_1k_224_fcmae.pt",
hf_hub_id='timm/',
num_classes=0),
'convnextv2_small.untrained': _cfg(),
# CLIP based weights, original image tower weights and fine-tunes
'convnext_base.clip_laion2b': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laion2b_augreg': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K-augreg',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laiona': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laiona_320': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 320, 320), crop_pct=1.0, num_classes=640),
'convnext_base.clip_laiona_augreg_320': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w_320-laion_aesthetic-s13B-b82K-augreg',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 320, 320), crop_pct=1.0, num_classes=640),
})
@ -576,3 +844,82 @@ def convnext_xlarge(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)
model = _create_convnext('convnext_xlarge', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_xxlarge(pretrained=False, **kwargs):
model_args = dict(depths=[3, 4, 30, 3], dims=[384, 768, 1536, 3072], **kwargs)
model = _create_convnext('convnext_xxlarge', pretrained=pretrained, **model_args)
return model
@register_model
def convnextv2_atto(pretrained=False, **kwargs):
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
model_args = dict(
depths=(2, 2, 6, 2), dims=(40, 80, 160, 320), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
model = _create_convnext('convnextv2_atto', pretrained=pretrained, **model_args)
return model
@register_model
def convnextv2_femto(pretrained=False, **kwargs):
# timm femto variant
model_args = dict(
depths=(2, 2, 6, 2), dims=(48, 96, 192, 384), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
model = _create_convnext('convnextv2_femto', pretrained=pretrained, **model_args)
return model
@register_model
def convnextv2_pico(pretrained=False, **kwargs):
# timm pico variant
model_args = dict(
depths=(2, 2, 6, 2), dims=(64, 128, 256, 512), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
model = _create_convnext('convnextv2_pico', pretrained=pretrained, **model_args)
return model
@register_model
def convnextv2_nano(pretrained=False, **kwargs):
# timm nano variant with standard stem and head
model_args = dict(
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), use_grn=True, ls_init_value=None, conv_mlp=True, **kwargs)
model = _create_convnext('convnextv2_nano', pretrained=pretrained, **model_args)
return model
@register_model
def convnextv2_tiny(pretrained=False, **kwargs):
model_args = dict(
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), use_grn=True, ls_init_value=None, **kwargs)
model = _create_convnext('convnextv2_tiny', pretrained=pretrained, **model_args)
return model
@register_model
def convnextv2_small(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], use_grn=True, ls_init_value=None, **kwargs)
model = _create_convnext('convnextv2_small', pretrained=pretrained, **model_args)
return model
@register_model
def convnextv2_base(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], use_grn=True, ls_init_value=None, **kwargs)
model = _create_convnext('convnextv2_base', pretrained=pretrained, **model_args)
return model
@register_model
def convnextv2_large(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], use_grn=True, ls_init_value=None, **kwargs)
model = _create_convnext('convnextv2_large', pretrained=pretrained, **model_args)
return model
@register_model
def convnextv2_huge(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[352, 704, 1408, 2816], use_grn=True, ls_init_value=None, **kwargs)
model = _create_convnext('convnextv2_huge', pretrained=pretrained, **model_args)
return model

@ -12,7 +12,7 @@ Reference impl via darknet cfg files at https://github.com/WongKinYiu/CrossStage
Hacked together by / Copyright 2020 Ross Wightman
"""
from dataclasses import dataclass, asdict
from dataclasses import dataclass, asdict, replace
from functools import partial
from typing import Any, Dict, Optional, Tuple, Union
@ -518,7 +518,7 @@ class CrossStage(nn.Module):
cross_linear=False,
block_dpr=None,
block_fn=BottleneckBlock,
**block_kwargs
**block_kwargs,
):
super(CrossStage, self).__init__()
first_dilation = first_dilation or dilation
@ -558,7 +558,7 @@ class CrossStage(nn.Module):
bottle_ratio=bottle_ratio,
groups=groups,
drop_path=block_dpr[i] if block_dpr is not None else 0.,
**block_kwargs
**block_kwargs,
))
prev_chs = block_out_chs
@ -597,7 +597,7 @@ class CrossStage3(nn.Module):
cross_linear=False,
block_dpr=None,
block_fn=BottleneckBlock,
**block_kwargs
**block_kwargs,
):
super(CrossStage3, self).__init__()
first_dilation = first_dilation or dilation
@ -635,7 +635,7 @@ class CrossStage3(nn.Module):
bottle_ratio=bottle_ratio,
groups=groups,
drop_path=block_dpr[i] if block_dpr is not None else 0.,
**block_kwargs
**block_kwargs,
))
prev_chs = block_out_chs
@ -668,7 +668,7 @@ class DarkStage(nn.Module):
avg_down=False,
block_fn=BottleneckBlock,
block_dpr=None,
**block_kwargs
**block_kwargs,
):
super(DarkStage, self).__init__()
first_dilation = first_dilation or dilation
@ -715,7 +715,7 @@ def create_csp_stem(
padding='',
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
aa_layer=None
aa_layer=None,
):
stem = nn.Sequential()
feature_info = []
@ -738,7 +738,7 @@ def create_csp_stem(
stride=conv_stride,
padding=padding if i == 0 else '',
act_layer=act_layer,
norm_layer=norm_layer
norm_layer=norm_layer,
))
stem_stride *= conv_stride
prev_chs = chs
@ -800,7 +800,7 @@ def create_csp_stages(
cfg: CspModelCfg,
drop_path_rate: float,
output_stride: int,
stem_feat: Dict[str, Any]
stem_feat: Dict[str, Any],
):
cfg_dict = asdict(cfg.stages)
num_stages = len(cfg.stages.depth)
@ -868,12 +868,27 @@ class CspNet(nn.Module):
global_pool='avg',
drop_rate=0.,
drop_path_rate=0.,
zero_init_last=True
zero_init_last=True,
**kwargs,
):
"""
Args:
cfg (CspModelCfg): Model architecture configuration
in_chans (int): Number of input channels (default: 3)
num_classes (int): Number of classifier classes (default: 1000)
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
global_pool (str): Global pooling type (default: 'avg')
drop_rate (float): Dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
zero_init_last (bool): Zero-init last weight of residual path
kwargs (dict): Extra kwargs overlayed onto cfg
"""
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
assert output_stride in (8, 16, 32)
cfg = replace(cfg, **kwargs) # overlay kwargs onto cfg
layer_args = dict(
act_layer=cfg.act_layer,
norm_layer=cfg.norm_layer,

@ -12,6 +12,7 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig
# All rights reserved.
# This source code is licensed under the MIT license
from collections import OrderedDict
import itertools
import torch
@ -20,7 +21,7 @@ import torch.nn.functional as F
from torch import Tensor
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, trunc_normal_, ClassifierHead, Mlp
from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp # ClassifierHead
from ._builder import build_model_with_cfg
from ._features import FeatureInfo
from ._features_fx import register_notrace_function
@ -407,7 +408,11 @@ class DaViTStage(nn.Module):
stage_blocks.append(nn.Sequential(*dual_attention_block))
self.blocks = nn.Sequential(*stage_blocks)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
def forward(self, x : Tensor):
x = self.patch_embed(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
@ -455,7 +460,8 @@ class DaViT(nn.Module):
drop_rate=0.,
attn_drop_rate=0.,
num_classes=1000,
global_pool='avg'
global_pool='avg',
head_norm_first=False,
):
super().__init__()
@ -503,11 +509,19 @@ class DaViT(nn.Module):
stages.append(stage)
self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')]
self.stages = nn.Sequential(*stages)
self.norms = norm_layer(self.num_features)
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
# otherwise pool -> norm -> fc, the default DaViT order, similar to ConvNeXt
# FIXME generalize this structure to ClassifierHead
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
self.apply(self._init_weights)
def _init_weights(self, m):
@ -522,40 +536,44 @@ class DaViT(nn.Module):
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
for stage in self.stages:
stage.set_grad_checkpointing(enable=enable)
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is None:
global_pool = self.head.global_pool.pool_type
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
if global_pool is not None:
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
x = self.stages(x)
# take final feature and norm
x = self.norms(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
#H, W = sizes[-1]
#x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous()
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stages, x)
else:
x = self.stages(x)
x = self.norm_pre(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x
def forward_head(self, x, pre_logits: bool = False):
return self.head(x, pre_logits=pre_logits)
x = self.head.global_pool(x)
x = self.head.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = self.head.flatten(x)
x = self.head.drop(x)
return x if pre_logits else self.head.fc(x)
def forward_classifier(self, x):
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def forward(self, x):
return self.forward_classifier(x)
def checkpoint_filter_fn(state_dict, model):
""" Remap MSFT checkpoints -> timm """
if 'head.norm.weight' in state_dict:
if 'head' in state_dict:
return state_dict # non-MSFT checkpoint
if 'state_dict' in state_dict:
@ -569,6 +587,7 @@ def checkpoint_filter_fn(state_dict, model):
k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
k = k.replace('stages.0.patch_embed', 'patch_embed')
k = k.replace('head.', 'head.fc.')
k = k.replace('norms.', 'head.norm.')
k = k.replace('cpe.0', 'cpe1')
k = k.replace('cpe.1', 'cpe2')
out_dict[k] = v
@ -577,8 +596,6 @@ def checkpoint_filter_fn(state_dict, model):
def _create_davit(variant, pretrained=False, **kwargs):
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
@ -594,11 +611,11 @@ def _create_davit(variant, pretrained=False, **kwargs):
def _cfg(url='', **kwargs): # not sure how this should be set up
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'crop_pct': 0.850, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
**kwargs

@ -12,7 +12,7 @@ import torch.utils.checkpoint as cp
from torch.jit.annotations import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier
from timm.layers import BatchNormAct2d, get_norm_act_layer, BlurPool2d, create_classifier
from ._builder import build_model_with_cfg
from ._manipulate import MATCH_PREV_GROUP
from ._registry import register_model
@ -115,8 +115,15 @@ class DenseBlock(nn.ModuleDict):
_version = 2
def __init__(
self, num_layers, num_input_features, bn_size, growth_rate, norm_layer=BatchNormAct2d,
drop_rate=0., memory_efficient=False):
self,
num_layers,
num_input_features,
bn_size,
growth_rate,
norm_layer=BatchNormAct2d,
drop_rate=0.,
memory_efficient=False,
):
super(DenseBlock, self).__init__()
for i in range(num_layers):
layer = DenseLayer(
@ -165,12 +172,25 @@ class DenseNet(nn.Module):
"""
def __init__(
self, growth_rate=32, block_config=(6, 12, 24, 16), num_classes=1000, in_chans=3, global_pool='avg',
bn_size=4, stem_type='', norm_layer=BatchNormAct2d, aa_layer=None, drop_rate=0,
memory_efficient=False, aa_stem_only=True):
self,
growth_rate=32,
block_config=(6, 12, 24, 16),
num_classes=1000,
in_chans=3,
global_pool='avg',
bn_size=4,
stem_type='',
act_layer='relu',
norm_layer='batchnorm2d',
aa_layer=None,
drop_rate=0,
memory_efficient=False,
aa_stem_only=True,
):
self.num_classes = num_classes
self.drop_rate = drop_rate
super(DenseNet, self).__init__()
norm_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
# Stem
deep_stem = 'deep' in stem_type # 3x3 deep stem
@ -226,8 +246,11 @@ class DenseNet(nn.Module):
dict(num_chs=num_features, reduction=current_stride, module='features.' + module_name)]
current_stride *= 2
trans = DenseTransition(
num_input_features=num_features, num_output_features=num_features // 2,
norm_layer=norm_layer, aa_layer=transition_aa_layer)
num_input_features=num_features,
num_output_features=num_features // 2,
norm_layer=norm_layer,
aa_layer=transition_aa_layer,
)
self.features.add_module(f'transition{i + 1}', trans)
num_features = num_features // 2
@ -322,8 +345,8 @@ def densenetblur121d(pretrained=False, **kwargs):
`"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`
"""
model = _create_densenet(
'densenetblur121d', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained, stem_type='deep',
aa_layer=BlurPool2d, **kwargs)
'densenetblur121d', growth_rate=32, block_config=(6, 12, 24, 16), pretrained=pretrained,
stem_type='deep', aa_layer=BlurPool2d, **kwargs)
return model
@ -382,11 +405,9 @@ def densenet264(pretrained=False, **kwargs):
def densenet264d_iabn(pretrained=False, **kwargs):
r"""Densenet-264 model with deep stem and Inplace-ABN
"""
def norm_act_fn(num_features, **kwargs):
return create_norm_act_layer('iabn', num_features, act_layer='leaky_relu', **kwargs)
model = _create_densenet(
'densenet264d_iabn', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep',
norm_layer=norm_act_fn, pretrained=pretrained, **kwargs)
norm_layer='iabn', act_layer='leaky_relu', pretrained=pretrained, **kwargs)
return model

@ -15,7 +15,7 @@ import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier
from timm.layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier, get_norm_act_layer
from ._builder import build_model_with_cfg
from ._registry import register_model
@ -33,6 +33,7 @@ def _cfg(url='', **kwargs):
default_cfgs = {
'dpn48b': _cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
'dpn68': _cfg(
url='https://github.com/rwightman/pytorch-dpn-pretrained/releases/download/v0.1/dpn68-66bebafa7.pth'),
'dpn68b': _cfg(
@ -82,7 +83,16 @@ class BnActConv2d(nn.Module):
class DualPathBlock(nn.Module):
def __init__(
self, in_chs, num_1x1_a, num_3x3_b, num_1x1_c, inc, groups, block_type='normal', b=False):
self,
in_chs,
num_1x1_a,
num_3x3_b,
num_1x1_c,
inc,
groups,
block_type='normal',
b=False,
):
super(DualPathBlock, self).__init__()
self.num_1x1_c = num_1x1_c
self.inc = inc
@ -167,16 +177,31 @@ class DualPathBlock(nn.Module):
class DPN(nn.Module):
def __init__(
self, small=False, num_init_features=64, k_r=96, groups=32, global_pool='avg',
b=False, k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), output_stride=32,
num_classes=1000, in_chans=3, drop_rate=0., fc_act_layer=nn.ELU):
self,
k_sec=(3, 4, 20, 3),
inc_sec=(16, 32, 24, 128),
k_r=96,
groups=32,
num_classes=1000,
in_chans=3,
output_stride=32,
global_pool='avg',
small=False,
num_init_features=64,
b=False,
drop_rate=0.,
norm_layer='batchnorm2d',
act_layer='relu',
fc_act_layer='elu',
):
super(DPN, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
self.b = b
assert output_stride == 32 # FIXME look into dilation support
norm_layer = partial(BatchNormAct2d, eps=.001)
fc_norm_layer = partial(BatchNormAct2d, eps=.001, act_layer=fc_act_layer, inplace=False)
norm_layer = partial(get_norm_act_layer(norm_layer, act_layer=act_layer), eps=.001)
fc_norm_layer = partial(get_norm_act_layer(norm_layer, act_layer=fc_act_layer), eps=.001, inplace=False)
bw_factor = 1 if small else 4
blocks = OrderedDict()
@ -291,49 +316,57 @@ def _create_dpn(variant, pretrained=False, **kwargs):
**kwargs)
@register_model
def dpn48b(pretrained=False, **kwargs):
model_kwargs = dict(
small=True, num_init_features=10, k_r=128, groups=32,
b=True, k_sec=(3, 4, 6, 3), inc_sec=(16, 32, 32, 64), act_layer='silu')
return _create_dpn('dpn48b', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
def dpn68(pretrained=False, **kwargs):
model_kwargs = dict(
small=True, num_init_features=10, k_r=128, groups=32,
k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs)
return _create_dpn('dpn68', pretrained=pretrained, **model_kwargs)
k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64))
return _create_dpn('dpn68', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
def dpn68b(pretrained=False, **kwargs):
model_kwargs = dict(
small=True, num_init_features=10, k_r=128, groups=32,
b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64), **kwargs)
return _create_dpn('dpn68b', pretrained=pretrained, **model_kwargs)
b=True, k_sec=(3, 4, 12, 3), inc_sec=(16, 32, 32, 64))
return _create_dpn('dpn68b', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
def dpn92(pretrained=False, **kwargs):
model_kwargs = dict(
num_init_features=64, k_r=96, groups=32,
k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128), **kwargs)
return _create_dpn('dpn92', pretrained=pretrained, **model_kwargs)
k_sec=(3, 4, 20, 3), inc_sec=(16, 32, 24, 128))
return _create_dpn('dpn92', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
def dpn98(pretrained=False, **kwargs):
model_kwargs = dict(
num_init_features=96, k_r=160, groups=40,
k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128), **kwargs)
return _create_dpn('dpn98', pretrained=pretrained, **model_kwargs)
k_sec=(3, 6, 20, 3), inc_sec=(16, 32, 32, 128))
return _create_dpn('dpn98', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
def dpn131(pretrained=False, **kwargs):
model_kwargs = dict(
num_init_features=128, k_r=160, groups=40,
k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128), **kwargs)
return _create_dpn('dpn131', pretrained=pretrained, **model_kwargs)
k_sec=(4, 8, 28, 3), inc_sec=(16, 32, 32, 128))
return _create_dpn('dpn131', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
def dpn107(pretrained=False, **kwargs):
model_kwargs = dict(
num_init_features=128, k_r=200, groups=50,
k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128), **kwargs)
return _create_dpn('dpn107', pretrained=pretrained, **model_kwargs)
k_sec=(4, 8, 20, 3), inc_sec=(20, 64, 64, 128))
return _create_dpn('dpn107', pretrained=pretrained, **dict(model_kwargs, **kwargs))

File diff suppressed because it is too large Load Diff

@ -2,6 +2,7 @@
from timm.layers.activations import *
from timm.layers.adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from timm.layers.attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
from timm.layers.blur_pool import BlurPool2d
from timm.layers.classifier import ClassifierHead, create_classifier
from timm.layers.cond_conv2d import CondConv2d, get_condconv_initializer

@ -47,16 +47,15 @@ import torch
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, ConvMlp, DropPath, ClassifierHead, trunc_normal_tf_, LayerNorm
from timm.layers import SelectAdaptivePool2d, create_pool2d
from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d
from timm.layers import to_2tuple, extend_tuple, make_divisible, _assert
from timm.layers import Mlp, ConvMlp, DropPath, ClassifierHead, LayerNorm, SelectAdaptivePool2d
from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d
from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert
from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import named_apply, checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
from .vision_transformer_relpos import RelPosMlp, RelPosBias # FIXME move these to common location
__all__ = ['MaxxVitCfg', 'MaxxVitConvCfg', 'MaxxVitTransformerCfg', 'MaxxVit']
@ -1076,93 +1075,6 @@ def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]):
return cfg
def generate_lookup_tensor(
length: int,
max_relative_position: Optional[int] = None,
):
"""Generate a one_hot lookup tensor to reindex embeddings along one dimension.
Args:
length: the length to reindex to.
max_relative_position: the maximum relative position to consider.
Relative position embeddings for distances above this threshold
are zeroed out.
Returns:
a lookup Tensor of size [length, length, vocab_size] that satisfies
ret[n,m,v] = 1{m - n + max_relative_position = v}.
"""
if max_relative_position is None:
max_relative_position = length - 1
# Return the cached lookup tensor, otherwise compute it and cache it.
vocab_size = 2 * max_relative_position + 1
ret = torch.zeros(length, length, vocab_size)
for i in range(length):
for x in range(length):
v = x - i + max_relative_position
if abs(x - i) > max_relative_position:
continue
ret[i, x, v] = 1
return ret
def reindex_2d_einsum_lookup(
relative_position_tensor,
height: int,
width: int,
height_lookup: torch.Tensor,
width_lookup: torch.Tensor,
) -> torch.Tensor:
"""Reindex 2d relative position bias with 2 independent einsum lookups.
Args:
relative_position_tensor: tensor of shape
[..., vocab_height, vocab_width, ...].
height: height to reindex to.
width: width to reindex to.
height_lookup: one-hot height lookup
width_lookup: one-hot width lookup
Returns:
reindexed_tensor: a Tensor of shape
[..., height * width, height * width, ...]
"""
reindexed_tensor = torch.einsum('nhw,ixh->nixw', relative_position_tensor, height_lookup)
reindexed_tensor = torch.einsum('nixw,jyw->nijxy', reindexed_tensor, width_lookup)
area = height * width
return reindexed_tensor.reshape(relative_position_tensor.shape[0], area, area)
class RelPosBiasTf(nn.Module):
def __init__(self, window_size, num_heads, prefix_tokens=0):
super().__init__()
assert prefix_tokens <= 1
self.window_size = window_size
self.window_area = window_size[0] * window_size[1]
self.num_heads = num_heads
vocab_height = 2 * window_size[0] - 1
vocab_width = 2 * window_size[1] - 1
self.bias_shape = (self.num_heads, vocab_height, vocab_width)
self.relative_position_bias_table = nn.Parameter(torch.zeros(self.bias_shape))
self.register_buffer('height_lookup', generate_lookup_tensor(window_size[0]), persistent=False)
self.register_buffer('width_lookup', generate_lookup_tensor(window_size[1]), persistent=False)
self.init_weights()
def init_weights(self):
nn.init.normal_(self.relative_position_bias_table, std=.02)
def get_bias(self) -> torch.Tensor:
# FIXME change to not use one-hot/einsum?
return reindex_2d_einsum_lookup(
self.relative_position_bias_table,
self.window_size[0],
self.window_size[1],
self.height_lookup,
self.width_lookup
)
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
return attn + self.get_bias()
class NormMlpHead(nn.Module):
def __init__(
@ -1204,6 +1116,26 @@ class NormMlpHead(nn.Module):
return x
def _overlay_kwargs(cfg: MaxxVitCfg, **kwargs):
transformer_kwargs = {}
conv_kwargs = {}
base_kwargs = {}
for k, v in kwargs.items():
if k.startswith('transformer_'):
transformer_kwargs[k.replace('transformer_', '')] = v
elif k.startswith('conv_'):
conv_kwargs[k.replace('conv_', '')] = v
else:
base_kwargs[k] = v
cfg = replace(
cfg,
transformer_cfg=replace(cfg.transformer_cfg, **transformer_kwargs),
conv_cfg=replace(cfg.conv_cfg, **conv_kwargs),
**base_kwargs
)
return cfg
class MaxxVit(nn.Module):
""" CoaTNet + MaxVit base model.
@ -1218,10 +1150,13 @@ class MaxxVit(nn.Module):
num_classes: int = 1000,
global_pool: str = 'avg',
drop_rate: float = 0.,
drop_path_rate: float = 0.
drop_path_rate: float = 0.,
**kwargs,
):
super().__init__()
img_size = to_2tuple(img_size)
if kwargs:
cfg = _overlay_kwargs(cfg, **kwargs)
transformer_cfg = cfg_window_size(cfg.transformer_cfg, img_size)
self.num_classes = num_classes
self.global_pool = global_pool
@ -1745,6 +1680,26 @@ model_cfgs = dict(
init_values=1e-6,
),
),
maxvit_rmlp_base_rw_224=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
head_hidden_size=768,
**_rw_max_cfg(
rel_pos_type='mlp',
),
),
maxvit_rmlp_base_rw_384=MaxxVitCfg(
embed_dim=(96, 192, 384, 768),
depths=(2, 6, 14, 2),
block_type=('M',) * 4,
stem_width=(32, 64),
head_hidden_size=768,
**_rw_max_cfg(
rel_pos_type='mlp',
),
),
maxvit_tiny_pm_256=MaxxVitCfg(
embed_dim=(64, 128, 256, 512),
@ -1927,6 +1882,12 @@ default_cfgs = generate_default_cfgs({
'maxvit_rmlp_small_rw_256': _cfg(
url='',
input_size=(3, 256, 256), pool_size=(8, 8)),
'maxvit_rmlp_base_rw_224': _cfg(
url='',
),
'maxvit_rmlp_base_rw_384': _cfg(
url='',
input_size=(3, 384, 384), pool_size=(12, 12)),
'maxvit_tiny_pm_256': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
@ -2156,6 +2117,16 @@ def maxvit_rmlp_small_rw_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_small_rw_256', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_base_rw_224(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_base_rw_224', pretrained=pretrained, **kwargs)
@register_model
def maxvit_rmlp_base_rw_384(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_rmlp_base_rw_384', pretrained=pretrained, **kwargs)
@register_model
def maxvit_tiny_pm_256(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_tiny_pm_256', pretrained=pretrained, **kwargs)

@ -21,93 +21,12 @@ from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficie
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from ._features import FeatureInfo, FeatureHooks
from ._manipulate import checkpoint_seq
from ._pretrained import generate_default_cfgs
from ._registry import register_model
__all__ = ['MobileNetV3', 'MobileNetV3Features']
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv_stem', 'classifier': 'classifier',
**kwargs
}
default_cfgs = {
'mobilenetv3_large_075': _cfg(url=''),
'mobilenetv3_large_100': _cfg(
interpolation='bicubic',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth'),
'mobilenetv3_large_100_miil': _cfg(
interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.),
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_1k_miil_78_0-66471c13.pth'),
'mobilenetv3_large_100_miil_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_in21k_miil-d71cc17b.pth',
interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), num_classes=11221),
'mobilenetv3_small_050': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth',
interpolation='bicubic'),
'mobilenetv3_small_075': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_075_lambc-384766db.pth',
interpolation='bicubic'),
'mobilenetv3_small_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_100_lamb-266a294c.pth',
interpolation='bicubic'),
'mobilenetv3_rw': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
interpolation='bicubic'),
'tf_mobilenetv3_large_075': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_large_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_large_minimal_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_small_075': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_small_100': _cfg(
url= 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_small_minimal_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'fbnetv3_b': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_b_224-ead5d2a1.pth',
test_input_size=(3, 256, 256), crop_pct=0.95),
'fbnetv3_d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_d_224-c98bce42.pth',
test_input_size=(3, 256, 256), crop_pct=0.95),
'fbnetv3_g': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_g_240-0b1df83b.pth',
input_size=(3, 240, 240), test_input_size=(3, 288, 288), crop_pct=0.95, pool_size=(8, 8)),
"lcnet_035": _cfg(),
"lcnet_050": _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_050-f447553b.pth',
interpolation='bicubic',
),
"lcnet_075": _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_075-318cad2c.pth',
interpolation='bicubic',
),
"lcnet_100": _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_100-a929038c.pth',
interpolation='bicubic',
),
"lcnet_150": _cfg(),
}
class MobileNetV3(nn.Module):
""" MobiletNet-V3
@ -124,9 +43,24 @@ class MobileNetV3(nn.Module):
"""
def __init__(
self, block_args, num_classes=1000, in_chans=3, stem_size=16, fix_stem=False, num_features=1280,
head_bias=True, pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True,
round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'):
self,
block_args,
num_classes=1000,
in_chans=3,
stem_size=16,
fix_stem=False,
num_features=1280,
head_bias=True,
pad_type='',
act_layer=None,
norm_layer=None,
se_layer=None,
se_from_exp=True,
round_chs_fn=round_channels,
drop_rate=0.,
drop_path_rate=0.,
global_pool='avg',
):
super(MobileNetV3, self).__init__()
act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d
@ -145,8 +79,15 @@ class MobileNetV3(nn.Module):
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
output_stride=32, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp,
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate)
output_stride=32,
pad_type=pad_type,
round_chs_fn=round_chs_fn,
se_from_exp=se_from_exp,
act_layer=act_layer,
norm_layer=norm_layer,
se_layer=se_layer,
drop_path_rate=drop_path_rate,
)
self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = builder.features
head_chs = builder.in_chs
@ -225,9 +166,23 @@ class MobileNetV3Features(nn.Module):
"""
def __init__(
self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,
stem_size=16, fix_stem=False, output_stride=32, pad_type='', round_chs_fn=round_channels,
se_from_exp=True, act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
self,
block_args,
out_indices=(0, 1, 2, 3, 4),
feature_location='bottleneck',
in_chans=3,
stem_size=16,
fix_stem=False,
output_stride=32,
pad_type='',
round_chs_fn=round_channels,
se_from_exp=True,
act_layer=None,
norm_layer=None,
se_layer=None,
drop_rate=0.,
drop_path_rate=0.,
):
super(MobileNetV3Features, self).__init__()
act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d
@ -243,9 +198,16 @@ class MobileNetV3Features(nn.Module):
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp,
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer,
drop_path_rate=drop_path_rate, feature_location=feature_location)
output_stride=output_stride,
pad_type=pad_type,
round_chs_fn=round_chs_fn,
se_from_exp=se_from_exp,
act_layer=act_layer,
norm_layer=norm_layer,
se_layer=se_layer,
drop_path_rate=drop_path_rate,
feature_location=feature_location,
)
self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = FeatureInfo(builder.features, out_indices)
self._stage_out_idx = {v['stage']: i for i, v in enumerate(self.feature_info) if i in out_indices}
@ -286,7 +248,9 @@ def _create_mnv3(variant, pretrained=False, **kwargs):
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool')
model_cls = MobileNetV3Features
model = build_model_with_cfg(
model_cls, variant, pretrained,
model_cls,
variant,
pretrained,
pretrained_strict=not features_only,
kwargs_filter=kwargs_filter,
**kwargs)
@ -567,6 +531,110 @@ def _gen_lcnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
return model
def _cfg(url='', **kwargs):
return {
'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bilinear',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'conv_stem', 'classifier': 'classifier',
**kwargs
}
default_cfgs = generate_default_cfgs({
'mobilenetv3_large_075.untrained': _cfg(url=''),
'mobilenetv3_large_100.ra_in1k': _cfg(
interpolation='bicubic',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth',
hf_hub_id='timm/'),
'mobilenetv3_large_100.miil_in21k_ft_in1k': _cfg(
interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.),
origin_url='https://github.com/Alibaba-MIIL/ImageNet21K',
paper_ids='arXiv:2104.10972v4',
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_1k_miil_78_0-66471c13.pth',
hf_hub_id='timm/'),
'mobilenetv3_large_100.miil_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/mobilenetv3_large_100_in21k_miil-d71cc17b.pth',
hf_hub_id='timm/',
origin_url='https://github.com/Alibaba-MIIL/ImageNet21K',
paper_ids='arXiv:2104.10972v4',
interpolation='bilinear', mean=(0., 0., 0.), std=(1., 1., 1.), num_classes=11221),
'mobilenetv3_small_050.lamb_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_050_lambc-4b7bbe87.pth',
hf_hub_id='timm/',
interpolation='bicubic'),
'mobilenetv3_small_075.lamb_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_075_lambc-384766db.pth',
hf_hub_id='timm/',
interpolation='bicubic'),
'mobilenetv3_small_100.lamb_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_small_100_lamb-266a294c.pth',
hf_hub_id='timm/',
interpolation='bicubic'),
'mobilenetv3_rw.rmsp_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth',
interpolation='bicubic'),
'tf_mobilenetv3_large_075.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth',
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_large_100.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth',
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_large_minimal_100.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth',
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_small_075.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth',
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_small_100.in1k': _cfg(
url= 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth',
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'tf_mobilenetv3_small_minimal_100.in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
hf_hub_id='timm/',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'fbnetv3_b.ra2_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_b_224-ead5d2a1.pth',
hf_hub_id='timm/',
test_input_size=(3, 256, 256), crop_pct=0.95),
'fbnetv3_d.ra2_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_d_224-c98bce42.pth',
hf_hub_id='timm/',
test_input_size=(3, 256, 256), crop_pct=0.95),
'fbnetv3_g.ra2_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetv3_g_240-0b1df83b.pth',
hf_hub_id='timm/',
input_size=(3, 240, 240), test_input_size=(3, 288, 288), crop_pct=0.95, pool_size=(8, 8)),
"lcnet_035.untrained": _cfg(),
"lcnet_050.ra2_in1k": _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_050-f447553b.pth',
hf_hub_id='timm/',
interpolation='bicubic',
),
"lcnet_075.ra2_in1k": _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_075-318cad2c.pth',
hf_hub_id='timm/',
interpolation='bicubic',
),
"lcnet_100.ra2_in1k": _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/lcnet_100-a929038c.pth',
hf_hub_id='timm/',
interpolation='bicubic',
),
"lcnet_150.untrained": _cfg(),
})
@register_model
def mobilenetv3_large_075(pretrained=False, **kwargs):
""" MobileNet V3 """
@ -581,24 +649,6 @@ def mobilenetv3_large_100(pretrained=False, **kwargs):
return model
@register_model
def mobilenetv3_large_100_miil(pretrained=False, **kwargs):
""" MobileNet V3
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
"""
model = _gen_mobilenet_v3('mobilenetv3_large_100_miil', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def mobilenetv3_large_100_miil_in21k(pretrained=False, **kwargs):
""" MobileNet V3, 21k pretraining
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
"""
model = _gen_mobilenet_v3('mobilenetv3_large_100_miil_in21k', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def mobilenetv3_small_050(pretrained=False, **kwargs):
""" MobileNet V3 """

@ -266,9 +266,16 @@ class MobileVitBlock(nn.Module):
self.transformer = nn.Sequential(*[
TransformerBlock(
transformer_dim, mlp_ratio=mlp_ratio, num_heads=num_heads, qkv_bias=True,
attn_drop=attn_drop, drop=drop, drop_path=drop_path_rate,
act_layer=layers.act, norm_layer=transformer_norm_layer)
transformer_dim,
mlp_ratio=mlp_ratio,
num_heads=num_heads,
qkv_bias=True,
attn_drop=attn_drop,
drop=drop,
drop_path=drop_path_rate,
act_layer=layers.act,
norm_layer=transformer_norm_layer,
)
for _ in range(transformer_depth)
])
self.norm = transformer_norm_layer(transformer_dim)

@ -17,7 +17,7 @@ Status:
Hacked together by / copyright Ross Wightman, 2021.
"""
from collections import OrderedDict
from dataclasses import dataclass
from dataclasses import dataclass, replace
from functools import partial
from typing import Tuple, Optional
@ -159,11 +159,25 @@ class NfCfg:
def _nfres_cfg(
depths, channels=(256, 512, 1024, 2048), group_size=None, act_layer='relu', attn_layer=None, attn_kwargs=None):
depths,
channels=(256, 512, 1024, 2048),
group_size=None,
act_layer='relu',
attn_layer=None,
attn_kwargs=None,
):
attn_kwargs = attn_kwargs or {}
cfg = NfCfg(
depths=depths, channels=channels, stem_type='7x7_pool', stem_chs=64, bottle_ratio=0.25,
group_size=group_size, act_layer=act_layer, attn_layer=attn_layer, attn_kwargs=attn_kwargs)
depths=depths,
channels=channels,
stem_type='7x7_pool',
stem_chs=64,
bottle_ratio=0.25,
group_size=group_size,
act_layer=act_layer,
attn_layer=attn_layer,
attn_kwargs=attn_kwargs,
)
return cfg
@ -171,28 +185,70 @@ def _nfreg_cfg(depths, channels=(48, 104, 208, 440)):
num_features = 1280 * channels[-1] // 440
attn_kwargs = dict(rd_ratio=0.5)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25,
num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs)
depths=depths,
channels=channels,
stem_type='3x3',
group_size=8,
width_factor=0.75,
bottle_ratio=2.25,
num_features=num_features,
reg=True,
attn_layer='se',
attn_kwargs=attn_kwargs,
)
return cfg
def _nfnet_cfg(
depths, channels=(256, 512, 1536, 1536), group_size=128, bottle_ratio=0.5, feat_mult=2.,
act_layer='gelu', attn_layer='se', attn_kwargs=None):
depths,
channels=(256, 512, 1536, 1536),
group_size=128,
bottle_ratio=0.5,
feat_mult=2.,
act_layer='gelu',
attn_layer='se',
attn_kwargs=None,
):
num_features = int(channels[-1] * feat_mult)
attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(rd_ratio=0.5)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=group_size,
bottle_ratio=bottle_ratio, extra_conv=True, num_features=num_features, act_layer=act_layer,
attn_layer=attn_layer, attn_kwargs=attn_kwargs)
depths=depths,
channels=channels,
stem_type='deep_quad',
stem_chs=128,
group_size=group_size,
bottle_ratio=bottle_ratio,
extra_conv=True,
num_features=num_features,
act_layer=act_layer,
attn_layer=attn_layer,
attn_kwargs=attn_kwargs,
)
return cfg
def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', skipinit=True):
def _dm_nfnet_cfg(
depths,
channels=(256, 512, 1536, 1536),
act_layer='gelu',
skipinit=True,
):
cfg = NfCfg(
depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=128,
bottle_ratio=0.5, extra_conv=True, gamma_in_act=True, same_padding=True, skipinit=skipinit,
num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=dict(rd_ratio=0.5))
depths=depths,
channels=channels,
stem_type='deep_quad',
stem_chs=128,
group_size=128,
bottle_ratio=0.5,
extra_conv=True,
gamma_in_act=True,
same_padding=True,
skipinit=skipinit,
num_features=int(channels[-1] * 2.0),
act_layer=act_layer,
attn_layer='se',
attn_kwargs=dict(rd_ratio=0.5),
)
return cfg
@ -278,7 +334,14 @@ def act_with_gamma(act_type, gamma: float = 1.):
class DownsampleAvg(nn.Module):
def __init__(
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, conv_layer=ScaledStdConv2d):
self,
in_chs,
out_chs,
stride=1,
dilation=1,
first_dilation=None,
conv_layer=ScaledStdConv2d,
):
""" AvgPool Downsampling as in 'D' ResNet variants. Support for dilation."""
super(DownsampleAvg, self).__init__()
avg_stride = stride if dilation == 1 else 1
@ -299,9 +362,26 @@ class NormFreeBlock(nn.Module):
"""
def __init__(
self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None,
alpha=1.0, beta=1.0, bottle_ratio=0.25, group_size=None, ch_div=1, reg=True, extra_conv=False,
skipinit=False, attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0.):
self,
in_chs,
out_chs=None,
stride=1,
dilation=1,
first_dilation=None,
alpha=1.0,
beta=1.0,
bottle_ratio=0.25,
group_size=None,
ch_div=1,
reg=True,
extra_conv=False,
skipinit=False,
attn_layer=None,
attn_gain=2.0,
act_layer=None,
conv_layer=None,
drop_path_rate=0.,
):
super().__init__()
first_dilation = first_dilation or dilation
out_chs = out_chs or in_chs
@ -316,7 +396,13 @@ class NormFreeBlock(nn.Module):
if in_chs != out_chs or stride != 1 or dilation != first_dilation:
self.downsample = DownsampleAvg(
in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, conv_layer=conv_layer)
in_chs,
out_chs,
stride=stride,
dilation=dilation,
first_dilation=first_dilation,
conv_layer=conv_layer,
)
else:
self.downsample = None
@ -452,14 +538,33 @@ class NormFreeNet(nn.Module):
for what it is/does. Approx 8-10% throughput loss.
"""
def __init__(
self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
drop_rate=0., drop_path_rate=0.
self,
cfg: NfCfg,
num_classes=1000,
in_chans=3,
global_pool='avg',
output_stride=32,
drop_rate=0.,
drop_path_rate=0.,
**kwargs,
):
"""
Args:
cfg (NfCfg): Model architecture configuration
num_classes (int): Number of classifier classes (default: 1000)
in_chans (int): Number of input channels (default: 3)
global_pool (str): Global pooling type (default: 'avg')
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
drop_rate (float): Dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
kwargs (dict): Extra kwargs overlayed onto cfg
"""
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
self.grad_checkpointing = False
cfg = replace(cfg, **kwargs)
assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})."
conv_layer = ScaledStdConv2dSame if cfg.same_padding else ScaledStdConv2d
if cfg.gamma_in_act:
@ -472,7 +577,12 @@ class NormFreeNet(nn.Module):
stem_chs = make_divisible((cfg.stem_chs or cfg.channels[0]) * cfg.width_factor, cfg.ch_div)
self.stem, stem_stride, stem_feat = create_stem(
in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer, act_layer=act_layer)
in_chans,
stem_chs,
cfg.stem_type,
conv_layer=conv_layer,
act_layer=act_layer,
)
self.feature_info = [stem_feat]
drop_path_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]

@ -14,7 +14,7 @@ Weights from original impl have been modified
Hacked together by / Copyright 2020 Ross Wightman
"""
import math
from dataclasses import dataclass
from dataclasses import dataclass, replace
from functools import partial
from typing import Optional, Union, Callable
@ -237,7 +237,15 @@ def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_la
def create_shortcut(
downsample_type, in_chs, out_chs, kernel_size, stride, dilation=(1, 1), norm_layer=None, preact=False):
downsample_type,
in_chs,
out_chs,
kernel_size,
stride,
dilation=(1, 1),
norm_layer=None,
preact=False,
):
assert downsample_type in ('avg', 'conv1x1', '', None)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
dargs = dict(stride=stride, dilation=dilation[0], norm_layer=norm_layer, preact=preact)
@ -259,9 +267,21 @@ class Bottleneck(nn.Module):
"""
def __init__(
self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25,
downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
drop_block=None, drop_path_rate=0.):
self,
in_chs,
out_chs,
stride=1,
dilation=(1, 1),
bottle_ratio=1,
group_size=1,
se_ratio=0.25,
downsample='conv1x1',
linear_out=False,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
drop_block=None,
drop_path_rate=0.,
):
super(Bottleneck, self).__init__()
act_layer = get_act_layer(act_layer)
bottleneck_chs = int(round(out_chs * bottle_ratio))
@ -307,9 +327,21 @@ class PreBottleneck(nn.Module):
"""
def __init__(
self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25,
downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
drop_block=None, drop_path_rate=0.):
self,
in_chs,
out_chs,
stride=1,
dilation=(1, 1),
bottle_ratio=1,
group_size=1,
se_ratio=0.25,
downsample='conv1x1',
linear_out=False,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
drop_block=None,
drop_path_rate=0.,
):
super(PreBottleneck, self).__init__()
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
bottleneck_chs = int(round(out_chs * bottle_ratio))
@ -353,8 +385,16 @@ class RegStage(nn.Module):
"""Stage (sequence of blocks w/ the same output shape)."""
def __init__(
self, depth, in_chs, out_chs, stride, dilation,
drop_path_rates=None, block_fn=Bottleneck, **block_kwargs):
self,
depth,
in_chs,
out_chs,
stride,
dilation,
drop_path_rates=None,
block_fn=Bottleneck,
**block_kwargs,
):
super(RegStage, self).__init__()
self.grad_checkpointing = False
@ -367,8 +407,13 @@ class RegStage(nn.Module):
name = "b{}".format(i + 1)
self.add_module(
name, block_fn(
block_in_chs, out_chs, stride=block_stride, dilation=block_dilation,
drop_path_rate=dpr, **block_kwargs)
block_in_chs,
out_chs,
stride=block_stride,
dilation=block_dilation,
drop_path_rate=dpr,
**block_kwargs,
)
)
first_dilation = dilation
@ -389,12 +434,35 @@ class RegNet(nn.Module):
"""
def __init__(
self, cfg: RegNetCfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg',
drop_rate=0., drop_path_rate=0., zero_init_last=True):
self,
cfg: RegNetCfg,
in_chans=3,
num_classes=1000,
output_stride=32,
global_pool='avg',
drop_rate=0.,
drop_path_rate=0.,
zero_init_last=True,
**kwargs,
):
"""
Args:
cfg (RegNetCfg): Model architecture configuration
in_chans (int): Number of input channels (default: 3)
num_classes (int): Number of classifier classes (default: 1000)
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
global_pool (str): Global pooling type (default: 'avg')
drop_rate (float): Dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
zero_init_last (bool): Zero-init last weight of residual path
kwargs (dict): Extra kwargs overlayed onto cfg
"""
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
assert output_stride in (8, 16, 32)
cfg = replace(cfg, **kwargs) # update cfg with extra passed kwargs
# Construct the stem
stem_width = cfg.stem_width
@ -461,8 +529,12 @@ class RegNet(nn.Module):
dict(zip(arg_names, params)) for params in
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_br, stage_gs, stage_dpr)]
common_args = dict(
downsample=cfg.downsample, se_ratio=cfg.se_ratio, linear_out=cfg.linear_out,
act_layer=cfg.act_layer, norm_layer=cfg.norm_layer)
downsample=cfg.downsample,
se_ratio=cfg.se_ratio,
linear_out=cfg.linear_out,
act_layer=cfg.act_layer,
norm_layer=cfg.norm_layer,
)
return per_stage_args, common_args
@torch.jit.ignore
@ -518,7 +590,6 @@ def _init_weights(module, name='', zero_init_last=False):
def _filter_fn(state_dict):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
if 'classy_state_dict' in state_dict:
import re
state_dict = state_dict['classy_state_dict']['base_model']['model']

@ -51,9 +51,21 @@ class Bottle2neck(nn.Module):
expansion = 4
def __init__(
self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=26, scale=4, dilation=1, first_dilation=None,
act_layer=nn.ReLU, norm_layer=None, attn_layer=None, **_):
self,
inplanes,
planes,
stride=1,
downsample=None,
cardinality=1,
base_width=26,
scale=4,
dilation=1,
first_dilation=None,
act_layer=nn.ReLU,
norm_layer=None,
attn_layer=None,
**_,
):
super(Bottle2neck, self).__init__()
self.scale = scale
self.is_first = stride > 1 or downsample is not None
@ -89,7 +101,8 @@ class Bottle2neck(nn.Module):
self.downsample = downsample
def zero_init_last(self):
nn.init.zeros_(self.bn3.weight)
if getattr(self.bn3, 'weight', None) is not None:
nn.init.zeros_(self.bn3.weight)
def forward(self, x):
shortcut = x
@ -143,8 +156,8 @@ def res2net50_26w_4s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4), **kwargs)
return _create_res2net('res2net50_26w_4s', pretrained, **model_args)
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=4))
return _create_res2net('res2net50_26w_4s', pretrained, **dict(model_args, **kwargs))
@register_model
@ -154,8 +167,8 @@ def res2net101_26w_4s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model_args = dict(
block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4), **kwargs)
return _create_res2net('res2net101_26w_4s', pretrained, **model_args)
block=Bottle2neck, layers=[3, 4, 23, 3], base_width=26, block_args=dict(scale=4))
return _create_res2net('res2net101_26w_4s', pretrained, **dict(model_args, **kwargs))
@register_model
@ -165,8 +178,8 @@ def res2net50_26w_6s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6), **kwargs)
return _create_res2net('res2net50_26w_6s', pretrained, **model_args)
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=6))
return _create_res2net('res2net50_26w_6s', pretrained, **dict(model_args, **kwargs))
@register_model
@ -176,8 +189,8 @@ def res2net50_26w_8s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8), **kwargs)
return _create_res2net('res2net50_26w_8s', pretrained, **model_args)
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=26, block_args=dict(scale=8))
return _create_res2net('res2net50_26w_8s', pretrained, **dict(model_args, **kwargs))
@register_model
@ -187,8 +200,8 @@ def res2net50_48w_2s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2), **kwargs)
return _create_res2net('res2net50_48w_2s', pretrained, **model_args)
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=48, block_args=dict(scale=2))
return _create_res2net('res2net50_48w_2s', pretrained, **dict(model_args, **kwargs))
@register_model
@ -198,8 +211,8 @@ def res2net50_14w_8s(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8), **kwargs)
return _create_res2net('res2net50_14w_8s', pretrained, **model_args)
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=14, block_args=dict(scale=8))
return _create_res2net('res2net50_14w_8s', pretrained, **dict(model_args, **kwargs))
@register_model
@ -209,5 +222,5 @@ def res2next50(pretrained=False, **kwargs):
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model_args = dict(
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4), **kwargs)
return _create_res2net('res2next50', pretrained, **model_args)
block=Bottle2neck, layers=[3, 4, 6, 3], base_width=4, cardinality=8, block_args=dict(scale=4))
return _create_res2net('res2next50', pretrained, **dict(model_args, **kwargs))

@ -57,10 +57,27 @@ class ResNestBottleneck(nn.Module):
expansion = 4
def __init__(
self, inplanes, planes, stride=1, downsample=None,
radix=1, cardinality=1, base_width=64, avd=False, avd_first=False, is_first=False,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
self,
inplanes,
planes,
stride=1,
downsample=None,
radix=1,
cardinality=1,
base_width=64,
avd=False,
avd_first=False,
is_first=False,
reduce_first=1,
dilation=1,
first_dilation=None,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
attn_layer=None,
aa_layer=None,
drop_block=None,
drop_path=None,
):
super(ResNestBottleneck, self).__init__()
assert reduce_first == 1 # not supported
assert attn_layer is None # not supported
@ -103,7 +120,8 @@ class ResNestBottleneck(nn.Module):
self.downsample = downsample
def zero_init_last(self):
nn.init.zeros_(self.bn3.weight)
if getattr(self.bn3, 'weight', None) is not None:
nn.init.zeros_(self.bn3.weight)
def forward(self, x):
shortcut = x
@ -145,8 +163,8 @@ def resnest14d(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[1, 1, 1, 1],
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
return _create_resnest('resnest14d', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest14d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
@ -156,8 +174,8 @@ def resnest26d(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[2, 2, 2, 2],
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
return _create_resnest('resnest26d', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest26d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
@ -168,8 +186,8 @@ def resnest50d(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 4, 6, 3],
stem_type='deep', stem_width=32, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
return _create_resnest('resnest50d', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest50d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
@ -180,8 +198,8 @@ def resnest101e(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 4, 23, 3],
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
return _create_resnest('resnest101e', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest101e', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
@ -192,8 +210,8 @@ def resnest200e(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 24, 36, 3],
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
return _create_resnest('resnest200e', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest200e', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
@ -204,8 +222,8 @@ def resnest269e(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 30, 48, 8],
stem_type='deep', stem_width=64, avg_down=True, base_width=64, cardinality=1,
block_args=dict(radix=2, avd=True, avd_first=False), **kwargs)
return _create_resnest('resnest269e', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=2, avd=True, avd_first=False))
return _create_resnest('resnest269e', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
@ -215,8 +233,8 @@ def resnest50d_4s2x40d(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 4, 6, 3],
stem_type='deep', stem_width=32, avg_down=True, base_width=40, cardinality=2,
block_args=dict(radix=4, avd=True, avd_first=True), **kwargs)
return _create_resnest('resnest50d_4s2x40d', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=4, avd=True, avd_first=True))
return _create_resnest('resnest50d_4s2x40d', pretrained=pretrained, **dict(model_kwargs, **kwargs))
@register_model
@ -226,5 +244,5 @@ def resnest50d_1s4x24d(pretrained=False, **kwargs):
model_kwargs = dict(
block=ResNestBottleneck, layers=[3, 4, 6, 3],
stem_type='deep', stem_width=32, avg_down=True, base_width=24, cardinality=4,
block_args=dict(radix=1, avd=True, avd_first=True), **kwargs)
return _create_resnest('resnest50d_1s4x24d', pretrained=pretrained, **model_kwargs)
block_args=dict(radix=1, avd=True, avd_first=True))
return _create_resnest('resnest50d_1s4x24d', pretrained=pretrained, **dict(model_kwargs, **kwargs))

File diff suppressed because it is too large Load Diff

@ -37,7 +37,7 @@ import torch.nn as nn
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0, FilterResponseNormTlu2d, \
ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d
ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d, get_act_layer, get_norm_act_layer
from ._builder import build_model_with_cfg
from ._manipulate import checkpoint_seq, named_apply, adapt_input_conv
from ._registry import register_model
@ -155,8 +155,20 @@ class PreActBottleneck(nn.Module):
"""
def __init__(
self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
self,
in_chs,
out_chs=None,
bottle_ratio=0.25,
stride=1,
dilation=1,
first_dilation=None,
groups=1,
act_layer=None,
conv_layer=None,
norm_layer=None,
proj_layer=None,
drop_path_rate=0.,
):
super().__init__()
first_dilation = first_dilation or dilation
conv_layer = conv_layer or StdConv2d
@ -202,8 +214,20 @@ class Bottleneck(nn.Module):
"""Non Pre-activation bottleneck block, equiv to V1.5/V1b Bottleneck. Used for ViT.
"""
def __init__(
self, in_chs, out_chs=None, bottle_ratio=0.25, stride=1, dilation=1, first_dilation=None, groups=1,
act_layer=None, conv_layer=None, norm_layer=None, proj_layer=None, drop_path_rate=0.):
self,
in_chs,
out_chs=None,
bottle_ratio=0.25,
stride=1,
dilation=1,
first_dilation=None,
groups=1,
act_layer=None,
conv_layer=None,
norm_layer=None,
proj_layer=None,
drop_path_rate=0.,
):
super().__init__()
first_dilation = first_dilation or dilation
act_layer = act_layer or nn.ReLU
@ -229,7 +253,8 @@ class Bottleneck(nn.Module):
self.act3 = act_layer(inplace=True)
def zero_init_last(self):
nn.init.zeros_(self.norm3.weight)
if getattr(self.norm3, 'weight', None) is not None:
nn.init.zeros_(self.norm3.weight)
def forward(self, x):
# shortcut branch
@ -251,8 +276,16 @@ class Bottleneck(nn.Module):
class DownsampleConv(nn.Module):
def __init__(
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, preact=True,
conv_layer=None, norm_layer=None):
self,
in_chs,
out_chs,
stride=1,
dilation=1,
first_dilation=None,
preact=True,
conv_layer=None,
norm_layer=None,
):
super(DownsampleConv, self).__init__()
self.conv = conv_layer(in_chs, out_chs, 1, stride=stride)
self.norm = nn.Identity() if preact else norm_layer(out_chs, apply_act=False)
@ -263,8 +296,16 @@ class DownsampleConv(nn.Module):
class DownsampleAvg(nn.Module):
def __init__(
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None,
preact=True, conv_layer=None, norm_layer=None):
self,
in_chs,
out_chs,
stride=1,
dilation=1,
first_dilation=None,
preact=True,
conv_layer=None,
norm_layer=None,
):
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
super(DownsampleAvg, self).__init__()
avg_stride = stride if dilation == 1 else 1
@ -283,9 +324,22 @@ class DownsampleAvg(nn.Module):
class ResNetStage(nn.Module):
"""ResNet Stage."""
def __init__(
self, in_chs, out_chs, stride, dilation, depth, bottle_ratio=0.25, groups=1,
avg_down=False, block_dpr=None, block_fn=PreActBottleneck,
act_layer=None, conv_layer=None, norm_layer=None, **block_kwargs):
self,
in_chs,
out_chs,
stride,
dilation,
depth,
bottle_ratio=0.25,
groups=1,
avg_down=False,
block_dpr=None,
block_fn=PreActBottleneck,
act_layer=None,
conv_layer=None,
norm_layer=None,
**block_kwargs,
):
super(ResNetStage, self).__init__()
first_dilation = 1 if dilation in (1, 2) else 2
layer_kwargs = dict(act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer)
@ -296,9 +350,18 @@ class ResNetStage(nn.Module):
drop_path_rate = block_dpr[block_idx] if block_dpr else 0.
stride = stride if block_idx == 0 else 1
self.blocks.add_module(str(block_idx), block_fn(
prev_chs, out_chs, stride=stride, dilation=dilation, bottle_ratio=bottle_ratio, groups=groups,
first_dilation=first_dilation, proj_layer=proj_layer, drop_path_rate=drop_path_rate,
**layer_kwargs, **block_kwargs))
prev_chs,
out_chs,
stride=stride,
dilation=dilation,
bottle_ratio=bottle_ratio,
groups=groups,
first_dilation=first_dilation,
proj_layer=proj_layer,
drop_path_rate=drop_path_rate,
**layer_kwargs,
**block_kwargs,
))
prev_chs = out_chs
first_dilation = dilation
proj_layer = None
@ -313,8 +376,13 @@ def is_stem_deep(stem_type):
def create_resnetv2_stem(
in_chs, out_chs=64, stem_type='', preact=True,
conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32)):
in_chs,
out_chs=64,
stem_type='',
preact=True,
conv_layer=StdConv2d,
norm_layer=partial(GroupNormAct, num_groups=32),
):
stem = OrderedDict()
assert stem_type in ('', 'fixed', 'same', 'deep', 'deep_fixed', 'deep_same', 'tiered')
@ -357,20 +425,62 @@ class ResNetV2(nn.Module):
"""
def __init__(
self, layers, channels=(256, 512, 1024, 2048),
num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
width_factor=1, stem_chs=64, stem_type='', avg_down=False, preact=True,
act_layer=nn.ReLU, conv_layer=StdConv2d, norm_layer=partial(GroupNormAct, num_groups=32),
drop_rate=0., drop_path_rate=0., zero_init_last=False):
self,
layers,
channels=(256, 512, 1024, 2048),
num_classes=1000,
in_chans=3,
global_pool='avg',
output_stride=32,
width_factor=1,
stem_chs=64,
stem_type='',
avg_down=False,
preact=True,
act_layer=nn.ReLU,
norm_layer=partial(GroupNormAct, num_groups=32),
conv_layer=StdConv2d,
drop_rate=0.,
drop_path_rate=0.,
zero_init_last=False,
):
"""
Args:
layers (List[int]) : number of layers in each block
channels (List[int]) : number of channels in each block:
num_classes (int): number of classification classes (default 1000)
in_chans (int): number of input (color) channels. (default 3)
global_pool (str): Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' (default 'avg')
output_stride (int): output stride of the network, 32, 16, or 8. (default 32)
width_factor (int): channel (width) multiplication factor
stem_chs (int): stem width (default: 64)
stem_type (str): stem type (default: '' == 7x7)
avg_down (bool): average pooling in residual downsampling (default: False)
preact (bool): pre-activiation (default: True)
act_layer (Union[str, nn.Module]): activation layer
norm_layer (Union[str, nn.Module]): normalization layer
conv_layer (nn.Module): convolution module
drop_rate: classifier dropout rate (default: 0.)
drop_path_rate: stochastic depth rate (default: 0.)
zero_init_last: zero-init last weight in residual path (default: False)
"""
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
wf = width_factor
norm_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
act_layer = get_act_layer(act_layer)
self.feature_info = []
stem_chs = make_div(stem_chs * wf)
self.stem = create_resnetv2_stem(
in_chans, stem_chs, stem_type, preact, conv_layer=conv_layer, norm_layer=norm_layer)
in_chans,
stem_chs,
stem_type,
preact,
conv_layer=conv_layer,
norm_layer=norm_layer,
)
stem_feat = ('stem.conv3' if is_stem_deep(stem_type) else 'stem.conv') if preact else 'stem.norm'
self.feature_info.append(dict(num_chs=stem_chs, reduction=2, module=stem_feat))
@ -387,8 +497,18 @@ class ResNetV2(nn.Module):
dilation *= stride
stride = 1
stage = ResNetStage(
prev_chs, out_chs, stride=stride, dilation=dilation, depth=d, avg_down=avg_down,
act_layer=act_layer, conv_layer=conv_layer, norm_layer=norm_layer, block_dpr=bdpr, block_fn=block_fn)
prev_chs,
out_chs,
stride=stride,
dilation=dilation,
depth=d,
avg_down=avg_down,
act_layer=act_layer,
conv_layer=conv_layer,
norm_layer=norm_layer,
block_dpr=bdpr,
block_fn=block_fn,
)
prev_chs = out_chs
curr_stride *= stride
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{stage_idx}')]
@ -626,86 +746,83 @@ def resnetv2_152x2_bit_teacher_384(pretrained=False, **kwargs):
@register_model
def resnetv2_50(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50', pretrained=pretrained,
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
model_args = dict(layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d)
return _create_resnetv2('resnetv2_50', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_50d(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d', pretrained=pretrained,
model_args = dict(
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
stem_type='deep', avg_down=True, **kwargs)
stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_50d', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_50t(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50t', pretrained=pretrained,
model_args = dict(
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
stem_type='tiered', avg_down=True, **kwargs)
stem_type='tiered', avg_down=True)
return _create_resnetv2('resnetv2_50t', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_101(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101', pretrained=pretrained,
layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
model_args = dict(layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d)
return _create_resnetv2('resnetv2_101', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_101d(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_101d', pretrained=pretrained,
model_args = dict(
layers=[3, 4, 23, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
stem_type='deep', avg_down=True, **kwargs)
stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_101d', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_152(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152', pretrained=pretrained,
layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d, **kwargs)
model_args = dict(layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d)
return _create_resnetv2('resnetv2_152', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_152d(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_152d', pretrained=pretrained,
model_args = dict(
layers=[3, 8, 36, 3], conv_layer=create_conv2d, norm_layer=BatchNormAct2d,
stem_type='deep', avg_down=True, **kwargs)
stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_152d', pretrained=pretrained, **dict(model_args, **kwargs))
# Experimental configs (may change / be removed)
@register_model
def resnetv2_50d_gn(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d_gn', pretrained=pretrained,
model_args = dict(
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=GroupNormAct,
stem_type='deep', avg_down=True, **kwargs)
stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_50d_gn', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_50d_evob(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d_evob', pretrained=pretrained,
model_args = dict(
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dB0,
stem_type='deep', avg_down=True, zero_init_last=True, **kwargs)
stem_type='deep', avg_down=True, zero_init_last=True)
return _create_resnetv2('resnetv2_50d_evob', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_50d_evos(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d_evos', pretrained=pretrained,
model_args = dict(
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dS0,
stem_type='deep', avg_down=True, **kwargs)
stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_50d_evos', pretrained=pretrained, **dict(model_args, **kwargs))
@register_model
def resnetv2_50d_frn(pretrained=False, **kwargs):
return _create_resnetv2(
'resnetv2_50d_frn', pretrained=pretrained,
model_args = dict(
layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=FilterResponseNormTlu2d,
stem_type='deep', avg_down=True, **kwargs)
stem_type='deep', avg_down=True)
return _create_resnetv2('resnetv2_50d_frn', pretrained=pretrained, **dict(model_args, **kwargs))

@ -47,9 +47,24 @@ class SelectiveKernelBasic(nn.Module):
expansion = 1
def __init__(
self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64,
sk_kwargs=None, reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
self,
inplanes,
planes,
stride=1,
downsample=None,
cardinality=1,
base_width=64,
sk_kwargs=None,
reduce_first=1,
dilation=1,
first_dilation=None,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
attn_layer=None,
aa_layer=None,
drop_block=None,
drop_path=None,
):
super(SelectiveKernelBasic, self).__init__()
sk_kwargs = sk_kwargs or {}
@ -71,7 +86,8 @@ class SelectiveKernelBasic(nn.Module):
self.drop_path = drop_path
def zero_init_last(self):
nn.init.zeros_(self.conv2.bn.weight)
if getattr(self.conv2.bn, 'weight', None) is not None:
nn.init.zeros_(self.conv2.bn.weight)
def forward(self, x):
shortcut = x
@ -92,9 +108,24 @@ class SelectiveKernelBottleneck(nn.Module):
expansion = 4
def __init__(
self, inplanes, planes, stride=1, downsample=None, cardinality=1, base_width=64, sk_kwargs=None,
reduce_first=1, dilation=1, first_dilation=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
self,
inplanes,
planes,
stride=1,
downsample=None,
cardinality=1,
base_width=64,
sk_kwargs=None,
reduce_first=1,
dilation=1,
first_dilation=None,
act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d,
attn_layer=None,
aa_layer=None,
drop_block=None,
drop_path=None,
):
super(SelectiveKernelBottleneck, self).__init__()
sk_kwargs = sk_kwargs or {}
@ -115,7 +146,8 @@ class SelectiveKernelBottleneck(nn.Module):
self.drop_path = drop_path
def zero_init_last(self):
nn.init.zeros_(self.conv3.bn.weight)
if getattr(self.conv3.bn, 'weight', None) is not None:
nn.init.zeros_(self.conv3.bn.weight)
def forward(self, x):
shortcut = x

File diff suppressed because it is too large Load Diff

@ -27,72 +27,6 @@ from .resnetv2 import ResNetV2, create_resnetv2_stem
from .vision_transformer import _create_vision_transformer
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head',
**kwargs
}
default_cfgs = generate_default_cfgs({
# hybrid in-1k models (weights from official JAX impl where they exist)
'vit_tiny_r_s16_p8_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
custom_load=True,
first_conv='patch_embed.backbone.conv'),
'vit_tiny_r_s16_p8_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
'vit_small_r26_s32_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz',
custom_load=True,
),
'vit_small_r26_s32_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
'vit_base_r26_s32_224.untrained': _cfg(),
'vit_base_r50_s16_384.v1_in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_r50_s32_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
custom_load=True,
),
'vit_large_r50_s32_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
input_size=(3, 384, 384), crop_pct=1.0, custom_load=True,
),
# hybrid in-21k models (weights from official Google JAX impl where they exist)
'vit_tiny_r_s16_p8_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv', custom_load=True),
'vit_small_r26_s32_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9, custom_load=True),
'vit_base_r50_s16_224.v1_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
num_classes=21843, crop_pct=0.9),
'vit_large_r50_s32_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz',
num_classes=21843, crop_pct=0.9, custom_load=True),
# hybrid models (using timm resnet backbones)
'vit_small_resnet26d_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_small_resnet50d_s16_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_base_resnet26d_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_base_resnet50d_224': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
})
class HybridEmbed(nn.Module):
""" CNN Feature Map Embedding
Extract feature map from CNN, flatten, project to embedding dim.
@ -166,6 +100,83 @@ def _resnetv2(layers=(3, 4, 9), **kwargs):
return backbone
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head',
**kwargs
}
default_cfgs = generate_default_cfgs({
# hybrid in-1k models (weights from official JAX impl where they exist)
'vit_tiny_r_s16_p8_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_224.npz',
hf_hub_id='timm/',
custom_load=True,
first_conv='patch_embed.backbone.conv'),
'vit_tiny_r_s16_p8_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
hf_hub_id='timm/',
first_conv='patch_embed.backbone.conv', input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
'vit_small_r26_s32_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_light0-wd_0.03-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.03-res_224.npz',
hf_hub_id='timm/',
custom_load=True,
),
'vit_small_r26_s32_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.03-res_384.npz',
hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0, custom_load=True),
'vit_base_r26_s32_224.untrained': _cfg(),
'vit_base_r50_s16_384.orig_in21k_ft_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_384-9fd3c705.pth',
hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0),
'vit_large_r50_s32_224.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium1-wd_0.1-do_0.1-sd_0.1--imagenet2012-steps_20k-lr_0.01-res_224.npz',
hf_hub_id='timm/',
custom_load=True,
),
'vit_large_r50_s32_384.augreg_in21k_ft_in1k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0--imagenet2012-steps_20k-lr_0.01-res_384.npz',
hf_hub_id='timm/',
input_size=(3, 384, 384), crop_pct=1.0, custom_load=True,
),
# hybrid in-21k models (weights from official Google JAX impl where they exist)
'vit_tiny_r_s16_p8_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R_Ti_16-i21k-300ep-lr_0.001-aug_none-wd_0.03-do_0.0-sd_0.0.npz',
hf_hub_id='timm/',
num_classes=21843, crop_pct=0.9, first_conv='patch_embed.backbone.conv', custom_load=True),
'vit_small_r26_s32_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R26_S_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.03-do_0.0-sd_0.0.npz',
hf_hub_id='timm/',
num_classes=21843, crop_pct=0.9, custom_load=True),
'vit_base_r50_s16_224.orig_in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_resnet50_224_in21k-6f7c7740.pth',
hf_hub_id='timm/',
num_classes=21843, crop_pct=0.9),
'vit_large_r50_s32_224.augreg_in21k': _cfg(
url='https://storage.googleapis.com/vit_models/augreg/R50_L_32-i21k-300ep-lr_0.001-aug_medium2-wd_0.1-do_0.0-sd_0.0.npz',
hf_hub_id='timm/',
num_classes=21843, crop_pct=0.9, custom_load=True),
# hybrid models (using timm resnet backbones)
'vit_small_resnet26d_224.untrained': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_small_resnet50d_s16_224.untrained': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_base_resnet26d_224.untrained': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
'vit_base_resnet50d_224.untrained': _cfg(
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, first_conv='patch_embed.backbone.conv1.0'),
})
@register_model
def vit_tiny_r_s16_p8_224(pretrained=False, **kwargs):
""" R+ViT-Ti/S16 w/ 8x8 patch hybrid @ 224 x 224.

@ -11,12 +11,12 @@ from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from timm.layers import PatchEmbed, Mlp, DropPath, trunc_normal_
from timm.layers import PatchEmbed, Mlp, DropPath, RelPosMlp, RelPosBias
from ._builder import build_model_with_cfg
from ._pretrained import generate_default_cfgs
from ._registry import register_model
__all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint fn to this
@ -24,216 +24,6 @@ __all__ = ['VisionTransformerRelPos'] # model_registry will add each entrypoint
_logger = logging.getLogger(__name__)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
'vit_relpos_base_patch32_plus_rpn_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_replos_base_patch32_plus_rpn_256-sw-dd486f51.pth',
input_size=(3, 256, 256)),
'vit_relpos_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240)),
'vit_relpos_small_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_small_patch16_224-sw-ec2778b4.pth'),
'vit_relpos_medium_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_224-sw-11c174af.pth'),
'vit_relpos_base_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth'),
'vit_srelpos_small_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_small_patch16_224-sw-6cdb8849.pth'),
'vit_srelpos_medium_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_medium_patch16_224-sw-ad702b8c.pth'),
'vit_relpos_medium_patch16_cls_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_cls_224-sw-cfe8e259.pth'),
'vit_relpos_base_patch16_cls_224': _cfg(
url=''),
'vit_relpos_base_patch16_clsgap_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth'),
'vit_relpos_small_patch16_rpn_224': _cfg(url=''),
'vit_relpos_medium_patch16_rpn_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_rpn_224-sw-5d2befd8.pth'),
'vit_relpos_base_patch16_rpn_224': _cfg(url=''),
}
def gen_relative_position_index(
q_size: Tuple[int, int],
k_size: Tuple[int, int] = None,
class_token: bool = False) -> torch.Tensor:
# Adapted with significant modifications from Swin / BeiT codebases
# get pair-wise relative position index for each token inside the window
q_coords = torch.stack(torch.meshgrid([torch.arange(q_size[0]), torch.arange(q_size[1])])).flatten(1) # 2, Wh, Ww
if k_size is None:
k_coords = q_coords
k_size = q_size
else:
# different q vs k sizes is a WIP
k_coords = torch.stack(torch.meshgrid([torch.arange(k_size[0]), torch.arange(k_size[1])])).flatten(1)
relative_coords = q_coords[:, :, None] - k_coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0) # Wh*Ww, Wh*Ww, 2
_, relative_position_index = torch.unique(relative_coords.view(-1, 2), return_inverse=True, dim=0)
if class_token:
# handle cls to token & token 2 cls & cls to cls as per beit for rel pos bias
# NOTE not intended or tested with MLP log-coords
max_size = (max(q_size[0], k_size[0]), max(q_size[1], k_size[1]))
num_relative_distance = (2 * max_size[0] - 1) * (2 * max_size[1] - 1) + 3
relative_position_index = F.pad(relative_position_index, [1, 0, 1, 0])
relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = num_relative_distance - 2
relative_position_index[0, 0] = num_relative_distance - 1
return relative_position_index.contiguous()
def gen_relative_log_coords(
win_size: Tuple[int, int],
pretrained_win_size: Tuple[int, int] = (0, 0),
mode='swin',
):
assert mode in ('swin', 'cr', 'rw')
# as per official swin-v2 impl, supporting timm specific 'cr' and 'rw' log coords as well
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2
if mode == 'swin':
if pretrained_win_size[0] > 0:
relative_coords_table[:, :, 0] /= (pretrained_win_size[0] - 1)
relative_coords_table[:, :, 1] /= (pretrained_win_size[1] - 1)
else:
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
relative_coords_table *= 8 # normalize to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs()) / math.log2(8)
else:
if mode == 'rw':
# cr w/ window size normalization -> [-1,1] log coords
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
relative_coords_table *= 8 # scale to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs())
relative_coords_table /= math.log2(9) # -> [-1, 1]
else:
# mode == 'cr'
relative_coords_table = torch.sign(relative_coords_table) * torch.log(
1.0 + relative_coords_table.abs())
return relative_coords_table
class RelPosMlp(nn.Module):
def __init__(
self,
window_size,
num_heads=8,
hidden_dim=128,
prefix_tokens=0,
mode='cr',
pretrained_window_size=(0, 0)
):
super().__init__()
self.window_size = window_size
self.window_area = self.window_size[0] * self.window_size[1]
self.prefix_tokens = prefix_tokens
self.num_heads = num_heads
self.bias_shape = (self.window_area,) * 2 + (num_heads,)
if mode == 'swin':
self.bias_act = nn.Sigmoid()
self.bias_gain = 16
mlp_bias = (True, False)
elif mode == 'rw':
self.bias_act = nn.Tanh()
self.bias_gain = 4
mlp_bias = True
else:
self.bias_act = nn.Identity()
self.bias_gain = None
mlp_bias = True
self.mlp = Mlp(
2, # x, y
hidden_features=hidden_dim,
out_features=num_heads,
act_layer=nn.ReLU,
bias=mlp_bias,
drop=(0.125, 0.)
)
self.register_buffer(
"relative_position_index",
gen_relative_position_index(window_size),
persistent=False)
# get relative_coords_table
self.register_buffer(
"rel_coords_log",
gen_relative_log_coords(window_size, pretrained_window_size, mode=mode),
persistent=False)
def get_bias(self) -> torch.Tensor:
relative_position_bias = self.mlp(self.rel_coords_log)
if self.relative_position_index is not None:
relative_position_bias = relative_position_bias.view(-1, self.num_heads)[
self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.view(self.bias_shape)
relative_position_bias = relative_position_bias.permute(2, 0, 1)
relative_position_bias = self.bias_act(relative_position_bias)
if self.bias_gain is not None:
relative_position_bias = self.bias_gain * relative_position_bias
if self.prefix_tokens:
relative_position_bias = F.pad(relative_position_bias, [self.prefix_tokens, 0, self.prefix_tokens, 0])
return relative_position_bias.unsqueeze(0).contiguous()
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
return attn + self.get_bias()
class RelPosBias(nn.Module):
def __init__(self, window_size, num_heads, prefix_tokens=0):
super().__init__()
assert prefix_tokens <= 1
self.window_size = window_size
self.window_area = window_size[0] * window_size[1]
self.bias_shape = (self.window_area + prefix_tokens,) * 2 + (num_heads,)
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * prefix_tokens
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
self.register_buffer(
"relative_position_index",
gen_relative_position_index(self.window_size, class_token=prefix_tokens > 0),
persistent=False,
)
self.init_weights()
def init_weights(self):
trunc_normal_(self.relative_position_bias_table, std=.02)
def get_bias(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
# win_h * win_w, win_h * win_w, num_heads
relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1)
return relative_position_bias.unsqueeze(0).contiguous()
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
return attn + self.get_bias()
class RelPosAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, rel_pos_cls=None, attn_drop=0., proj_drop=0.):
super().__init__()
@ -513,6 +303,57 @@ def _create_vision_transformer_relpos(variant, pretrained=False, **kwargs):
return model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = generate_default_cfgs({
'vit_relpos_base_patch32_plus_rpn_256.sw_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_replos_base_patch32_plus_rpn_256-sw-dd486f51.pth',
hf_hub_id='timm/',
input_size=(3, 256, 256)),
'vit_relpos_base_patch16_plus_240.untrained': _cfg(url='', input_size=(3, 240, 240)),
'vit_relpos_small_patch16_224.sw_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_small_patch16_224-sw-ec2778b4.pth',
hf_hub_id='timm/'),
'vit_relpos_medium_patch16_224.sw_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_224-sw-11c174af.pth',
hf_hub_id='timm/'),
'vit_relpos_base_patch16_224.sw_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth',
hf_hub_id='timm/'),
'vit_srelpos_small_patch16_224.sw_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_small_patch16_224-sw-6cdb8849.pth',
hf_hub_id='timm/'),
'vit_srelpos_medium_patch16_224.sw_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_srelpos_medium_patch16_224-sw-ad702b8c.pth',
hf_hub_id='timm/'),
'vit_relpos_medium_patch16_cls_224.sw_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_cls_224-sw-cfe8e259.pth',
hf_hub_id='timm/'),
'vit_relpos_base_patch16_cls_224.untrained': _cfg(),
'vit_relpos_base_patch16_clsgap_224.sw_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth',
hf_hub_id='timm/'),
'vit_relpos_small_patch16_rpn_224.untrained': _cfg(),
'vit_relpos_medium_patch16_rpn_224.sw_in1k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_rpn_224-sw-5d2befd8.pth',
hf_hub_id='timm/'),
'vit_relpos_base_patch16_rpn_224.untrained': _cfg(),
})
@register_model
def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token

@ -181,8 +181,18 @@ class SequentialAppendList(nn.Sequential):
class OsaBlock(nn.Module):
def __init__(
self, in_chs, mid_chs, out_chs, layer_per_block, residual=False,
depthwise=False, attn='', norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path=None):
self,
in_chs,
mid_chs,
out_chs,
layer_per_block,
residual=False,
depthwise=False,
attn='',
norm_layer=BatchNormAct2d,
act_layer=nn.ReLU,
drop_path=None,
):
super(OsaBlock, self).__init__()
self.residual = residual
@ -232,9 +242,20 @@ class OsaBlock(nn.Module):
class OsaStage(nn.Module):
def __init__(
self, in_chs, mid_chs, out_chs, block_per_stage, layer_per_block, downsample=True,
residual=True, depthwise=False, attn='ese', norm_layer=BatchNormAct2d, act_layer=nn.ReLU,
drop_path_rates=None):
self,
in_chs,
mid_chs,
out_chs,
block_per_stage,
layer_per_block,
downsample=True,
residual=True,
depthwise=False,
attn='ese',
norm_layer=BatchNormAct2d,
act_layer=nn.ReLU,
drop_path_rates=None,
):
super(OsaStage, self).__init__()
self.grad_checkpointing = False
@ -270,16 +291,38 @@ class OsaStage(nn.Module):
class VovNet(nn.Module):
def __init__(
self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_rate=0., stem_stride=4,
output_stride=32, norm_layer=BatchNormAct2d, act_layer=nn.ReLU, drop_path_rate=0.):
""" VovNet (v2)
self,
cfg,
in_chans=3,
num_classes=1000,
global_pool='avg',
output_stride=32,
norm_layer=BatchNormAct2d,
act_layer=nn.ReLU,
drop_rate=0.,
drop_path_rate=0.,
**kwargs,
):
"""
Args:
cfg (dict): Model architecture configuration
in_chans (int): Number of input channels (default: 3)
num_classes (int): Number of classifier classes (default: 1000)
global_pool (str): Global pooling type (default: 'avg')
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
norm_layer (Union[str, nn.Module]): normalization layer
act_layer (Union[str, nn.Module]): activation layer
drop_rate (float): Dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth drop-path rate (default: 0.)
kwargs (dict): Extra kwargs overlayed onto cfg
"""
super(VovNet, self).__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
assert stem_stride in (4, 2)
assert output_stride == 32 # FIXME support dilation
cfg = dict(cfg, **kwargs)
stem_stride = cfg.get("stem_stride", 4)
stem_chs = cfg["stem_chs"]
stage_conv_chs = cfg["stage_conv_chs"]
stage_out_chs = cfg["stage_out_chs"]
@ -307,9 +350,15 @@ class VovNet(nn.Module):
for i in range(4): # num_stages
downsample = stem_stride == 2 or i > 0 # first stage has no stride/downsample if stem_stride is 4
stages += [OsaStage(
in_ch_list[i], stage_conv_chs[i], stage_out_chs[i], block_per_stage[i], layer_per_block,
downsample=downsample, drop_path_rates=stage_dpr[i], **stage_args)
]
in_ch_list[i],
stage_conv_chs[i],
stage_out_chs[i],
block_per_stage[i],
layer_per_block,
downsample=downsample,
drop_path_rates=stage_dpr[i],
**stage_args,
)]
self.num_features = stage_out_chs[i]
current_stride *= 2 if downsample else 1
self.feature_info += [dict(num_chs=self.num_features, reduction=current_stride, module=f'stages.{i}')]
@ -324,7 +373,6 @@ class VovNet(nn.Module):
elif isinstance(m, nn.Linear):
nn.init.zeros_(m.bias)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(

@ -8,7 +8,7 @@ from .distributed import distribute_bn, reduce_tensor, init_distributed_device,\
from .jit import set_jit_legacy, set_jit_fuser
from .log import setup_default_logging, FormatterNoInfo
from .metrics import AverageMeter, accuracy
from .misc import natural_key, add_bool_arg
from .misc import natural_key, add_bool_arg, ParseKwargs
from .model import unwrap_model, get_state_dict, freeze, unfreeze
from .model_ema import ModelEma, ModelEmaV2
from .random import random_seed

@ -2,6 +2,8 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import argparse
import ast
import re
@ -16,3 +18,15 @@ def add_bool_arg(parser, name, default=False, help=''):
group.add_argument('--' + name, dest=dest_name, action='store_true', help=help)
group.add_argument('--no-' + name, dest=dest_name, action='store_false', help=help)
parser.set_defaults(**{dest_name: default})
class ParseKwargs(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
kw = {}
for value in values:
key, value = value.split('=')
try:
kw[key] = ast.literal_eval(value)
except ValueError:
kw[key] = str(value) # fallback to string (avoid need to escape on command line)
setattr(namespace, self.dest, kw)

@ -7,6 +7,8 @@ import fnmatch
import torch
from torchvision.ops.misc import FrozenBatchNorm2d
from timm.layers import BatchNormAct2d, SyncBatchNormAct, FrozenBatchNormAct2d,\
freeze_batch_norm_2d, unfreeze_batch_norm_2d
from .model_ema import ModelEma
@ -100,70 +102,6 @@ def extract_spp_stats(
return hook.stats
def freeze_batch_norm_2d(module):
"""
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
returned. Otherwise, the module is walked recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
res = FrozenBatchNorm2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = freeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def unfreeze_batch_norm_2d(module):
"""
Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance
of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked
recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
if isinstance(module, FrozenBatchNorm2d):
res = torch.nn.BatchNorm2d(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = unfreeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'):
"""
Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is
@ -179,7 +117,12 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
"""
assert mode in ["freeze", "unfreeze"], '`mode` must be one of "freeze" or "unfreeze"'
if isinstance(root_module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
if isinstance(root_module, (
torch.nn.modules.batchnorm.BatchNorm2d,
torch.nn.modules.batchnorm.SyncBatchNorm,
BatchNormAct2d,
SyncBatchNormAct,
)):
# Raise assertion here because we can't convert it in place
raise AssertionError(
"You have provided a batch norm layer as the `root module`. Please use "
@ -213,13 +156,18 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True,
# It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't
# convert it in place, but will return the converted result. In this case `res` holds the converted
# result and we may try to re-assign the named module
if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
if isinstance(m, (
torch.nn.modules.batchnorm.BatchNorm2d,
torch.nn.modules.batchnorm.SyncBatchNorm,
BatchNormAct2d,
SyncBatchNormAct,
)):
_add_submodule(root_module, n, res)
# Unfreeze batch norm
else:
res = unfreeze_batch_norm_2d(m)
# Ditto. See note above in mode == 'freeze' branch
if isinstance(m, FrozenBatchNorm2d):
if isinstance(m, (FrozenBatchNorm2d, FrozenBatchNormAct2d)):
_add_submodule(root_module, n, res)

@ -1 +1 @@
__version__ = '0.8.1dev0'
__version__ = '0.8.6dev0'

@ -89,56 +89,58 @@ parser.add_argument('--data-dir', metavar='DIR',
parser.add_argument('--dataset', metavar='NAME', default='',
help='dataset type + name ("<type>/<name>") (default: ImageFolder or ImageTar if empty)')
group.add_argument('--train-split', metavar='NAME', default='train',
help='dataset train split (default: train)')
help='dataset train split (default: train)')
group.add_argument('--val-split', metavar='NAME', default='validation',
help='dataset validation split (default: validation)')
help='dataset validation split (default: validation)')
group.add_argument('--dataset-download', action='store_true', default=False,
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
help='Allow download of dataset for torch/ and tfds/ datasets that support it.')
group.add_argument('--class-map', default='', type=str, metavar='FILENAME',
help='path to class to idx mapping file (default: "")')
help='path to class to idx mapping file (default: "")')
# Model parameters
group = parser.add_argument_group('Model parameters')
group.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
help='Name of model to train (default: "resnet50")')
help='Name of model to train (default: "resnet50")')
group.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
help='Start with pretrained version of specified network (if avail)')
group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
help='Initialize model from this checkpoint (default: none)')
group.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
help='Resume full model and optimizer state from checkpoint (default: none)')
group.add_argument('--no-resume-opt', action='store_true', default=False,
help='prevent resume of optimizer state when resuming model')
help='prevent resume of optimizer state when resuming model')
group.add_argument('--num-classes', type=int, default=None, metavar='N',
help='number of label classes (Model default if None)')
help='number of label classes (Model default if None)')
group.add_argument('--gp', default=None, type=str, metavar='POOL',
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.')
group.add_argument('--img-size', type=int, default=None, metavar='N',
help='Image size (default: None => model default)')
help='Image size (default: None => model default)')
group.add_argument('--in-chans', type=int, default=None, metavar='N',
help='Image input channels (default: None => 3)')
help='Image input channels (default: None => 3)')
group.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
metavar='N N N',
help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
group.add_argument('--crop-pct', default=None, type=float,
metavar='N', help='Input image center crop percent (for validation only)')
metavar='N', help='Input image center crop percent (for validation only)')
group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
help='Override mean pixel value of dataset')
help='Override mean pixel value of dataset')
group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
help='Override std deviation of dataset')
help='Override std deviation of dataset')
group.add_argument('--interpolation', default='', type=str, metavar='NAME',
help='Image resize interpolation type (overrides model)')
help='Image resize interpolation type (overrides model)')
group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N',
help='Input batch size for training (default: 128)')
help='Input batch size for training (default: 128)')
group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N',
help='Validation batch size override (default: None)')
help='Validation batch size override (default: None)')
group.add_argument('--channels-last', action='store_true', default=False,
help='Use channels_last memory layout')
help='Use channels_last memory layout')
group.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
group.add_argument('--grad-checkpointing', action='store_true', default=False,
help='Enable gradient checkpointing through model blocks/stages')
help='Enable gradient checkpointing through model blocks/stages')
group.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
help='enable experimental fast-norm')
group.add_argument('--model-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
scripting_group = group.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
@ -151,199 +153,200 @@ scripting_group.add_argument('--aot-autograd', default=False, action='store_true
# Optimizer parameters
group = parser.add_argument_group('Optimizer parameters')
group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "sgd")')
help='Optimizer (default: "sgd")')
group.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: None, use opt default)')
help='Optimizer Epsilon (default: None, use opt default)')
group.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
help='Optimizer Betas (default: None, use opt default)')
group.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='Optimizer momentum (default: 0.9)')
help='Optimizer momentum (default: 0.9)')
group.add_argument('--weight-decay', type=float, default=2e-5,
help='weight decay (default: 2e-5)')
help='weight decay (default: 2e-5)')
group.add_argument('--clip-grad', type=float, default=None, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
help='Clip gradient norm (default: None, no clipping)')
group.add_argument('--clip-mode', type=str, default='norm',
help='Gradient clipping mode. One of ("norm", "value", "agc")')
help='Gradient clipping mode. One of ("norm", "value", "agc")')
group.add_argument('--layer-decay', type=float, default=None,
help='layer-wise learning rate decay (default: None)')
help='layer-wise learning rate decay (default: None)')
group.add_argument('--opt-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
# Learning rate schedule parameters
group = parser.add_argument_group('Learning rate schedule parameters')
group.add_argument('--sched', type=str, default='cosine', metavar='SCHEDULER',
help='LR scheduler (default: "step"')
help='LR scheduler (default: "step"')
group.add_argument('--sched-on-updates', action='store_true', default=False,
help='Apply LR scheduler step on update instead of epoch end.')
help='Apply LR scheduler step on update instead of epoch end.')
group.add_argument('--lr', type=float, default=None, metavar='LR',
help='learning rate, overrides lr-base if set (default: None)')
help='learning rate, overrides lr-base if set (default: None)')
group.add_argument('--lr-base', type=float, default=0.1, metavar='LR',
help='base learning rate: lr = lr_base * global_batch_size / base_size')
help='base learning rate: lr = lr_base * global_batch_size / base_size')
group.add_argument('--lr-base-size', type=int, default=256, metavar='DIV',
help='base learning rate batch size (divisor, default: 256).')
help='base learning rate batch size (divisor, default: 256).')
group.add_argument('--lr-base-scale', type=str, default='', metavar='SCALE',
help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
help='base learning rate vs batch_size scaling ("linear", "sqrt", based on opt if empty)')
group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
help='learning rate noise on/off epoch percentages')
group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
help='learning rate noise limit percent (default: 0.67)')
group.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
help='learning rate noise std-dev (default: 1.0)')
group.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT',
help='learning rate cycle len multiplier (default: 1.0)')
help='learning rate cycle len multiplier (default: 1.0)')
group.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT',
help='amount to decay each learning rate cycle (default: 0.5)')
help='amount to decay each learning rate cycle (default: 0.5)')
group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N',
help='learning rate cycle limit, cycles enabled if > 1')
help='learning rate cycle limit, cycles enabled if > 1')
group.add_argument('--lr-k-decay', type=float, default=1.0,
help='learning rate k-decay for cosine/poly (default: 1.0)')
help='learning rate k-decay for cosine/poly (default: 1.0)')
group.add_argument('--warmup-lr', type=float, default=1e-5, metavar='LR',
help='warmup learning rate (default: 1e-5)')
help='warmup learning rate (default: 1e-5)')
group.add_argument('--min-lr', type=float, default=0, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (default: 0)')
help='lower lr bound for cyclic schedulers that hit 0 (default: 0)')
group.add_argument('--epochs', type=int, default=300, metavar='N',
help='number of epochs to train (default: 300)')
help='number of epochs to train (default: 300)')
group.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
group.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
help='manual epoch number (useful on restarts)')
group.add_argument('--decay-milestones', default=[90, 180, 270], type=int, nargs='+', metavar="MILESTONES",
help='list of decay epoch indices for multistep lr. must be increasing')
help='list of decay epoch indices for multistep lr. must be increasing')
group.add_argument('--decay-epochs', type=float, default=90, metavar='N',
help='epoch interval to decay LR')
help='epoch interval to decay LR')
group.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
help='epochs to warmup LR, if scheduler supports')
group.add_argument('--warmup-prefix', action='store_true', default=False,
help='Exclude warmup period from decay schedule.'),
help='Exclude warmup period from decay schedule.'),
group.add_argument('--cooldown-epochs', type=int, default=0, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
group.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10)')
help='patience epochs for Plateau LR scheduler (default: 10)')
group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
help='LR decay rate (default: 0.1)')
# Augmentation & regularization parameters
group = parser.add_argument_group('Augmentation and regularization parameters')
group.add_argument('--no-aug', action='store_true', default=False,
help='Disable all training augmentation, override other train aug args')
help='Disable all training augmentation, override other train aug args')
group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT',
help='Random resize scale (default: 0.08 1.0)')
group.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
help='Random resize scale (default: 0.08 1.0)')
group.add_argument('--ratio', type=float, nargs='+', default=[3. / 4., 4. / 3.], metavar='RATIO',
help='Random resize aspect ratio (default: 0.75 1.33)')
group.add_argument('--hflip', type=float, default=0.5,
help='Horizontal flip training aug probability')
help='Horizontal flip training aug probability')
group.add_argument('--vflip', type=float, default=0.,
help='Vertical flip training aug probability')
help='Vertical flip training aug probability')
group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
help='Color jitter factor (default: 0.4)')
group.add_argument('--aa', type=str, default=None, metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
help='Use AutoAugment policy. "v0" or "original". (default: None)'),
group.add_argument('--aug-repeats', type=float, default=0,
help='Number of augmentation repetitions (distributed training only) (default: 0)')
help='Number of augmentation repetitions (distributed training only) (default: 0)')
group.add_argument('--aug-splits', type=int, default=0,
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
help='Number of augmentation splits (default: 0, valid: 0 or >=2)')
group.add_argument('--jsd-loss', action='store_true', default=False,
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.')
group.add_argument('--bce-loss', action='store_true', default=False,
help='Enable BCE loss w/ Mixup/CutMix use.')
help='Enable BCE loss w/ Mixup/CutMix use.')
group.add_argument('--bce-target-thresh', type=float, default=None,
help='Threshold for binarizing softened BCE targets (default: None, disabled)')
help='Threshold for binarizing softened BCE targets (default: None, disabled)')
group.add_argument('--reprob', type=float, default=0., metavar='PCT',
help='Random erase prob (default: 0.)')
help='Random erase prob (default: 0.)')
group.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
help='Random erase mode (default: "pixel")')
group.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
help='Random erase count (default: 1)')
group.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
help='Do not random erase first (clean) augmentation split')
group.add_argument('--mixup', type=float, default=0.0,
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
help='mixup alpha, mixup enabled if > 0. (default: 0.)')
group.add_argument('--cutmix', type=float, default=0.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
help='cutmix alpha, cutmix enabled if > 0. (default: 0.)')
group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
group.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
help='Probability of performing mixup or cutmix when either/both is enabled')
group.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
help='Probability of switching to cutmix when both mixup and cutmix enabled')
group.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N',
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
help='Turn off mixup after this epoch, disabled if 0 (default: 0)')
group.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
help='Label smoothing (default: 0.1)')
group.add_argument('--train-interpolation', type=str, default='random',
help='Training interpolation (random, bilinear, bicubic default: "random")')
help='Training interpolation (random, bilinear, bicubic default: "random")')
group.add_argument('--drop', type=float, default=0.0, metavar='PCT',
help='Dropout rate (default: 0.)')
help='Dropout rate (default: 0.)')
group.add_argument('--drop-connect', type=float, default=None, metavar='PCT',
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
help='Drop connect rate, DEPRECATED, use drop-path (default: None)')
group.add_argument('--drop-path', type=float, default=None, metavar='PCT',
help='Drop path rate (default: None)')
help='Drop path rate (default: None)')
group.add_argument('--drop-block', type=float, default=None, metavar='PCT',
help='Drop block rate (default: None)')
help='Drop block rate (default: None)')
# Batch norm parameters (only works with gen_efficientnet based models currently)
group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.')
group.add_argument('--bn-momentum', type=float, default=None,
help='BatchNorm momentum override (if not None)')
help='BatchNorm momentum override (if not None)')
group.add_argument('--bn-eps', type=float, default=None,
help='BatchNorm epsilon override (if not None)')
help='BatchNorm epsilon override (if not None)')
group.add_argument('--sync-bn', action='store_true',
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
group.add_argument('--dist-bn', type=str, default='reduce',
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
group.add_argument('--split-bn', action='store_true',
help='Enable separate BN layers per augmentation split.')
help='Enable separate BN layers per augmentation split.')
# Model Exponential Moving Average
group = parser.add_argument_group('Model exponential moving average parameters')
group.add_argument('--model-ema', action='store_true', default=False,
help='Enable tracking moving average of model weights')
help='Enable tracking moving average of model weights')
group.add_argument('--model-ema-force-cpu', action='store_true', default=False,
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.')
group.add_argument('--model-ema-decay', type=float, default=0.9998,
help='decay factor for model weights moving average (default: 0.9998)')
help='decay factor for model weights moving average (default: 0.9998)')
# Misc
group = parser.add_argument_group('Miscellaneous parameters')
group.add_argument('--seed', type=int, default=42, metavar='S',
help='random seed (default: 42)')
help='random seed (default: 42)')
group.add_argument('--worker-seeding', type=str, default='all',
help='worker seed mode (default: all)')
help='worker seed mode (default: all)')
group.add_argument('--log-interval', type=int, default=50, metavar='N',
help='how many batches to wait before logging training status')
help='how many batches to wait before logging training status')
group.add_argument('--recovery-interval', type=int, default=0, metavar='N',
help='how many batches to wait before writing recovery checkpoint')
help='how many batches to wait before writing recovery checkpoint')
group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N',
help='number of checkpoints to keep (default: 10)')
help='number of checkpoints to keep (default: 10)')
group.add_argument('-j', '--workers', type=int, default=4, metavar='N',
help='how many training processes to use (default: 4)')
help='how many training processes to use (default: 4)')
group.add_argument('--save-images', action='store_true', default=False,
help='save images of input bathes every log interval for debugging')
help='save images of input bathes every log interval for debugging')
group.add_argument('--amp', action='store_true', default=False,
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
help='use NVIDIA Apex AMP or Native AMP for mixed precision training')
group.add_argument('--amp-dtype', default='float16', type=str,
help='lower precision AMP dtype (default: float16)')
help='lower precision AMP dtype (default: float16)')
group.add_argument('--amp-impl', default='native', type=str,
help='AMP impl to use, "native" or "apex" (default: native)')
help='AMP impl to use, "native" or "apex" (default: native)')
group.add_argument('--no-ddp-bb', action='store_true', default=False,
help='Force broadcast buffers for native DDP to off.')
help='Force broadcast buffers for native DDP to off.')
group.add_argument('--pin-mem', action='store_true', default=False,
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
group.add_argument('--no-prefetcher', action='store_true', default=False,
help='disable fast prefetcher')
help='disable fast prefetcher')
group.add_argument('--output', default='', type=str, metavar='PATH',
help='path to output folder (default: none, current dir)')
help='path to output folder (default: none, current dir)')
group.add_argument('--experiment', default='', type=str, metavar='NAME',
help='name of train experiment, name of sub-folder for output')
help='name of train experiment, name of sub-folder for output')
group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC',
help='Best metric (default: "top1"')
help='Best metric (default: "top1"')
group.add_argument('--tta', type=int, default=0, metavar='N',
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)')
group.add_argument("--local_rank", default=0, type=int)
group.add_argument('--use-multi-epochs-loader', action='store_true', default=False,
help='use the multi-epochs-loader to save time at the beginning of every epoch')
help='use the multi-epochs-loader to save time at the beginning of every epoch')
group.add_argument('--log-wandb', action='store_true', default=False,
help='log training and validation metrics to wandb')
help='log training and validation metrics to wandb')
def _parse_args():
@ -371,8 +374,6 @@ def main():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
if args.data and not args.data_dir:
args.data_dir = args.data
args.prefetcher = not args.no_prefetcher
device = utils.init_distributed_device(args)
if args.distributed:
@ -383,14 +384,6 @@ def main():
_logger.info(f'Training with a single process on 1 device ({args.device}).')
assert args.rank >= 0
if utils.is_primary(args) and args.log_wandb:
if has_wandb:
wandb.init(project=args.experiment, config=args)
else:
_logger.warning(
"You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
# resolve AMP arguments based on PyTorch / Apex availability
use_amp = None
amp_dtype = torch.float16
@ -432,6 +425,7 @@ def main():
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint,
**args.model_kwargs,
)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
@ -504,7 +498,11 @@ def main():
f'Learning rate ({args.lr}) calculated from base learning rate ({args.lr_base}) '
f'and global batch size ({global_batch_size}) with {args.lr_base_scale} scaling.')
optimizer = create_optimizer_v2(model, **optimizer_kwargs(cfg=args))
optimizer = create_optimizer_v2(
model,
**optimizer_kwargs(cfg=args),
**args.opt_kwargs,
)
# setup automatic mixed-precision (AMP) loss scaling and op casting
amp_autocast = suppress # do nothing
@ -559,6 +557,8 @@ def main():
# NOTE: EMA model does not need to be wrapped by DDP
# create the train and eval datasets
if args.data and not args.data_dir:
args.data_dir = args.data
dataset_train = create_dataset(
args.dataset,
root=args.data_dir,
@ -712,6 +712,14 @@ def main():
with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
f.write(args_text)
if utils.is_primary(args) and args.log_wandb:
if has_wandb:
wandb.init(project=args.experiment, config=args)
else:
_logger.warning(
"You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
# setup learning rate schedule and starting epoch
updates_per_epoch = len(loader_train)
lr_scheduler, num_epochs = create_scheduler_v2(

@ -26,7 +26,7 @@ from timm.data import create_dataset, create_loader, resolve_data_config, RealLa
from timm.layers import apply_test_time_pool, set_fast_norm
from timm.models import create_model, load_checkpoint, is_model, list_models
from timm.utils import accuracy, AverageMeter, natural_key, setup_default_logging, set_jit_fuser, \
decay_batch_step, check_batch_size_retry
decay_batch_step, check_batch_size_retry, ParseKwargs
try:
from apex import amp
@ -71,6 +71,8 @@ parser.add_argument('-b', '--batch-size', default=256, type=int,
metavar='N', help='mini-batch size (default: 256)')
parser.add_argument('--img-size', default=None, type=int,
metavar='N', help='Input image dimension, uses model default if empty')
parser.add_argument('--in-chans', type=int, default=None, metavar='N',
help='Image input channels (default: None => 3)')
parser.add_argument('--input-size', default=None, nargs=3, type=int,
metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty')
parser.add_argument('--use-train-size', action='store_true', default=False,
@ -123,6 +125,8 @@ parser.add_argument('--fuser', default='', type=str,
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
parser.add_argument('--fast-norm', default=False, action='store_true',
help='enable experimental fast-norm')
parser.add_argument('--model-kwargs', nargs='*', default={}, action=ParseKwargs)
scripting_group = parser.add_mutually_exclusive_group()
scripting_group.add_argument('--torchscript', default=False, action='store_true',
@ -181,13 +185,20 @@ def validate(args):
set_fast_norm()
# create model
in_chans = 3
if args.in_chans is not None:
in_chans = args.in_chans
elif args.input_size is not None:
in_chans = args.input_size[0]
model = create_model(
args.model,
pretrained=args.pretrained,
num_classes=args.num_classes,
in_chans=3,
in_chans=in_chans,
global_pool=args.gp,
scriptable=args.torchscript,
**args.model_kwargs,
)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
@ -232,8 +243,9 @@ def validate(args):
criterion = nn.CrossEntropyLoss().to(device)
root_dir = args.data or args.data_dir
dataset = create_dataset(
root=args.data,
root=root_dir,
name=args.dataset,
split=args.split,
download=args.dataset_download,
@ -389,7 +401,7 @@ def main():
if args.model == 'all':
# validate all models in a list of names with pretrained checkpoints
args.pretrained = True
model_names = list_models(pretrained=True, exclude_filters=['*_in21k', '*_in22k', '*_dino'])
model_names = list_models('convnext*', pretrained=True, exclude_filters=['*_in21k', '*_in22k', '*in12k', '*_dino', '*fcmae'])
model_cfgs = [(n, '') for n in model_names]
elif not is_model(args.model):
# model name doesn't exist, try as wildcard filter

Loading…
Cancel
Save