Merge remote-tracking branch 'origin/master' into bits_and_tpu

pull/1414/head
Ross Wightman 3 years ago
commit dff33730b3

@ -23,6 +23,32 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New ## What's New
### May 13, 2022
* Official Swin-V2 models and weights added from (https://github.com/microsoft/Swin-Transformer). Cleaned up to support torchscript.
* Some refactoring for existing `timm` Swin-V2-CR impl, will likely do a bit more to bring parts closer to official and decide whether to merge some aspects.
* More Vision Transformer relative position / residual post-norm experiments (all trained on TPU thanks to TRC program)
* `vit_relpos_small_patch16_224` - 81.5 @ 224, 82.5 @ 320 -- rel pos, layer scale, no class token, avg pool
* `vit_relpos_medium_patch16_rpn_224` - 82.3 @ 224, 83.1 @ 320 -- rel pos + res-post-norm, no class token, avg pool
* `vit_relpos_medium_patch16_224` - 82.5 @ 224, 83.3 @ 320 -- rel pos, layer scale, no class token, avg pool
* `vit_relpos_base_patch16_gapcls_224` - 82.8 @ 224, 83.9 @ 320 -- rel pos, layer scale, class token, avg pool (by mistake)
* Bring 512 dim, 8-head 'medium' ViT model variant back to life (after using in a pre DeiT 'small' model for first ViT impl back in 2020)
* Add ViT relative position support for switching btw existing impl and some additions in official Swin-V2 impl for future trials
* Sequencer2D impl (https://arxiv.org/abs/2205.01972), added via PR from author (https://github.com/okojoalg)
### May 2, 2022
* Vision Transformer experiments adding Relative Position (Swin-V2 log-coord) (`vision_transformer_relpos.py`) and Residual Post-Norm branches (from Swin-V2) (`vision_transformer*.py`)
* `vit_relpos_base_patch32_plus_rpn_256` - 79.5 @ 256, 80.6 @ 320 -- rel pos + extended width + res-post-norm, no class token, avg pool
* `vit_relpos_base_patch16_224` - 82.5 @ 224, 83.6 @ 320 -- rel pos, layer scale, no class token, avg pool
* `vit_base_patch16_rpn_224` - 82.3 @ 224 -- rel pos + res-post-norm, no class token, avg pool
* Vision Transformer refactor to remove representation layer that was only used in initial vit and rarely used since with newer pretrain (ie `How to Train Your ViT`)
* `vit_*` models support removal of class token, use of global average pool, use of fc_norm (ala beit, mae).
### April 22, 2022
* `timm` models are now officially supported in [fast.ai](https://www.fast.ai/)! Just in time for the new Practical Deep Learning course. `timmdocs` documentation link updated to [timm.fast.ai](http://timm.fast.ai/).
* Two more model weights added in the TPU trained [series](https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights). Some In22k pretrain still in progress.
* `seresnext101d_32x8d` - 83.69 @ 224, 84.35 @ 288
* `seresnextaa101d_32x8d` (anti-aliased w/ AvgPool2d) - 83.85 @ 224, 84.57 @ 288
### March 23, 2022 ### March 23, 2022
* Add `ParallelBlock` and `LayerScale` option to base vit models to support model configs in [Three things everyone should know about ViT](https://arxiv.org/abs/2203.09795) * Add `ParallelBlock` and `LayerScale` option to base vit models to support model configs in [Three things everyone should know about ViT](https://arxiv.org/abs/2203.09795)
* `convnext_tiny_hnf` (head norm first) weights trained with (close to) A2 recipe, 82.2% top-1, could do better with more epochs. * `convnext_tiny_hnf` (head norm first) weights trained with (close to) A2 recipe, 82.2% top-1, could do better with more epochs.
@ -375,6 +401,7 @@ A full version of the list below with source links can be found in the [document
* ReXNet - https://arxiv.org/abs/2007.00992 * ReXNet - https://arxiv.org/abs/2007.00992
* SelecSLS - https://arxiv.org/abs/1907.00837 * SelecSLS - https://arxiv.org/abs/1907.00837
* Selective Kernel Networks - https://arxiv.org/abs/1903.06586 * Selective Kernel Networks - https://arxiv.org/abs/1903.06586
* Sequencer2D - https://arxiv.org/abs/2205.01972
* Swin S3 (AutoFormerV2) - https://arxiv.org/abs/2111.14725 * Swin S3 (AutoFormerV2) - https://arxiv.org/abs/2111.14725
* Swin Transformer - https://arxiv.org/abs/2103.14030 * Swin Transformer - https://arxiv.org/abs/2103.14030
* Swin Transformer V2 - https://arxiv.org/abs/2111.09883 * Swin Transformer V2 - https://arxiv.org/abs/2111.09883
@ -462,7 +489,7 @@ My current [documentation](https://rwightman.github.io/pytorch-image-models/) fo
[Getting Started with PyTorch Image Models (timm): A Practitioners Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail. [Getting Started with PyTorch Image Models (timm): A Practitioners Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail.
[timmdocs](https://fastai.github.io/timmdocs/) is quickly becoming a much more comprehensive set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs. [timmdocs](http://timm.fast.ai/) is quickly becoming a much more comprehensive set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs.
[paperswithcode](https://paperswithcode.com/lib/timm) is a good resource for browsing the models within `timm`. [paperswithcode](https://paperswithcode.com/lib/timm) is a good resource for browsing the models within `timm`.

@ -1,5 +1,134 @@
# Archived Changes # Archived Changes
### June 8, 2021
* Add first ResMLP weights, trained in PyTorch XLA on TPU-VM w/ my XLA branch. 24 block variant, 79.2 top-1.
* Add ResNet51-Q model w/ pretrained weights at 82.36 top-1.
* NFNet inspired block layout with quad layer stem and no maxpool
* Same param count (35.7M) and throughput as ResNetRS-50 but +1.5 top-1 @ 224x224 and +2.5 top-1 at 288x288
### May 25, 2021
* Add LeViT, Visformer, Convit (PR by Aman Arora), Twins (PR by paper authors) transformer models
* Cleanup input_size/img_size override handling and testing for all vision transformer models
* Add `efficientnetv2_rw_m` model and weights (started training before official code). 84.8 top-1, 53M params.
### May 14, 2021
* Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl.
* 1k trained variants: `tf_efficientnetv2_s/m/l`
* 21k trained variants: `tf_efficientnetv2_s/m/l_in21k`
* 21k pretrained -> 1k fine-tuned: `tf_efficientnetv2_s/m/l_in21ft1k`
* v2 models w/ v1 scaling: `tf_efficientnetv2_b0` through `b3`
* Rename my prev V2 guess `efficientnet_v2s` -> `efficientnetv2_rw_s`
* Some blank `efficientnetv2_*` models in-place for future native PyTorch training
### May 5, 2021
* Add MLP-Mixer models and port pretrained weights from [Google JAX impl](https://github.com/google-research/vision_transformer/tree/linen)
* Add CaiT models and pretrained weights from [FB](https://github.com/facebookresearch/deit)
* Add ResNet-RS models and weights from [TF](https://github.com/tensorflow/tpu/tree/master/models/official/resnet/resnet_rs). Thanks [Aman Arora](https://github.com/amaarora)
* Add CoaT models and weights. Thanks [Mohammed Rizin](https://github.com/morizin)
* Add new ImageNet-21k weights & finetuned weights for TResNet, MobileNet-V3, ViT models. Thanks [mrT](https://github.com/mrT23)
* Add GhostNet models and weights. Thanks [Kai Han](https://github.com/iamhankai)
* Update ByoaNet attention modles
* Improve SA module inits
* Hack together experimental stand-alone Swin based attn module and `swinnet`
* Consistent '26t' model defs for experiments.
* Add improved Efficientnet-V2S (prelim model def) weights. 83.8 top-1.
* WandB logging support
### April 13, 2021
* Add Swin Transformer models and weights from https://github.com/microsoft/Swin-Transformer
### April 12, 2021
* Add ECA-NFNet-L1 (slimmed down F1 w/ SiLU, 41M params) trained with this code. 84% top-1 @ 320x320. Trained at 256x256.
* Add EfficientNet-V2S model (unverified model definition) weights. 83.3 top-1 @ 288x288. Only trained single res 224. Working on progressive training.
* Add ByoaNet model definition (Bring-your-own-attention) w/ SelfAttention block and corresponding SA/SA-like modules and model defs
* Lambda Networks - https://arxiv.org/abs/2102.08602
* Bottleneck Transformers - https://arxiv.org/abs/2101.11605
* Halo Nets - https://arxiv.org/abs/2103.12731
* Adabelief optimizer contributed by Juntang Zhuang
### April 1, 2021
* Add snazzy `benchmark.py` script for bulk `timm` model benchmarking of train and/or inference
* Add Pooling-based Vision Transformer (PiT) models (from https://github.com/naver-ai/pit)
* Merged distilled variant into main for torchscript compatibility
* Some `timm` cleanup/style tweaks and weights have hub download support
* Cleanup Vision Transformer (ViT) models
* Merge distilled (DeiT) model into main so that torchscript can work
* Support updated weight init (defaults to old still) that closer matches original JAX impl (possibly better training from scratch)
* Separate hybrid model defs into different file and add several new model defs to fiddle with, support patch_size != 1 for hybrids
* Fix fine-tuning num_class changes (PiT and ViT) and pos_embed resizing (Vit) with distilled variants
* nn.Sequential for block stack (does not break downstream compat)
* TnT (Transformer-in-Transformer) models contributed by author (from https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT)
* Add RegNetY-160 weights from DeiT teacher model
* Add new NFNet-L0 w/ SE attn (rename `nfnet_l0b`->`nfnet_l0`) weights 82.75 top-1 @ 288x288
* Some fixes/improvements for TFDS dataset wrapper
### March 7, 2021
* First 0.4.x PyPi release w/ NFNets (& related), ByoB (GPU-Efficient, RepVGG, etc).
* Change feature extraction for pre-activation nets (NFNets, ResNetV2) to return features before activation.
### Feb 18, 2021
* Add pretrained weights and model variants for NFNet-F* models from [DeepMind Haiku impl](https://github.com/deepmind/deepmind-research/tree/master/nfnets).
* Models are prefixed with `dm_`. They require SAME padding conv, skipinit enabled, and activation gains applied in act fn.
* These models are big, expect to run out of GPU memory. With the GELU activiation + other options, they are roughly 1/2 the inference speed of my SiLU PyTorch optimized `s` variants.
* Original model results are based on pre-processing that is not the same as all other models so you'll see different results in the results csv (once updated).
* Matching the original pre-processing as closely as possible I get these results:
* `dm_nfnet_f6` - 86.352
* `dm_nfnet_f5` - 86.100
* `dm_nfnet_f4` - 85.834
* `dm_nfnet_f3` - 85.676
* `dm_nfnet_f2` - 85.178
* `dm_nfnet_f1` - 84.696
* `dm_nfnet_f0` - 83.464
### Feb 16, 2021
* Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py.
* AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc`
* PyTorch global norm of 1.0 (old behaviour, always norm), `--clip-grad 1.0`
* PyTorch value clipping of 10, `--clip-grad 10. --clip-mode value`
* AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training w/ NFNet/NF-ResNet.
### Feb 12, 2021
* Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs
### Feb 10, 2021
* More model archs, incl a flexible ByobNet backbone ('Bring-your-own-blocks')
* GPU-Efficient-Networks (https://github.com/idstcv/GPU-Efficient-Networks), impl in `byobnet.py`
* RepVGG (https://github.com/DingXiaoH/RepVGG), impl in `byobnet.py`
* classic VGG (from torchvision, impl in `vgg`)
* Refinements to normalizer layer arg handling and normalizer+act layer handling in some models
* Default AMP mode changed to native PyTorch AMP instead of APEX. Issues not being fixed with APEX. Native works with `--channels-last` and `--torchscript` model training, APEX does not.
* Fix a few bugs introduced since last pypi release
### Feb 8, 2021
* Add several ResNet weights with ECA attention. 26t & 50t trained @ 256, test @ 320. 269d train @ 256, fine-tune @320, test @ 352.
* `ecaresnet26t` - 79.88 top-1 @ 320x320, 79.08 @ 256x256
* `ecaresnet50t` - 82.35 top-1 @ 320x320, 81.52 @ 256x256
* `ecaresnet269d` - 84.93 top-1 @ 352x352, 84.87 @ 320x320
* Remove separate tiered (`t`) vs tiered_narrow (`tn`) ResNet model defs, all `tn` changed to `t` and `t` models removed (`seresnext26t_32x4d` only model w/ weights that was removed).
* Support model default_cfgs with separate train vs test resolution `test_input_size` and remove extra `_320` suffix ResNet model defs that were just for test.
### Jan 30, 2021
* Add initial "Normalization Free" NF-RegNet-B* and NF-ResNet model definitions based on [paper](https://arxiv.org/abs/2101.08692)
### Jan 25, 2021
* Add ResNetV2 Big Transfer (BiT) models w/ ImageNet-1k and 21k weights from https://github.com/google-research/big_transfer
* Add official R50+ViT-B/16 hybrid models + weights from https://github.com/google-research/vision_transformer
* ImageNet-21k ViT weights are added w/ model defs and representation layer (pre logits) support
* NOTE: ImageNet-21k classifier heads were zero'd in original weights, they are only useful for transfer learning
* Add model defs and weights for DeiT Vision Transformer models from https://github.com/facebookresearch/deit
* Refactor dataset classes into ImageDataset/IterableImageDataset + dataset specific parser classes
* Add Tensorflow-Datasets (TFDS) wrapper to allow use of TFDS image classification sets with train script
* Ex: `train.py /data/tfds --dataset tfds/oxford_iiit_pet --val-split test --model resnet50 -b 256 --amp --num-classes 37 --opt adamw --lr 3e-4 --weight-decay .001 --pretrained -j 2`
* Add improved .tar dataset parser that reads images from .tar, folder of .tar files, or .tar within .tar
* Run validation on full ImageNet-21k directly from tar w/ BiT model: `validate.py /data/fall11_whole.tar --model resnetv2_50x1_bitm_in21k --amp`
* Models in this update should be stable w/ possible exception of ViT/BiT, possibility of some regressions with train/val scripts and dataset handling
### Jan 3, 2021
* Add SE-ResNet-152D weights
* 256x256 val, 0.94 crop top-1 - 83.75
* 320x320 val, 1.0 crop - 84.36
* Update results files
### Dec 18, 2020 ### Dec 18, 2020
* Add ResNet-101D, ResNet-152D, and ResNet-200D weights trained @ 256x256 * Add ResNet-101D, ResNet-152D, and ResNet-200D weights trained @ 256x256
* 256x256 val, 0.94 crop (top-1) - 101D (82.33), 152D (83.08), 200D (83.25) * 256x256 val, 0.94 crop (top-1) - 101D (82.33), 152D (83.08), 200D (83.25)

@ -1,130 +1,130 @@
# Recent Changes # Recent Changes
### June 8, 2021
* Add first ResMLP weights, trained in PyTorch XLA on TPU-VM w/ my XLA branch. 24 block variant, 79.2 top-1.
* Add ResNet51-Q model w/ pretrained weights at 82.36 top-1.
* NFNet inspired block layout with quad layer stem and no maxpool
* Same param count (35.7M) and throughput as ResNetRS-50 but +1.5 top-1 @ 224x224 and +2.5 top-1 at 288x288
### May 25, 2021 ### March 23, 2022
* Add LeViT, Visformer, Convit (PR by Aman Arora), Twins (PR by paper authors) transformer models * Add `ParallelBlock` and `LayerScale` option to base vit models to support model configs in [Three things everyone should know about ViT](https://arxiv.org/abs/2203.09795)
* Cleanup input_size/img_size override handling and testing for all vision transformer models * `convnext_tiny_hnf` (head norm first) weights trained with (close to) A2 recipe, 82.2% top-1, could do better with more epochs.
* Add `efficientnetv2_rw_m` model and weights (started training before official code). 84.8 top-1, 53M params.
### March 21, 2022
### May 14, 2021 * Merge `norm_norm_norm`. **IMPORTANT** this update for a coming 0.6.x release will likely de-stabilize the master branch for a while. Branch [`0.5.x`](https://github.com/rwightman/pytorch-image-models/tree/0.5.x) or a previous 0.5.x release can be used if stability is required.
* Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl. * Significant weights update (all TPU trained) as described in this [release](https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights)
* 1k trained variants: `tf_efficientnetv2_s/m/l` * `regnety_040` - 82.3 @ 224, 82.96 @ 288
* 21k trained variants: `tf_efficientnetv2_s/m/l_in21k` * `regnety_064` - 83.0 @ 224, 83.65 @ 288
* 21k pretrained -> 1k fine-tuned: `tf_efficientnetv2_s/m/l_in21ft1k` * `regnety_080` - 83.17 @ 224, 83.86 @ 288
* v2 models w/ v1 scaling: `tf_efficientnetv2_b0` through `b3` * `regnetv_040` - 82.44 @ 224, 83.18 @ 288 (timm pre-act)
* Rename my prev V2 guess `efficientnet_v2s` -> `efficientnetv2_rw_s` * `regnetv_064` - 83.1 @ 224, 83.71 @ 288 (timm pre-act)
* Some blank `efficientnetv2_*` models in-place for future native PyTorch training * `regnetz_040` - 83.67 @ 256, 84.25 @ 320
* `regnetz_040h` - 83.77 @ 256, 84.5 @ 320 (w/ extra fc in head)
### May 5, 2021 * `resnetv2_50d_gn` - 80.8 @ 224, 81.96 @ 288 (pre-act GroupNorm)
* Add MLP-Mixer models and port pretrained weights from [Google JAX impl](https://github.com/google-research/vision_transformer/tree/linen) * `resnetv2_50d_evos` 80.77 @ 224, 82.04 @ 288 (pre-act EvoNormS)
* Add CaiT models and pretrained weights from [FB](https://github.com/facebookresearch/deit) * `regnetz_c16_evos` - 81.9 @ 256, 82.64 @ 320 (EvoNormS)
* Add ResNet-RS models and weights from [TF](https://github.com/tensorflow/tpu/tree/master/models/official/resnet/resnet_rs). Thanks [Aman Arora](https://github.com/amaarora) * `regnetz_d8_evos` - 83.42 @ 256, 84.04 @ 320 (EvoNormS)
* Add CoaT models and weights. Thanks [Mohammed Rizin](https://github.com/morizin) * `xception41p` - 82 @ 299 (timm pre-act)
* Add new ImageNet-21k weights & finetuned weights for TResNet, MobileNet-V3, ViT models. Thanks [mrT](https://github.com/mrT23) * `xception65` - 83.17 @ 299
* Add GhostNet models and weights. Thanks [Kai Han](https://github.com/iamhankai) * `xception65p` - 83.14 @ 299 (timm pre-act)
* Update ByoaNet attention modles * `resnext101_64x4d` - 82.46 @ 224, 83.16 @ 288
* Improve SA module inits * `seresnext101_32x8d` - 83.57 @ 224, 84.270 @ 288
* Hack together experimental stand-alone Swin based attn module and `swinnet` * `resnetrs200` - 83.85 @ 256, 84.44 @ 320
* Consistent '26t' model defs for experiments. * HuggingFace hub support fixed w/ initial groundwork for allowing alternative 'config sources' for pretrained model definitions and weights (generic local file / remote url support soon)
* Add improved Efficientnet-V2S (prelim model def) weights. 83.8 top-1. * SwinTransformer-V2 implementation added. Submitted by [Christoph Reich](https://github.com/ChristophReich1996). Training experiments and model changes by myself are ongoing so expect compat breaks.
* WandB logging support * Swin-S3 (AutoFormerV2) models / weights added from https://github.com/microsoft/Cream/tree/main/AutoFormerV2
* MobileViT models w/ weights adapted from https://github.com/apple/ml-cvnets
### April 13, 2021 * PoolFormer models w/ weights adapted from https://github.com/sail-sg/poolformer
* Add Swin Transformer models and weights from https://github.com/microsoft/Swin-Transformer * VOLO models w/ weights adapted from https://github.com/sail-sg/volo
* Significant work experimenting with non-BatchNorm norm layers such as EvoNorm, FilterResponseNorm, GroupNorm, etc
### April 12, 2021 * Enhance support for alternate norm + act ('NormAct') layers added to a number of models, esp EfficientNet/MobileNetV3, RegNet, and aligned Xception
* Add ECA-NFNet-L1 (slimmed down F1 w/ SiLU, 41M params) trained with this code. 84% top-1 @ 320x320. Trained at 256x256. * Grouped conv support added to EfficientNet family
* Add EfficientNet-V2S model (unverified model definition) weights. 83.3 top-1 @ 288x288. Only trained single res 224. Working on progressive training. * Add 'group matching' API to all models to allow grouping model parameters for application of 'layer-wise' LR decay, lr scale added to LR scheduler
* Add ByoaNet model definition (Bring-your-own-attention) w/ SelfAttention block and corresponding SA/SA-like modules and model defs * Gradient checkpointing support added to many models
* Lambda Networks - https://arxiv.org/abs/2102.08602 * `forward_head(x, pre_logits=False)` fn added to all models to allow separate calls of `forward_features` + `forward_head`
* Bottleneck Transformers - https://arxiv.org/abs/2101.11605 * All vision transformer and vision MLP models update to return non-pooled / non-token selected features from `foward_features`, for consistency with CNN models, token selection or pooling now applied in `forward_head`
* Halo Nets - https://arxiv.org/abs/2103.12731
* Adabelief optimizer contributed by Juntang Zhuang ### Feb 2, 2022
* [Chris Hughes](https://github.com/Chris-hughes10) posted an exhaustive run through of `timm` on his blog yesterday. Well worth a read. [Getting Started with PyTorch Image Models (timm): A Practitioners Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055)
### April 1, 2021 * I'm currently prepping to merge the `norm_norm_norm` branch back to master (ver 0.6.x) in next week or so.
* Add snazzy `benchmark.py` script for bulk `timm` model benchmarking of train and/or inference * The changes are more extensive than usual and may destabilize and break some model API use (aiming for full backwards compat). So, beware `pip install git+https://github.com/rwightman/pytorch-image-models` installs!
* Add Pooling-based Vision Transformer (PiT) models (from https://github.com/naver-ai/pit) * `0.5.x` releases and a `0.5.x` branch will remain stable with a cherry pick or two until dust clears. Recommend sticking to pypi install for a bit if you want stable.
* Merged distilled variant into main for torchscript compatibility
* Some `timm` cleanup/style tweaks and weights have hub download support ### Jan 14, 2022
* Cleanup Vision Transformer (ViT) models * Version 0.5.4 w/ release to be pushed to pypi. It's been a while since last pypi update and riskier changes will be merged to main branch soon....
* Merge distilled (DeiT) model into main so that torchscript can work * Add ConvNeXT models /w weights from official impl (https://github.com/facebookresearch/ConvNeXt), a few perf tweaks, compatible with timm features
* Support updated weight init (defaults to old still) that closer matches original JAX impl (possibly better training from scratch) * Tried training a few small (~1.8-3M param) / mobile optimized models, a few are good so far, more on the way...
* Separate hybrid model defs into different file and add several new model defs to fiddle with, support patch_size != 1 for hybrids * `mnasnet_small` - 65.6 top-1
* Fix fine-tuning num_class changes (PiT and ViT) and pos_embed resizing (Vit) with distilled variants * `mobilenetv2_050` - 65.9
* nn.Sequential for block stack (does not break downstream compat) * `lcnet_100/075/050` - 72.1 / 68.8 / 63.1
* TnT (Transformer-in-Transformer) models contributed by author (from https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT) * `semnasnet_075` - 73
* Add RegNetY-160 weights from DeiT teacher model * `fbnetv3_b/d/g` - 79.1 / 79.7 / 82.0
* Add new NFNet-L0 w/ SE attn (rename `nfnet_l0b`->`nfnet_l0`) weights 82.75 top-1 @ 288x288 * TinyNet models added by [rsomani95](https://github.com/rsomani95)
* Some fixes/improvements for TFDS dataset wrapper * LCNet added via MobileNetV3 architecture
### March 7, 2021 ### Nov 22, 2021
* First 0.4.x PyPi release w/ NFNets (& related), ByoB (GPU-Efficient, RepVGG, etc). * A number of updated weights anew new model defs
* Change feature extraction for pre-activation nets (NFNets, ResNetV2) to return features before activation. * `eca_halonext26ts` - 79.5 @ 256
* `resnet50_gn` (new) - 80.1 @ 224, 81.3 @ 288
### Feb 18, 2021 * `resnet50` - 80.7 @ 224, 80.9 @ 288 (trained at 176, not replacing current a1 weights as default since these don't scale as well to higher res, [weights](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1h2_176-001a1197.pth))
* Add pretrained weights and model variants for NFNet-F* models from [DeepMind Haiku impl](https://github.com/deepmind/deepmind-research/tree/master/nfnets). * `resnext50_32x4d` - 81.1 @ 224, 82.0 @ 288
* Models are prefixed with `dm_`. They require SAME padding conv, skipinit enabled, and activation gains applied in act fn. * `sebotnet33ts_256` (new) - 81.2 @ 224
* These models are big, expect to run out of GPU memory. With the GELU activiation + other options, they are roughly 1/2 the inference speed of my SiLU PyTorch optimized `s` variants. * `lamhalobotnet50ts_256` - 81.5 @ 256
* Original model results are based on pre-processing that is not the same as all other models so you'll see different results in the results csv (once updated). * `halonet50ts` - 81.7 @ 256
* Matching the original pre-processing as closely as possible I get these results: * `halo2botnet50ts_256` - 82.0 @ 256
* `dm_nfnet_f6` - 86.352 * `resnet101` - 82.0 @ 224, 82.8 @ 288
* `dm_nfnet_f5` - 86.100 * `resnetv2_101` (new) - 82.1 @ 224, 83.0 @ 288
* `dm_nfnet_f4` - 85.834 * `resnet152` - 82.8 @ 224, 83.5 @ 288
* `dm_nfnet_f3` - 85.676 * `regnetz_d8` (new) - 83.5 @ 256, 84.0 @ 320
* `dm_nfnet_f2` - 85.178 * `regnetz_e8` (new) - 84.5 @ 256, 85.0 @ 320
* `dm_nfnet_f1` - 84.696 * `vit_base_patch8_224` (85.8 top-1) & `in21k` variant weights added thanks [Martins Bruveris](https://github.com/martinsbruveris)
* `dm_nfnet_f0` - 83.464 * Groundwork in for FX feature extraction thanks to [Alexander Soare](https://github.com/alexander-soare)
* models updated for tracing compatibility (almost full support with some distlled transformer exceptions)
### Feb 16, 2021
* Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py. ### Oct 19, 2021
* AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc` * ResNet strikes back (https://arxiv.org/abs/2110.00476) weights added, plus any extra training components used. Model weights and some more details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-rsb-weights)
* PyTorch global norm of 1.0 (old behaviour, always norm), `--clip-grad 1.0` * BCE loss and Repeated Augmentation support for RSB paper
* PyTorch value clipping of 10, `--clip-grad 10. --clip-mode value` * 4 series of ResNet based attention model experiments being added (implemented across byobnet.py/byoanet.py). These include all sorts of attention, from channel attn like SE, ECA to 2D QKV self-attention layers such as Halo, Bottlneck, Lambda. Details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights)
* AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training w/ NFNet/NF-ResNet. * Working implementations of the following 2D self-attention modules (likely to be differences from paper or eventual official impl):
* Halo (https://arxiv.org/abs/2103.12731)
### Feb 12, 2021 * Bottleneck Transformer (https://arxiv.org/abs/2101.11605)
* Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs * LambdaNetworks (https://arxiv.org/abs/2102.08602)
* A RegNetZ series of models with some attention experiments (being added to). These do not follow the paper (https://arxiv.org/abs/2103.06877) in any way other than block architecture, details of official models are not available. See more here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights)
### Feb 10, 2021 * ConvMixer (https://openreview.net/forum?id=TVHS5Y4dNvM), CrossVit (https://arxiv.org/abs/2103.14899), and BeiT (https://arxiv.org/abs/2106.08254) architectures + weights added
* More model archs, incl a flexible ByobNet backbone ('Bring-your-own-blocks') * freeze/unfreeze helpers by [Alexander Soare](https://github.com/alexander-soare)
* GPU-Efficient-Networks (https://github.com/idstcv/GPU-Efficient-Networks), impl in `byobnet.py`
* RepVGG (https://github.com/DingXiaoH/RepVGG), impl in `byobnet.py` ### Aug 18, 2021
* classic VGG (from torchvision, impl in `vgg`) * Optimizer bonanza!
* Refinements to normalizer layer arg handling and normalizer+act layer handling in some models * Add LAMB and LARS optimizers, incl trust ratio clipping options. Tweaked to work properly in PyTorch XLA (tested on TPUs w/ `timm bits` [branch](https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/timm/bits))
* Default AMP mode changed to native PyTorch AMP instead of APEX. Issues not being fixed with APEX. Native works with `--channels-last` and `--torchscript` model training, APEX does not. * Add MADGRAD from FB research w/ a few tweaks (decoupled decay option, step handling that works with PyTorch XLA)
* Fix a few bugs introduced since last pypi release * Some cleanup on all optimizers and factory. No more `.data`, a bit more consistency, unit tests for all!
* SGDP and AdamP still won't work with PyTorch XLA but others should (have yet to test Adabelief, Adafactor, Adahessian myself).
### Feb 8, 2021 * EfficientNet-V2 XL TF ported weights added, but they don't validate well in PyTorch (L is better). The pre-processing for the V2 TF training is a bit diff and the fine-tuned 21k -> 1k weights are very sensitive and less robust than the 1k weights.
* Add several ResNet weights with ECA attention. 26t & 50t trained @ 256, test @ 320. 269d train @ 256, fine-tune @320, test @ 352. * Added PyTorch trained EfficientNet-V2 'Tiny' w/ GlobalContext attn weights. Only .1-.2 top-1 better than the SE so more of a curiosity for those interested.
* `ecaresnet26t` - 79.88 top-1 @ 320x320, 79.08 @ 256x256
* `ecaresnet50t` - 82.35 top-1 @ 320x320, 81.52 @ 256x256 ### July 12, 2021
* `ecaresnet269d` - 84.93 top-1 @ 352x352, 84.87 @ 320x320 * Add XCiT models from [official facebook impl](https://github.com/facebookresearch/xcit). Contributed by [Alexander Soare](https://github.com/alexander-soare)
* Remove separate tiered (`t`) vs tiered_narrow (`tn`) ResNet model defs, all `tn` changed to `t` and `t` models removed (`seresnext26t_32x4d` only model w/ weights that was removed).
* Support model default_cfgs with separate train vs test resolution `test_input_size` and remove extra `_320` suffix ResNet model defs that were just for test. ### July 5-9, 2021
* Add `efficientnetv2_rw_t` weights, a custom 'tiny' 13.6M param variant that is a bit better than (non NoisyStudent) B3 models. Both faster and better accuracy (at same or lower res)
### Jan 30, 2021 * top-1 82.34 @ 288x288 and 82.54 @ 320x320
* Add initial "Normalization Free" NF-RegNet-B* and NF-ResNet model definitions based on [paper](https://arxiv.org/abs/2101.08692) * Add [SAM pretrained](https://arxiv.org/abs/2106.01548) in1k weight for ViT B/16 (`vit_base_patch16_sam_224`) and B/32 (`vit_base_patch32_sam_224`) models.
* Add 'Aggregating Nested Transformer' (NesT) w/ weights converted from official [Flax impl](https://github.com/google-research/nested-transformer). Contributed by [Alexander Soare](https://github.com/alexander-soare).
### Jan 25, 2021 * `jx_nest_base` - 83.534, `jx_nest_small` - 83.120, `jx_nest_tiny` - 81.426
* Add ResNetV2 Big Transfer (BiT) models w/ ImageNet-1k and 21k weights from https://github.com/google-research/big_transfer
* Add official R50+ViT-B/16 hybrid models + weights from https://github.com/google-research/vision_transformer ### June 23, 2021
* ImageNet-21k ViT weights are added w/ model defs and representation layer (pre logits) support * Reproduce gMLP model training, `gmlp_s16_224` trained to 79.6 top-1, matching [paper](https://arxiv.org/abs/2105.08050). Hparams for this and other recent MLP training [here](https://gist.github.com/rwightman/d6c264a9001f9167e06c209f630b2cc6)
* NOTE: ImageNet-21k classifier heads were zero'd in original weights, they are only useful for transfer learning
* Add model defs and weights for DeiT Vision Transformer models from https://github.com/facebookresearch/deit ### June 20, 2021
* Refactor dataset classes into ImageDataset/IterableImageDataset + dataset specific parser classes * Release Vision Transformer 'AugReg' weights from [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270)
* Add Tensorflow-Datasets (TFDS) wrapper to allow use of TFDS image classification sets with train script * .npz weight loading support added, can load any of the 50K+ weights from the [AugReg series](https://console.cloud.google.com/storage/browser/vit_models/augreg)
* Ex: `train.py /data/tfds --dataset tfds/oxford_iiit_pet --val-split test --model resnet50 -b 256 --amp --num-classes 37 --opt adamw --lr 3e-4 --weight-decay .001 --pretrained -j 2` * See [example notebook](https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax_augreg.ipynb) from [official impl](https://github.com/google-research/vision_transformer/) for navigating the augreg weights
* Add improved .tar dataset parser that reads images from .tar, folder of .tar files, or .tar within .tar * Replaced all default weights w/ best AugReg variant (if possible). All AugReg 21k classifiers work.
* Run validation on full ImageNet-21k directly from tar w/ BiT model: `validate.py /data/fall11_whole.tar --model resnetv2_50x1_bitm_in21k --amp` * Highlights: `vit_large_patch16_384` (87.1 top-1), `vit_large_r50_s32_384` (86.2 top-1), `vit_base_patch16_384` (86.0 top-1)
* Models in this update should be stable w/ possible exception of ViT/BiT, possibility of some regressions with train/val scripts and dataset handling * `vit_deit_*` renamed to just `deit_*`
* Remove my old small model, replace with DeiT compatible small w/ AugReg weights
### Jan 3, 2021 * Add 1st training of my `gmixer_24_224` MLP /w GLU, 78.1 top-1 w/ 25M params.
* Add SE-ResNet-152D weights * Add weights from official ResMLP release (https://github.com/facebookresearch/deit)
* 256x256 val, 0.94 crop top-1 - 83.75 * Add `eca_nfnet_l2` weights from my 'lightweight' series. 84.7 top-1 at 384x384.
* 320x320 val, 1.0 crop - 84.36 * Add distilled BiT 50x1 student and 152x2 Teacher weights from [Knowledge distillation: A good teacher is patient and consistent](https://arxiv.org/abs/2106.05237)
* Update results files * NFNets and ResNetV2-BiT models work w/ Pytorch XLA now
* weight standardization uses F.batch_norm instead of std_mean (std_mean wasn't lowered)
* eps values adjusted, will be slight differences but should be quite close
* Improve test coverage and classifier interface of non-conv (vision transformer and mlp) models
* Cleanup a few classifier / flatten details for models w/ conv classifiers or early global pool
* Please report any regressions, this PR touched quite a few models.

@ -4,7 +4,7 @@
Welcome to the `timm` documentation, a lean set of docs that covers the basics of `timm`. Welcome to the `timm` documentation, a lean set of docs that covers the basics of `timm`.
For a more comprehensive set of docs (currently under development), please visit [timmdocs](https://fastai.github.io/timmdocs/) by [Aman Arora](https://github.com/amaarora). For a more comprehensive set of docs (currently under development), please visit [timmdocs](http://timm.fast.ai) by [Aman Arora](https://github.com/amaarora).
## Install ## Install
@ -20,17 +20,17 @@ pip install git+https://github.com/rwightman/pytorch-image-models.git
``` ```
!!! info "Conda Environment" !!! info "Conda Environment"
All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically Python 3.6.x, 3.7.x., 3.8.x., 3.9 All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically 3.7, 3.8, 3.9, 3.10
Little to no care has been taken to be Python 2.x friendly and will not support it. If you run into any challenges running on Windows, or other OS, I'm definitely open to looking into those issues so long as it's in a reproducible (read Conda) environment. Little to no care has been taken to be Python 2.x friendly and will not support it. If you run into any challenges running on Windows, or other OS, I'm definitely open to looking into those issues so long as it's in a reproducible (read Conda) environment.
PyTorch versions 1.4, 1.5.x, 1.6, 1.7.x, and 1.8 have been tested with this code. PyTorch versions 1.9, 1.10, 1.11 have been tested with the latest versions of this code.
I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda: I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda:
``` ```
conda create -n torch-env conda create -n torch-env
conda activate torch-env conda activate torch-env
conda install pytorch torchvision cudatoolkit=11.1 -c pytorch -c conda-forge conda install pytorch torchvision cudatoolkit=11.3 -c pytorch
conda install pyyaml conda install pyyaml
``` ```

@ -12,7 +12,7 @@ To train an SE-ResNet34 on ImageNet, locally distributed, 4 GPUs, one process pe
`./distributed_train.sh 4 /data/imagenet --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 --amp -j 4` `./distributed_train.sh 4 /data/imagenet --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 --amp -j 4`
NOTE: It is recommended to use PyTorch 1.7+ w/ PyTorch native AMP and DDP instead of APEX AMP. `--amp` defaults to native AMP as of timm ver 0.4.3. `--apex-amp` will force use of APEX components if they are installed. NOTE: It is recommended to use PyTorch 1.9+ w/ PyTorch native AMP and DDP instead of APEX AMP. `--amp` defaults to native AMP as of timm ver 0.4.3. `--apex-amp` will force use of APEX components if they are installed.
## Validation / Inference Scripts ## Validation / Inference Scripts

@ -25,7 +25,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
NON_STD_FILTERS = [ NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*',
'poolformer_*', 'volo_*'] 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*']
NUM_NON_STD = len(NON_STD_FILTERS) NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures # exclude models that cause specific test failures
@ -202,13 +202,15 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
pytest.skip("Fixed input size model > limit.") pytest.skip("Fixed input size model > limit.")
input_tensor = torch.randn((batch_size, *input_size)) input_tensor = torch.randn((batch_size, *input_size))
feat_dim = getattr(model, 'feature_dim', None)
outputs = model.forward_features(input_tensor) outputs = model.forward_features(input_tensor)
if isinstance(outputs, (tuple, list)): if isinstance(outputs, (tuple, list)):
# cannot currently verify multi-tensor output. # cannot currently verify multi-tensor output.
pass pass
else: else:
feat_dim = -1 if outputs.ndim == 3 else 1 if feat_dim is None:
feat_dim = -1 if outputs.ndim == 3 else 1
assert outputs.shape[feat_dim] == model.num_features assert outputs.shape[feat_dim] == model.num_features
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
@ -216,14 +218,16 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
outputs = model.forward(input_tensor) outputs = model.forward(input_tensor)
if isinstance(outputs, (tuple, list)): if isinstance(outputs, (tuple, list)):
outputs = outputs[0] outputs = outputs[0]
feat_dim = -1 if outputs.ndim == 3 else 1 if feat_dim is None:
feat_dim = -1 if outputs.ndim == 3 else 1
assert outputs.shape[feat_dim] == model.num_features, 'pooled num_features != config' assert outputs.shape[feat_dim] == model.num_features, 'pooled num_features != config'
model = create_model(model_name, pretrained=False, num_classes=0).eval() model = create_model(model_name, pretrained=False, num_classes=0).eval()
outputs = model.forward(input_tensor) outputs = model.forward(input_tensor)
if isinstance(outputs, (tuple, list)): if isinstance(outputs, (tuple, list)):
outputs = outputs[0] outputs = outputs[0]
feat_dim = -1 if outputs.ndim == 3 else 1 if feat_dim is None:
feat_dim = -1 if outputs.ndim == 3 else 1
assert outputs.shape[feat_dim] == model.num_features assert outputs.shape[feat_dim] == model.num_features
# check classifier name matches default_cfg # check classifier name matches default_cfg

@ -39,8 +39,10 @@ from .resnetv2 import *
from .rexnet import * from .rexnet import *
from .selecsls import * from .selecsls import *
from .senet import * from .senet import *
from .sequencer import *
from .sknet import * from .sknet import *
from .swin_transformer import * from .swin_transformer import *
from .swin_transformer_v2 import *
from .swin_transformer_v2_cr import * from .swin_transformer_v2_cr import *
from .tnt import * from .tnt import *
from .tresnet import * from .tresnet import *
@ -49,6 +51,7 @@ from .vgg import *
from .visformer import * from .visformer import *
from .vision_transformer import * from .vision_transformer import *
from .vision_transformer_hybrid import * from .vision_transformer_hybrid import *
from .vision_transformer_relpos import *
from .volo import * from .volo import *
from .vovnet import * from .vovnet import *
from .xception import * from .xception import *

@ -46,27 +46,27 @@ def _cfg(url='', **kwargs):
default_cfgs = { default_cfgs = {
'beit_base_patch16_224': _cfg( 'beit_base_patch16_224': _cfg(
url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'), url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'),
'beit_base_patch16_384': _cfg( 'beit_base_patch16_384': _cfg(
url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth', url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth',
input_size=(3, 384, 384), crop_pct=1.0, input_size=(3, 384, 384), crop_pct=1.0,
), ),
'beit_base_patch16_224_in22k': _cfg( 'beit_base_patch16_224_in22k': _cfg(
url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22k.pth', url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth',
num_classes=21841, num_classes=21841,
), ),
'beit_large_patch16_224': _cfg( 'beit_large_patch16_224': _cfg(
url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'), url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'),
'beit_large_patch16_384': _cfg( 'beit_large_patch16_384': _cfg(
url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth', url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth',
input_size=(3, 384, 384), crop_pct=1.0, input_size=(3, 384, 384), crop_pct=1.0,
), ),
'beit_large_patch16_512': _cfg( 'beit_large_patch16_512': _cfg(
url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth', url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth',
input_size=(3, 512, 512), crop_pct=1.0, input_size=(3, 512, 512), crop_pct=1.0,
), ),
'beit_large_patch16_224_in22k': _cfg( 'beit_large_patch16_224_in22k': _cfg(
url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_224_pt22k_ft22k.pth', url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth',
num_classes=21841, num_classes=21841,
), ),
} }

@ -421,14 +421,14 @@ def convnext_large(pretrained=False, **kwargs):
@register_model @register_model
def convnext_tiny_in22ft1k(pretrained=False, **kwargs): def convnext_tiny_in22ft1k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs)
model = _create_convnext('convnext_tiny_in22ft1k', pretrained=pretrained, **model_args) model = _create_convnext('convnext_tiny_in22ft1k', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def convnext_small_in22ft1k(pretrained=False, **kwargs): def convnext_small_in22ft1k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
model = _create_convnext('convnext_small_in22ft1k', pretrained=pretrained, **model_args) model = _create_convnext('convnext_small_in22ft1k', pretrained=pretrained, **model_args)
return model return model
@ -456,14 +456,14 @@ def convnext_xlarge_in22ft1k(pretrained=False, **kwargs):
@register_model @register_model
def convnext_tiny_384_in22ft1k(pretrained=False, **kwargs): def convnext_tiny_384_in22ft1k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs)
model = _create_convnext('convnext_tiny_384_in22ft1k', pretrained=pretrained, **model_args) model = _create_convnext('convnext_tiny_384_in22ft1k', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def convnext_small_384_in22ft1k(pretrained=False, **kwargs): def convnext_small_384_in22ft1k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
model = _create_convnext('convnext_small_384_in22ft1k', pretrained=pretrained, **model_args) model = _create_convnext('convnext_small_384_in22ft1k', pretrained=pretrained, **model_args)
return model return model
@ -491,14 +491,14 @@ def convnext_xlarge_384_in22ft1k(pretrained=False, **kwargs):
@register_model @register_model
def convnext_tiny_in22k(pretrained=False, **kwargs): def convnext_tiny_in22k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs)
model = _create_convnext('convnext_tiny_in22k', pretrained=pretrained, **model_args) model = _create_convnext('convnext_tiny_in22k', pretrained=pretrained, **model_args)
return model return model
@register_model @register_model
def convnext_small_in22k(pretrained=False, **kwargs): def convnext_small_in22k(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs)
model = _create_convnext('convnext_small_in22k', pretrained=pretrained, **model_args) model = _create_convnext('convnext_small_in22k', pretrained=pretrained, **model_args)
return model return model

@ -477,7 +477,7 @@ def build_model_with_cfg(
pretrained_cfg: Optional[Dict] = None, pretrained_cfg: Optional[Dict] = None,
model_cfg: Optional[Any] = None, model_cfg: Optional[Any] = None,
feature_cfg: Optional[Dict] = None, feature_cfg: Optional[Dict] = None,
pretrained_strict: bool = False, pretrained_strict: bool = True,
pretrained_filter_fn: Optional[Callable] = None, pretrained_filter_fn: Optional[Callable] = None,
pretrained_custom_load: bool = False, pretrained_custom_load: bool = False,
kwargs_filter: Optional[Tuple[str]] = None, kwargs_filter: Optional[Tuple[str]] = None,

@ -91,7 +91,8 @@ class CondConv2d(nn.Module):
bias = torch.matmul(routing_weights, self.bias) bias = torch.matmul(routing_weights, self.bias)
bias = bias.view(B * self.out_channels) bias = bias.view(B * self.out_channels)
# move batch elements with channels so each batch element can be efficiently convolved with separate kernel # move batch elements with channels so each batch element can be efficiently convolved with separate kernel
x = x.view(1, B * C, H, W) # reshape instead of view to work with channels_last input
x = x.reshape(1, B * C, H, W)
if self.dynamic_padding: if self.dynamic_padding:
out = conv2d_same( out = conv2d_same(
x, weight, bias, stride=self.stride, padding=self.padding, x, weight, bias, stride=self.stride, padding=self.padding,

@ -10,16 +10,17 @@ from .helpers import to_2tuple
class Mlp(nn.Module): class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks """ MLP as used in Vision Transformer, MLP-Mixer and related networks
""" """
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
hidden_features = hidden_features or in_features hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop) drop_probs = to_2tuple(drop)
self.fc1 = nn.Linear(in_features, hidden_features) self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
self.act = act_layer() self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0]) self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1]) self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x): def forward(self, x):
@ -35,17 +36,18 @@ class GluMlp(nn.Module):
""" MLP w/ GLU style gating """ MLP w/ GLU style gating
See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
""" """
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, bias=True, drop=0.):
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
hidden_features = hidden_features or in_features hidden_features = hidden_features or in_features
assert hidden_features % 2 == 0 assert hidden_features % 2 == 0
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop) drop_probs = to_2tuple(drop)
self.fc1 = nn.Linear(in_features, hidden_features) self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
self.act = act_layer() self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0]) self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features // 2, out_features) self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1]) self.drop2 = nn.Dropout(drop_probs[1])
def init_weights(self): def init_weights(self):
@ -67,14 +69,16 @@ class GluMlp(nn.Module):
class GatedMlp(nn.Module): class GatedMlp(nn.Module):
""" MLP as used in gMLP """ MLP as used in gMLP
""" """
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, def __init__(
gate_layer=None, drop=0.): self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
gate_layer=None, bias=True, drop=0.):
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
hidden_features = hidden_features or in_features hidden_features = hidden_features or in_features
bias = to_2tuple(bias)
drop_probs = to_2tuple(drop) drop_probs = to_2tuple(drop)
self.fc1 = nn.Linear(in_features, hidden_features) self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
self.act = act_layer() self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0]) self.drop1 = nn.Dropout(drop_probs[0])
if gate_layer is not None: if gate_layer is not None:
@ -83,7 +87,7 @@ class GatedMlp(nn.Module):
hidden_features = hidden_features // 2 # FIXME base reduction on gate property? hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
else: else:
self.gate = nn.Identity() self.gate = nn.Identity()
self.fc2 = nn.Linear(hidden_features, out_features) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1]) self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x): def forward(self, x):
@ -100,15 +104,18 @@ class ConvMlp(nn.Module):
""" MLP using 1x1 convs that keeps spatial dims """ MLP using 1x1 convs that keeps spatial dims
""" """
def __init__( def __init__(
self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.): self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU,
norm_layer=None, bias=True, drop=0.):
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
hidden_features = hidden_features or in_features hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True) bias = to_2tuple(bias)
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0])
self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
self.act = act_layer() self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True)
self.drop = nn.Dropout(drop) self.drop = nn.Dropout(drop)
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1])
def forward(self, x): def forward(self, x):
x = self.fc1(x) x = self.fc1(x)

@ -234,8 +234,8 @@ class PoolFormer(nn.Module):
return dict( return dict(
stem=r'^patch_embed', # stem and embed stem=r'^patch_embed', # stem and embed
blocks=[ blocks=[
(r'^network\.(\d+)\.(\d+)', None), (r'^network\.(\d+).*\.proj', (99999,)),
(r'^network\.(\d+)', (0,)), (r'^network\.(\d+)', None) if coarse else (r'^network\.(\d+)\.(\d+)', None),
(r'^norm', (99999,)) (r'^norm', (99999,))
], ],
) )

@ -458,7 +458,7 @@ class RegNet(nn.Module):
def group_matcher(self, coarse=False): def group_matcher(self, coarse=False):
return dict( return dict(
stem=r'^stem', stem=r'^stem',
blocks=r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.blocks\.(\d+)', blocks=r'^s(\d+)' if coarse else r'^s(\d+)\.b(\d+)',
) )
@torch.jit.ignore @torch.jit.ignore

@ -148,6 +148,49 @@ default_cfgs = {
'swsl_resnext101_32x16d': _cfg( 'swsl_resnext101_32x16d': _cfg(
url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth'), url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth'),
# Efficient Channel Attention ResNets
'ecaresnet26t': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet26t_ra2-46609757.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
crop_pct=0.95, test_input_size=(3, 320, 320)),
'ecaresnetlight': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNetLight_4f34b35b.pth',
interpolation='bicubic'),
'ecaresnet50d': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet50D_833caf58.pth',
interpolation='bicubic',
first_conv='conv1.0'),
'ecaresnet50d_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45899/outputs/ECAResNet50D_P_9c67f710.pth',
interpolation='bicubic',
first_conv='conv1.0'),
'ecaresnet50t': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet50t_ra2-f7ac63c4.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
crop_pct=0.95, test_input_size=(3, 320, 320)),
'ecaresnet101d': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet101D_281c5844.pth',
interpolation='bicubic', first_conv='conv1.0'),
'ecaresnet101d_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
interpolation='bicubic',
first_conv='conv1.0'),
'ecaresnet200d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
'ecaresnet269d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet269d_320_ra2-7baa55cb.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), pool_size=(10, 10),
crop_pct=1.0, test_input_size=(3, 352, 352)),
# Efficient Channel Attention ResNeXts
'ecaresnext26t_32x4d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0'),
'ecaresnext50t_32x4d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0'),
# Squeeze-Excitation ResNets, to eventually replace the models in senet.py # Squeeze-Excitation ResNets, to eventually replace the models in senet.py
'seresnet18': _cfg( 'seresnet18': _cfg(
url='', url='',
@ -180,7 +223,6 @@ default_cfgs = {
url='', url='',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)), interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
# Squeeze-Excitation ResNeXts, to eventually replace the models in senet.py # Squeeze-Excitation ResNeXts, to eventually replace the models in senet.py
'seresnext26d_32x4d': _cfg( 'seresnext26d_32x4d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26d_32x4d-80fa48a3.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26d_32x4d-80fa48a3.pth',
@ -199,55 +241,16 @@ default_cfgs = {
'seresnext101_32x8d': _cfg( 'seresnext101_32x8d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnext101_32x8d_ah-e6bc4c0a.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnext101_32x8d_ah-e6bc4c0a.pth',
interpolation='bicubic', test_input_size=(3, 288, 288), crop_pct=1.0), interpolation='bicubic', test_input_size=(3, 288, 288), crop_pct=1.0),
'seresnext101d_32x8d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnext101d_32x8d_ah-191d7b94.pth',
interpolation='bicubic', first_conv='conv1.0', test_input_size=(3, 288, 288), crop_pct=1.0),
'senet154': _cfg( 'senet154': _cfg(
url='', url='',
interpolation='bicubic', interpolation='bicubic',
first_conv='conv1.0'), first_conv='conv1.0'),
# Efficient Channel Attention ResNets # ResNets with anti-aliasing / blur pool
'ecaresnet26t': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet26t_ra2-46609757.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
crop_pct=0.95, test_input_size=(3, 320, 320)),
'ecaresnetlight': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNetLight_4f34b35b.pth',
interpolation='bicubic'),
'ecaresnet50d': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet50D_833caf58.pth',
interpolation='bicubic',
first_conv='conv1.0'),
'ecaresnet50d_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45899/outputs/ECAResNet50D_P_9c67f710.pth',
interpolation='bicubic',
first_conv='conv1.0'),
'ecaresnet50t': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet50t_ra2-f7ac63c4.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8),
crop_pct=0.95, test_input_size=(3, 320, 320)),
'ecaresnet101d': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet101D_281c5844.pth',
interpolation='bicubic', first_conv='conv1.0'),
'ecaresnet101d_pruned': _cfg(
url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth',
interpolation='bicubic',
first_conv='conv1.0'),
'ecaresnet200d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)),
'ecaresnet269d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet269d_320_ra2-7baa55cb.pth',
interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), pool_size=(10, 10),
crop_pct=1.0, test_input_size=(3, 352, 352)),
# Efficient Channel Attention ResNeXts
'ecaresnext26t_32x4d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0'),
'ecaresnext50t_32x4d': _cfg(
url='',
interpolation='bicubic', first_conv='conv1.0'),
# ResNets with anti-aliasing blur pool
'resnetblur18': _cfg( 'resnetblur18': _cfg(
interpolation='bicubic'), interpolation='bicubic'),
'resnetblur50': _cfg( 'resnetblur50': _cfg(
@ -268,6 +271,9 @@ default_cfgs = {
'seresnetaa50d': _cfg( 'seresnetaa50d': _cfg(
url='', url='',
interpolation='bicubic', first_conv='conv1.0'), interpolation='bicubic', first_conv='conv1.0'),
'seresnextaa101d_32x8d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnextaa101d_32x8d_ah-83c8ae12.pth',
interpolation='bicubic', first_conv='conv1.0', test_input_size=(3, 288, 288), crop_pct=1.0),
# ResNet-RS models # ResNet-RS models
'resnetrs50': _cfg( 'resnetrs50': _cfg(
@ -1157,98 +1163,6 @@ def ecaresnet50d(pretrained=False, **kwargs):
return _create_resnet('ecaresnet50d', pretrained, **model_args) return _create_resnet('ecaresnet50d', pretrained, **model_args)
@register_model
def resnetrs50(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-50 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs50', pretrained, **model_args)
@register_model
def resnetrs101(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-101 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs101', pretrained, **model_args)
@register_model
def resnetrs152(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-152 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs152', pretrained, **model_args)
@register_model
def resnetrs200(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-200 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs200', pretrained, **model_args)
@register_model
def resnetrs270(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-270 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs270', pretrained, **model_args)
@register_model
def resnetrs350(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-350 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs350', pretrained, **model_args)
@register_model
def resnetrs420(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-420 model
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs420', pretrained, **model_args)
@register_model @register_model
def ecaresnet50d_pruned(pretrained=False, **kwargs): def ecaresnet50d_pruned(pretrained=False, **kwargs):
"""Constructs a ResNet-50-D model pruned with eca. """Constructs a ResNet-50-D model pruned with eca.
@ -1346,72 +1260,6 @@ def ecaresnext50t_32x4d(pretrained=False, **kwargs):
return _create_resnet('ecaresnext50t_32x4d', pretrained, **model_args) return _create_resnet('ecaresnext50t_32x4d', pretrained, **model_args)
@register_model
def resnetblur18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model with blur anti-aliasing
"""
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d, **kwargs)
return _create_resnet('resnetblur18', pretrained, **model_args)
@register_model
def resnetblur50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model with blur anti-aliasing
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, **kwargs)
return _create_resnet('resnetblur50', pretrained, **model_args)
@register_model
def resnetblur50d(pretrained=False, **kwargs):
"""Constructs a ResNet-50-D model with blur anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetblur50d', pretrained, **model_args)
@register_model
def resnetblur101d(pretrained=False, **kwargs):
"""Constructs a ResNet-101-D model with blur anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=BlurPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetblur101d', pretrained, **model_args)
@register_model
def resnetaa50d(pretrained=False, **kwargs):
"""Constructs a ResNet-50-D model with avgpool anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetaa50d', pretrained, **model_args)
@register_model
def resnetaa101d(pretrained=False, **kwargs):
"""Constructs a ResNet-101-D model with avgpool anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetaa101d', pretrained, **model_args)
@register_model
def seresnetaa50d(pretrained=False, **kwargs):
"""Constructs a SE=ResNet-50-D model with avgpool anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnetaa50d', pretrained, **model_args)
@register_model @register_model
def seresnet18(pretrained=False, **kwargs): def seresnet18(pretrained=False, **kwargs):
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'), **kwargs) model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'), **kwargs)
@ -1535,9 +1383,187 @@ def seresnext101_32x8d(pretrained=False, **kwargs):
return _create_resnet('seresnext101_32x8d', pretrained, **model_args) return _create_resnet('seresnext101_32x8d', pretrained, **model_args)
@register_model
def seresnext101d_32x8d(pretrained=False, **kwargs):
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8,
stem_width=32, stem_type='deep', avg_down=True,
block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnext101d_32x8d', pretrained, **model_args)
@register_model @register_model
def senet154(pretrained=False, **kwargs): def senet154(pretrained=False, **kwargs):
model_args = dict( model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep',
down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se'), **kwargs) down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('senet154', pretrained, **model_args) return _create_resnet('senet154', pretrained, **model_args)
@register_model
def resnetblur18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model with blur anti-aliasing
"""
model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d, **kwargs)
return _create_resnet('resnetblur18', pretrained, **model_args)
@register_model
def resnetblur50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model with blur anti-aliasing
"""
model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, **kwargs)
return _create_resnet('resnetblur50', pretrained, **model_args)
@register_model
def resnetblur50d(pretrained=False, **kwargs):
"""Constructs a ResNet-50-D model with blur anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetblur50d', pretrained, **model_args)
@register_model
def resnetblur101d(pretrained=False, **kwargs):
"""Constructs a ResNet-101-D model with blur anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=BlurPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetblur101d', pretrained, **model_args)
@register_model
def resnetaa50d(pretrained=False, **kwargs):
"""Constructs a ResNet-50-D model with avgpool anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetaa50d', pretrained, **model_args)
@register_model
def resnetaa101d(pretrained=False, **kwargs):
"""Constructs a ResNet-101-D model with avgpool anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True, **kwargs)
return _create_resnet('resnetaa101d', pretrained, **model_args)
@register_model
def seresnetaa50d(pretrained=False, **kwargs):
"""Constructs a SE=ResNet-50-D model with avgpool anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d,
stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnetaa50d', pretrained, **model_args)
@register_model
def seresnextaa101d_32x8d(pretrained=False, **kwargs):
"""Constructs a SE=ResNeXt-101-D 32x8d model with avgpool anti-aliasing
"""
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8,
stem_width=32, stem_type='deep', avg_down=True, aa_layer=nn.AvgPool2d,
block_args=dict(attn_layer='se'), **kwargs)
return _create_resnet('seresnextaa101d_32x8d', pretrained, **model_args)
@register_model
def resnetrs50(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-50 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs50', pretrained, **model_args)
@register_model
def resnetrs101(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-101 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs101', pretrained, **model_args)
@register_model
def resnetrs152(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-152 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs152', pretrained, **model_args)
@register_model
def resnetrs200(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-200 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs200', pretrained, **model_args)
@register_model
def resnetrs270(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-270 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs270', pretrained, **model_args)
@register_model
def resnetrs350(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-350 model.
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs350', pretrained, **model_args)
@register_model
def resnetrs420(pretrained=False, **kwargs):
"""Constructs a ResNet-RS-420 model
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
return _create_resnet('resnetrs420', pretrained, **model_args)

@ -0,0 +1,417 @@
""" Sequencer
Paper: `Sequencer: Deep LSTM for Image Classification` - https://arxiv.org/pdf/2205.01972.pdf
"""
# Copyright (c) 2022. Yuki Tatsunami
# Licensed under the Apache License, Version 2.0 (the "License");
import math
from functools import partial
from typing import Tuple
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT
from .helpers import build_model_with_cfg, named_apply
from .layers import lecun_normal_, DropPath, Mlp, PatchEmbed as TimmPatchEmbed
from .registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': DEFAULT_CROP_PCT, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = dict(
sequencer2d_s=_cfg(url="https://github.com/okojoalg/sequencer/releases/download/weights/sequencer2d_s.pth"),
sequencer2d_m=_cfg(url="https://github.com/okojoalg/sequencer/releases/download/weights/sequencer2d_m.pth"),
sequencer2d_l=_cfg(url="https://github.com/okojoalg/sequencer/releases/download/weights/sequencer2d_l.pth"),
)
def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False):
if isinstance(module, nn.Linear):
if name.startswith('head'):
nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias)
else:
if flax:
# Flax defaults
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
else:
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
if 'mlp' in name:
nn.init.normal_(module.bias, std=1e-6)
else:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
lecun_normal_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, (nn.RNN, nn.GRU, nn.LSTM)):
stdv = 1.0 / math.sqrt(module.hidden_size)
for weight in module.parameters():
nn.init.uniform_(weight, -stdv, stdv)
elif hasattr(module, 'init_weights'):
module.init_weights()
def get_stage(
index, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer, rnn_layer, mlp_layer,
norm_layer, act_layer, num_layers, bidirectional, union,
with_fc, drop=0., drop_path_rate=0., **kwargs):
assert len(layers) == len(patch_sizes) == len(embed_dims) == len(hidden_sizes) == len(mlp_ratios)
blocks = []
for block_idx in range(layers[index]):
drop_path = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1)
blocks.append(block_layer(
embed_dims[index], hidden_sizes[index], mlp_ratio=mlp_ratios[index],
rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer,
num_layers=num_layers, bidirectional=bidirectional, union=union, with_fc=with_fc,
drop=drop, drop_path=drop_path))
if index < len(embed_dims) - 1:
blocks.append(Downsample2D(embed_dims[index], embed_dims[index + 1], patch_sizes[index + 1]))
blocks = nn.Sequential(*blocks)
return blocks
class RNNIdentity(nn.Module):
def __init__(self, *args, **kwargs):
super(RNNIdentity, self).__init__()
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, None]:
return x, None
class RNN2DBase(nn.Module):
def __init__(
self, input_size: int, hidden_size: int,
num_layers: int = 1, bias: bool = True, bidirectional: bool = True,
union="cat", with_fc=True):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = 2 * hidden_size if bidirectional else hidden_size
self.union = union
self.with_vertical = True
self.with_horizontal = True
self.with_fc = with_fc
self.fc = None
if with_fc:
if union == "cat":
self.fc = nn.Linear(2 * self.output_size, input_size)
elif union == "add":
self.fc = nn.Linear(self.output_size, input_size)
elif union == "vertical":
self.fc = nn.Linear(self.output_size, input_size)
self.with_horizontal = False
elif union == "horizontal":
self.fc = nn.Linear(self.output_size, input_size)
self.with_vertical = False
else:
raise ValueError("Unrecognized union: " + union)
elif union == "cat":
pass
if 2 * self.output_size != input_size:
raise ValueError(f"The output channel {2 * self.output_size} is different from the input channel {input_size}.")
elif union == "add":
pass
if self.output_size != input_size:
raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.")
elif union == "vertical":
if self.output_size != input_size:
raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.")
self.with_horizontal = False
elif union == "horizontal":
if self.output_size != input_size:
raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.")
self.with_vertical = False
else:
raise ValueError("Unrecognized union: " + union)
self.rnn_v = RNNIdentity()
self.rnn_h = RNNIdentity()
def forward(self, x):
B, H, W, C = x.shape
if self.with_vertical:
v = x.permute(0, 2, 1, 3)
v = v.reshape(-1, H, C)
v, _ = self.rnn_v(v)
v = v.reshape(B, W, H, -1)
v = v.permute(0, 2, 1, 3)
else:
v = None
if self.with_horizontal:
h = x.reshape(-1, W, C)
h, _ = self.rnn_h(h)
h = h.reshape(B, H, W, -1)
else:
h = None
if v is not None and h is not None:
if self.union == "cat":
x = torch.cat([v, h], dim=-1)
else:
x = v + h
elif v is not None:
x = v
elif h is not None:
x = h
if self.fc is not None:
x = self.fc(x)
return x
class LSTM2D(RNN2DBase):
def __init__(
self, input_size: int, hidden_size: int,
num_layers: int = 1, bias: bool = True, bidirectional: bool = True,
union="cat", with_fc=True):
super().__init__(input_size, hidden_size, num_layers, bias, bidirectional, union, with_fc)
if self.with_vertical:
self.rnn_v = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bias=bias, bidirectional=bidirectional)
if self.with_horizontal:
self.rnn_h = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bias=bias, bidirectional=bidirectional)
class Sequencer2DBlock(nn.Module):
def __init__(
self, dim, hidden_size, mlp_ratio=3.0, rnn_layer=LSTM2D, mlp_layer=Mlp,
norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU,
num_layers=1, bidirectional=True, union="cat", with_fc=True, drop=0., drop_path=0.):
super().__init__()
channels_dim = int(mlp_ratio * dim)
self.norm1 = norm_layer(dim)
self.rnn_tokens = rnn_layer(dim, hidden_size, num_layers=num_layers, bidirectional=bidirectional,
union=union, with_fc=with_fc)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.rnn_tokens(self.norm1(x)))
x = x + self.drop_path(self.mlp_channels(self.norm2(x)))
return x
class PatchEmbed(TimmPatchEmbed):
def forward(self, x):
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
else:
x = x.permute(0, 2, 3, 1) # BCHW -> BHWC
x = self.norm(x)
return x
class Shuffle(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
if self.training:
B, H, W, C = x.shape
r = torch.randperm(H * W)
x = x.reshape(B, -1, C)
x = x[:, r, :].reshape(B, H, W, -1)
return x
class Downsample2D(nn.Module):
def __init__(self, input_dim, output_dim, patch_size):
super().__init__()
self.down = nn.Conv2d(input_dim, output_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = x.permute(0, 3, 1, 2)
x = self.down(x)
x = x.permute(0, 2, 3, 1)
return x
class Sequencer2D(nn.Module):
def __init__(
self,
num_classes=1000,
img_size=224,
in_chans=3,
global_pool='avg',
layers=[4, 3, 8, 3],
patch_sizes=[7, 2, 1, 1],
embed_dims=[192, 384, 384, 384],
hidden_sizes=[48, 96, 96, 96],
mlp_ratios=[3.0, 3.0, 3.0, 3.0],
block_layer=Sequencer2DBlock,
rnn_layer=LSTM2D,
mlp_layer=Mlp,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
num_rnn_layers=1,
bidirectional=True,
union="cat",
with_fc=True,
drop_rate=0.,
drop_path_rate=0.,
nlhb=False,
stem_norm=False,
):
super().__init__()
assert global_pool in ('', 'avg')
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = embed_dims[-1] # num_features for consistency with other models
self.feature_dim = -1 # channel dim index for feature outputs (rank 4, NHWC)
self.embed_dims = embed_dims
self.stem = PatchEmbed(
img_size=img_size, patch_size=patch_sizes[0], in_chans=in_chans,
embed_dim=embed_dims[0], norm_layer=norm_layer if stem_norm else None,
flatten=False)
self.blocks = nn.Sequential(*[
get_stage(
i, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer=block_layer,
rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer,
num_layers=num_rnn_layers, bidirectional=bidirectional,
union=union, with_fc=with_fc, drop=drop_rate, drop_path_rate=drop_path_rate,
)
for i, _ in enumerate(embed_dims)])
self.norm = norm_layer(embed_dims[-1])
self.head = nn.Linear(embed_dims[-1], self.num_classes) if num_classes > 0 else nn.Identity()
self.init_weights(nlhb=nlhb)
def init_weights(self, nlhb=False):
head_bias = -math.log(self.num_classes) if nlhb else 0.
named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem',
blocks=[
(r'^blocks\.(\d+)\..*\.down', (99999,)),
(r'^blocks\.(\d+)', None) if coarse else (r'^blocks\.(\d+)\.(\d+)', None),
(r'^norm', (99999,))
]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg')
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.stem(x)
x = self.blocks(x)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
x = x.mean(dim=(1, 2))
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _create_sequencer2d(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Sequencer2D models.')
model = build_model_with_cfg(Sequencer2D, variant, pretrained, **kwargs)
return model
# main
@register_model
def sequencer2d_s(pretrained=False, **kwargs):
model_args = dict(
layers=[4, 3, 8, 3],
patch_sizes=[7, 2, 1, 1],
embed_dims=[192, 384, 384, 384],
hidden_sizes=[48, 96, 96, 96],
mlp_ratios=[3.0, 3.0, 3.0, 3.0],
rnn_layer=LSTM2D,
bidirectional=True,
union="cat",
with_fc=True,
**kwargs)
model = _create_sequencer2d('sequencer2d_s', pretrained=pretrained, **model_args)
return model
@register_model
def sequencer2d_m(pretrained=False, **kwargs):
model_args = dict(
layers=[4, 3, 14, 3],
patch_sizes=[7, 2, 1, 1],
embed_dims=[192, 384, 384, 384],
hidden_sizes=[48, 96, 96, 96],
mlp_ratios=[3.0, 3.0, 3.0, 3.0],
rnn_layer=LSTM2D,
bidirectional=True,
union="cat",
with_fc=True,
**kwargs)
model = _create_sequencer2d('sequencer2d_m', pretrained=pretrained, **model_args)
return model
@register_model
def sequencer2d_l(pretrained=False, **kwargs):
model_args = dict(
layers=[8, 8, 16, 4],
patch_sizes=[7, 2, 1, 1],
embed_dims=[192, 384, 384, 384],
hidden_sizes=[48, 96, 96, 96],
mlp_ratios=[3.0, 3.0, 3.0, 3.0],
rnn_layer=LSTM2D,
bidirectional=True,
union="cat",
with_fc=True,
**kwargs)
model = _create_sequencer2d('sequencer2d_l', pretrained=pretrained, **model_args)
return model

@ -0,0 +1,753 @@
""" Swin Transformer V2
A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution`
- https://arxiv.org/abs/2111.09883
Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below
Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman
"""
# --------------------------------------------------------
# Swin Transformer V2
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------
import math
from typing import Tuple, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, named_apply
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert
from .registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
'swinv2_tiny_window8_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth',
input_size=(3, 256, 256)
),
'swinv2_tiny_window16_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth',
input_size=(3, 256, 256)
),
'swinv2_small_window8_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth',
input_size=(3, 256, 256)
),
'swinv2_small_window16_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth',
input_size=(3, 256, 256)
),
'swinv2_base_window8_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth',
input_size=(3, 256, 256)
),
'swinv2_base_window16_256': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth',
input_size=(3, 256, 256)
),
'swinv2_base_window12_192_22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth',
num_classes=21841, input_size=(3, 192, 192)
),
'swinv2_base_window12to16_192to256_22kft1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth',
input_size=(3, 256, 256)
),
'swinv2_base_window12to24_192to384_22kft1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth',
input_size=(3, 384, 384), crop_pct=1.0,
),
'swinv2_large_window12_192_22k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth',
num_classes=21841, input_size=(3, 192, 192)
),
'swinv2_large_window12to16_192to256_22kft1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth',
input_size=(3, 256, 256)
),
'swinv2_large_window12to24_192to384_22kft1k': _cfg(
url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth',
input_size=(3, 384, 384), crop_pct=1.0,
),
}
def window_partition(x, window_size: Tuple[int, int]):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C)
return windows
@register_notrace_function # reason: int argument is a Proxy
def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]):
"""
Args:
windows: (num_windows * B, window_size[0], window_size[1], C)
window_size (Tuple[int, int]): Window size
img_size (Tuple[int, int]): Image size
Returns:
x: (B, H, W, C)
"""
H, W = img_size
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
"""
def __init__(
self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.,
pretrained_window_size=[0, 0]):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.pretrained_window_size = pretrained_window_size
self.num_heads = num_heads
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
# mlp to generate continuous relative position bias
self.cpb_mlp = nn.Sequential(
nn.Linear(2, 512, bias=True),
nn.ReLU(inplace=True),
nn.Linear(512, num_heads, bias=False)
)
# get relative_coords_table
relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32)
relative_coords_table = torch.stack(torch.meshgrid([
relative_coords_h,
relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2
if pretrained_window_size[0] > 0:
relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1)
relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1)
else:
relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1)
relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1)
relative_coords_table *= 8 # normalize to -8, 8
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
torch.abs(relative_coords_table) + 1.0) / math.log2(8)
self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index, persistent=False)
self.qkv = nn.Linear(dim, dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(dim))
self.register_buffer('k_bias', torch.zeros(dim), persistent=False)
self.v_bias = nn.Parameter(torch.zeros(dim))
else:
self.q_bias = None
self.k_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask: Optional[torch.Tensor] = None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
# cosine attention
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
logit_scale = torch.clamp(self.logit_scale, max=math.log(1. / 0.01)).exp()
attn = attn * logit_scale
relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads)
relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class SwinTransformerBlock(nn.Module):
r""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
pretrained_window_size (int): Window size in pretraining.
"""
def __init__(
self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0):
super().__init__()
self.dim = dim
self.input_resolution = to_2tuple(input_resolution)
self.num_heads = num_heads
ws, ss = self._calc_window_shift(window_size, shift_size)
self.window_size: Tuple[int, int] = ws
self.shift_size: Tuple[int, int] = ss
self.window_area = self.window_size[0] * self.window_size[1]
self.mlp_ratio = mlp_ratio
self.attn = WindowAttention(
dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
pretrained_window_size=to_2tuple(pretrained_window_size))
self.norm1 = norm_layer(dim)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.norm2 = norm_layer(dim)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
if any(self.shift_size):
# calculate attention mask for SW-MSA
H, W = self.input_resolution
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
cnt = 0
for h in (
slice(0, -self.window_size[0]),
slice(-self.window_size[0], -self.shift_size[0]),
slice(-self.shift_size[0], None)):
for w in (
slice(0, -self.window_size[1]),
slice(-self.window_size[1], -self.shift_size[1]),
slice(-self.shift_size[1], None)):
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1, self.window_area)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
else:
attn_mask = None
self.register_buffer("attn_mask", attn_mask)
def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]:
target_window_size = to_2tuple(target_window_size)
target_shift_size = to_2tuple(target_shift_size)
window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)]
shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)]
return tuple(window_size), tuple(shift_size)
def _attn(self, x):
H, W = self.input_resolution
B, L, C = x.shape
_assert(L == H * W, "input feature has wrong size")
x = x.view(B, H, W, C)
# cyclic shift
has_shift = any(self.shift_size)
if has_shift:
shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C)
shifted_x = window_reverse(attn_windows, self.window_size, self.input_resolution) # B H' W' C
# reverse cyclic shift
if has_shift:
x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
return x
def forward(self, x):
x = x + self.drop_path1(self.norm1(self._attn(x)))
x = x + self.drop_path2(self.norm2(self.mlp(x)))
return x
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(2 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
_assert(L == H * W, "input feature has wrong size")
_assert(H % 2 == 0, f"x size ({H}*{W}) are not even.")
_assert(W % 2 == 0, f"x size ({H}*{W}) are not even.")
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.reduction(x)
x = self.norm(x)
return x
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
pretrained_window_size (int): Local window size in pre-training.
"""
def __init__(
self, dim, input_resolution, depth, num_heads, window_size,
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.,
norm_layer=nn.LayerNorm, downsample=None, pretrained_window_size=0):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.grad_checkpointing = False
# build blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(
dim=dim, input_resolution=input_resolution,
num_heads=num_heads, window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
pretrained_window_size=pretrained_window_size)
for i in range(depth)])
# patch merging layer
if downsample is not None:
self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = nn.Identity()
def forward(self, x):
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint.checkpoint(blk, x)
else:
x = blk(x)
x = self.downsample(x)
return x
def _init_respostnorm(self):
for blk in self.blocks:
nn.init.constant_(blk.norm1.bias, 0)
nn.init.constant_(blk.norm1.weight, 0)
nn.init.constant_(blk.norm2.bias, 0)
nn.init.constant_(blk.norm2.weight, 0)
class SwinTransformerV2(nn.Module):
r""" Swin Transformer V2
A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution`
- https://arxiv.org/abs/2111.09883
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
num_classes (int): Number of classes for classification head. Default: 1000
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer.
"""
def __init__(
self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg',
embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24),
window_size=7, mlp_ratio=4., qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
norm_layer=nn.LayerNorm, ape=False, patch_norm=True,
pretrained_window_sizes=(0, 0, 0, 0), **kwargs):
super().__init__()
self.num_classes = num_classes
assert global_pool in ('', 'avg')
self.global_pool = global_pool
self.num_layers = len(depths)
self.embed_dim = embed_dim
self.patch_norm = patch_norm
self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
# split image into non-overlapping patches
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
# absolute position embedding
if ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
else:
self.absolute_pos_embed = None
self.pos_drop = nn.Dropout(p=drop_rate)
# stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# build layers
self.layers = nn.ModuleList()
for i_layer in range(self.num_layers):
layer = BasicLayer(
dim=int(embed_dim * 2 ** i_layer),
input_resolution=(
self.patch_embed.grid_size[0] // (2 ** i_layer),
self.patch_embed.grid_size[1] // (2 ** i_layer)),
depth=depths[i_layer],
num_heads=num_heads[i_layer],
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
norm_layer=norm_layer,
downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
pretrained_window_size=pretrained_window_sizes[i_layer]
)
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
for bly in self.layers:
bly._init_respostnorm()
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
@torch.jit.ignore
def no_weight_decay(self):
nod = {'absolute_pos_embed'}
for n, m in self.named_modules():
if any([kw in n for kw in ("cpb_mlp", "logit_scale", 'relative_position_bias_table')]):
nod.add(n)
return nod
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^absolute_pos_embed|patch_embed', # stem and embed
blocks=r'^layers\.(\d+)' if coarse else [
(r'^layers\.(\d+).downsample', (0,)),
(r'^layers\.(\d+)\.\w+\.(\d+)', None),
(r'^norm', (99999,)),
]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for l in self.layers:
l.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg')
self.global_pool = global_pool
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
if self.absolute_pos_embed is not None:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
for layer in self.layers:
x = layer(x)
x = self.norm(x) # B L C
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool == 'avg':
x = x.mean(dim=1)
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def checkpoint_filter_fn(state_dict, model):
out_dict = {}
if 'model' in state_dict:
# For deit models
state_dict = state_dict['model']
for k, v in state_dict.items():
if any([n in k for n in ('relative_position_index', 'relative_coords_table')]):
continue # skip buffers that should not be persistent
out_dict[k] = v
return out_dict
def _create_swin_transformer_v2(variant, pretrained=False, **kwargs):
model = build_model_with_cfg(
SwinTransformerV2, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model
@register_model
def swinv2_tiny_window16_256(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=16, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs)
return _create_swin_transformer_v2('swinv2_tiny_window16_256', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_tiny_window8_256(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=8, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs)
return _create_swin_transformer_v2('swinv2_tiny_window8_256', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_small_window16_256(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=16, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs)
return _create_swin_transformer_v2('swinv2_small_window16_256', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_small_window8_256(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=8, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs)
return _create_swin_transformer_v2('swinv2_small_window8_256', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_base_window16_256(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=16, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
return _create_swin_transformer_v2('swinv2_base_window16_256', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_base_window8_256(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=8, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
return _create_swin_transformer_v2('swinv2_base_window8_256', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_base_window12_192_22k(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs)
return _create_swin_transformer_v2('swinv2_base_window12_192_22k', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_base_window12to16_192to256_22kft1k(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=16, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32),
pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
return _create_swin_transformer_v2(
'swinv2_base_window12to16_192to256_22kft1k', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_base_window12to24_192to384_22kft1k(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=24, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32),
pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
return _create_swin_transformer_v2(
'swinv2_base_window12to24_192to384_22kft1k', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_large_window12_192_22k(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
return _create_swin_transformer_v2('swinv2_large_window12_192_22k', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_large_window12to16_192to256_22kft1k(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=16, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48),
pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
return _create_swin_transformer_v2(
'swinv2_large_window12to16_192to256_22kft1k', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_large_window12to24_192to384_22kft1k(pretrained=False, **kwargs):
"""
"""
model_kwargs = dict(
window_size=24, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48),
pretrained_window_sizes=(12, 12, 12, 6), **kwargs)
return _create_swin_transformer_v2(
'swinv2_large_window12to24_192to384_22kft1k', pretrained=pretrained, **model_kwargs)

@ -34,6 +34,7 @@ from typing import Tuple, Optional, List, Union, Any, Type
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint import torch.utils.checkpoint as checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
@ -41,7 +42,7 @@ from .fx_features import register_notrace_function
from .helpers import build_model_with_cfg, named_apply from .helpers import build_model_with_cfg, named_apply
from .layers import DropPath, Mlp, to_2tuple, _assert from .layers import DropPath, Mlp, to_2tuple, _assert
from .registry import register_model from .registry import register_model
from .vision_transformer import checkpoint_filter_fn
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -51,7 +52,7 @@ def _cfg(url='', **kwargs):
'url': url, 'url': url,
'num_classes': 1000, 'num_classes': 1000,
'input_size': (3, 224, 224), 'input_size': (3, 224, 224),
'pool_size': None, 'pool_size': (7, 7),
'crop_pct': 0.9, 'crop_pct': 0.9,
'interpolation': 'bicubic', 'interpolation': 'bicubic',
'fixed_input_size': True, 'fixed_input_size': True,
@ -64,33 +65,38 @@ def _cfg(url='', **kwargs):
default_cfgs = { default_cfgs = {
'swin_v2_cr_tiny_384': _cfg( 'swinv2_cr_tiny_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0), url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swin_v2_cr_tiny_224': _cfg( 'swinv2_cr_tiny_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9), url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_tiny_ns_224': _cfg( 'swinv2_cr_tiny_ns_224': _cfg(
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth", url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth",
input_size=(3, 224, 224), crop_pct=0.9), input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_small_384': _cfg( 'swinv2_cr_small_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0), url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swin_v2_cr_small_224': _cfg( 'swinv2_cr_small_224': _cfg(
url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth", url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth",
input_size=(3, 224, 224), crop_pct=0.9), input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_base_384': _cfg( 'swinv2_cr_small_ns_224': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0), url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth",
'swin_v2_cr_base_224': _cfg( input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_base_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_base_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9), url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_large_384': _cfg( 'swinv2_cr_base_ns_224': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0),
'swin_v2_cr_large_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9), url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_huge_384': _cfg( 'swinv2_cr_large_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0), url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swin_v2_cr_huge_224': _cfg( 'swinv2_cr_large_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9), url="", input_size=(3, 224, 224), crop_pct=0.9),
'swin_v2_cr_giant_384': _cfg( 'swinv2_cr_huge_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0), url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swin_v2_cr_giant_224': _cfg( 'swinv2_cr_huge_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9),
'swinv2_cr_giant_384': _cfg(
url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)),
'swinv2_cr_giant_224': _cfg(
url="", input_size=(3, 224, 224), crop_pct=0.9), url="", input_size=(3, 224, 224), crop_pct=0.9),
} }
@ -179,14 +185,15 @@ class WindowMultiHeadAttention(nn.Module):
hidden_features=meta_hidden_dim, hidden_features=meta_hidden_dim,
out_features=num_heads, out_features=num_heads,
act_layer=nn.ReLU, act_layer=nn.ReLU,
drop=0.1 # FIXME should there be stochasticity, appears to 'overfit' without? drop=(0.125, 0.) # FIXME should there be stochasticity, appears to 'overfit' without?
) )
self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads))) # NOTE old checkpoints used inverse of logit_scale ('tau') following the paper, see conversion fn
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones(num_heads)))
self._make_pair_wise_relative_positions() self._make_pair_wise_relative_positions()
def _make_pair_wise_relative_positions(self) -> None: def _make_pair_wise_relative_positions(self) -> None:
"""Method initializes the pair-wise relative positions to compute the positional biases.""" """Method initializes the pair-wise relative positions to compute the positional biases."""
device = self.tau.device device = self.logit_scale.device
coordinates = torch.stack(torch.meshgrid([ coordinates = torch.stack(torch.meshgrid([
torch.arange(self.window_size[0], device=device), torch.arange(self.window_size[0], device=device),
torch.arange(self.window_size[1], device=device)]), dim=0).flatten(1) torch.arange(self.window_size[1], device=device)]), dim=0).flatten(1)
@ -245,10 +252,11 @@ class WindowMultiHeadAttention(nn.Module):
query, key, value = qkv.unbind(0) query, key, value = qkv.unbind(0)
# compute attention map with scaled cosine attention # compute attention map with scaled cosine attention
denom = torch.norm(query, dim=-1, keepdim=True) @ torch.norm(key, dim=-1, keepdim=True).transpose(-2, -1) attn = (F.normalize(query, dim=-1) @ F.normalize(key, dim=-1).transpose(-2, -1))
attn = query @ key.transpose(-2, -1) / denom.clamp(min=1e-6) logit_scale = torch.clamp(self.logit_scale.reshape(1, self.num_heads, 1, 1), max=math.log(1. / 0.01)).exp()
attn = attn / self.tau.clamp(min=0.01).reshape(1, self.num_heads, 1, 1) attn = attn * logit_scale
attn = attn + self._relative_positional_encodings() attn = attn + self._relative_positional_encodings()
if mask is not None: if mask is not None:
# Apply mask if utilized # Apply mask if utilized
num_win: int = mask.shape[0] num_win: int = mask.shape[0]
@ -304,6 +312,7 @@ class SwinTransformerBlock(nn.Module):
window_size: Tuple[int, int], window_size: Tuple[int, int],
shift_size: Tuple[int, int] = (0, 0), shift_size: Tuple[int, int] = (0, 0),
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
init_values: Optional[float] = 0,
drop: float = 0.0, drop: float = 0.0,
drop_attn: float = 0.0, drop_attn: float = 0.0,
drop_path: float = 0.0, drop_path: float = 0.0,
@ -317,6 +326,7 @@ class SwinTransformerBlock(nn.Module):
self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size) self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size)
self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size)) self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size))
self.window_area = self.window_size[0] * self.window_size[1] self.window_area = self.window_size[0] * self.window_size[1]
self.init_values: Optional[float] = init_values
# attn branch # attn branch
self.attn = WindowMultiHeadAttention( self.attn = WindowMultiHeadAttention(
@ -345,6 +355,7 @@ class SwinTransformerBlock(nn.Module):
self.norm3 = norm_layer(dim) if extra_norm else nn.Identity() self.norm3 = norm_layer(dim) if extra_norm else nn.Identity()
self._make_attention_mask() self._make_attention_mask()
self.init_weights()
def _calc_window_shift(self, target_window_size): def _calc_window_shift(self, target_window_size):
window_size = [f if f <= w else w for f, w in zip(self.feat_size, target_window_size)] window_size = [f if f <= w else w for f, w in zip(self.feat_size, target_window_size)]
@ -377,6 +388,12 @@ class SwinTransformerBlock(nn.Module):
attn_mask = None attn_mask = None
self.register_buffer("attn_mask", attn_mask, persistent=False) self.register_buffer("attn_mask", attn_mask, persistent=False)
def init_weights(self):
# extra, module specific weight init
if self.init_values is not None:
nn.init.constant_(self.norm1.weight, self.init_values)
nn.init.constant_(self.norm2.weight, self.init_values)
def update_input_size(self, new_window_size: Tuple[int, int], new_feat_size: Tuple[int, int]) -> None: def update_input_size(self, new_window_size: Tuple[int, int], new_feat_size: Tuple[int, int]) -> None:
"""Method updates the image resolution to be processed and window size and so the pair-wise relative positions. """Method updates the image resolution to be processed and window size and so the pair-wise relative positions.
@ -435,7 +452,7 @@ class SwinTransformerBlock(nn.Module):
Returns: Returns:
output (torch.Tensor): Output tensor of the shape [B, C, H, W] output (torch.Tensor): Output tensor of the shape [B, C, H, W]
""" """
# NOTE post-norm branches (op -> norm -> drop) # post-norm branches (op -> norm -> drop)
x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x))) x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x)))
x = x + self.drop_path2(self.norm2(self.mlp(x))) x = x + self.drop_path2(self.norm2(self.mlp(x)))
x = self.norm3(x) # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant) x = self.norm3(x) # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant)
@ -522,6 +539,7 @@ class SwinTransformerStage(nn.Module):
feat_size: Tuple[int, int], feat_size: Tuple[int, int],
window_size: Tuple[int, int], window_size: Tuple[int, int],
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
init_values: Optional[float] = 0.0,
drop: float = 0.0, drop: float = 0.0,
drop_attn: float = 0.0, drop_attn: float = 0.0,
drop_path: Union[List[float], float] = 0.0, drop_path: Union[List[float], float] = 0.0,
@ -552,6 +570,7 @@ class SwinTransformerStage(nn.Module):
window_size=window_size, window_size=window_size,
shift_size=tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]), shift_size=tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]),
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
init_values=init_values,
drop=drop, drop=drop,
drop_attn=drop_attn, drop_attn=drop_attn,
drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path, drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path,
@ -634,6 +653,7 @@ class SwinTransformerV2Cr(nn.Module):
depths: Tuple[int, ...] = (2, 2, 6, 2), depths: Tuple[int, ...] = (2, 2, 6, 2),
num_heads: Tuple[int, ...] = (3, 6, 12, 24), num_heads: Tuple[int, ...] = (3, 6, 12, 24),
mlp_ratio: float = 4.0, mlp_ratio: float = 4.0,
init_values: Optional[float] = 0.,
drop_rate: float = 0.0, drop_rate: float = 0.0,
attn_drop_rate: float = 0.0, attn_drop_rate: float = 0.0,
drop_path_rate: float = 0.0, drop_path_rate: float = 0.0,
@ -674,6 +694,7 @@ class SwinTransformerV2Cr(nn.Module):
num_heads=num_heads, num_heads=num_heads,
window_size=window_size, window_size=window_size,
mlp_ratio=mlp_ratio, mlp_ratio=mlp_ratio,
init_values=init_values,
drop=drop_rate, drop=drop_rate,
drop_attn=attn_drop_rate, drop_attn=attn_drop_rate,
drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])], drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])],
@ -786,6 +807,23 @@ def init_weights(module: nn.Module, name: str = ''):
nn.init.xavier_uniform_(module.weight) nn.init.xavier_uniform_(module.weight)
if module.bias is not None: if module.bias is not None:
nn.init.zeros_(module.bias) nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
module.init_weights()
def checkpoint_filter_fn(state_dict, model):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
out_dict = {}
if 'model' in state_dict:
# For deit models
state_dict = state_dict['model']
for k, v in state_dict.items():
if 'tau' in k:
# convert old tau based checkpoints -> logit_scale (inverse)
v = torch.log(1 / v)
k = k.replace('tau', 'logit_scale')
out_dict[k] = v
return out_dict
def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs): def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs):
@ -796,7 +834,7 @@ def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs):
@register_model @register_model
def swin_v2_cr_tiny_384(pretrained=False, **kwargs): def swinv2_cr_tiny_384(pretrained=False, **kwargs):
"""Swin-T V2 CR @ 384x384, trained ImageNet-1k""" """Swin-T V2 CR @ 384x384, trained ImageNet-1k"""
model_kwargs = dict( model_kwargs = dict(
embed_dim=96, embed_dim=96,
@ -804,11 +842,11 @@ def swin_v2_cr_tiny_384(pretrained=False, **kwargs):
num_heads=(3, 6, 12, 24), num_heads=(3, 6, 12, 24),
**kwargs **kwargs
) )
return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_384', pretrained=pretrained, **model_kwargs) return _create_swin_transformer_v2_cr('swinv2_cr_tiny_384', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def swin_v2_cr_tiny_224(pretrained=False, **kwargs): def swinv2_cr_tiny_224(pretrained=False, **kwargs):
"""Swin-T V2 CR @ 224x224, trained ImageNet-1k""" """Swin-T V2 CR @ 224x224, trained ImageNet-1k"""
model_kwargs = dict( model_kwargs = dict(
embed_dim=96, embed_dim=96,
@ -816,11 +854,11 @@ def swin_v2_cr_tiny_224(pretrained=False, **kwargs):
num_heads=(3, 6, 12, 24), num_heads=(3, 6, 12, 24),
**kwargs **kwargs
) )
return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_224', pretrained=pretrained, **model_kwargs) return _create_swin_transformer_v2_cr('swinv2_cr_tiny_224', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def swin_v2_cr_tiny_ns_224(pretrained=False, **kwargs): def swinv2_cr_tiny_ns_224(pretrained=False, **kwargs):
"""Swin-T V2 CR @ 224x224, trained ImageNet-1k w/ extra stage norms. """Swin-T V2 CR @ 224x224, trained ImageNet-1k w/ extra stage norms.
** Experimental, may make default if results are improved. ** ** Experimental, may make default if results are improved. **
""" """
@ -831,11 +869,11 @@ def swin_v2_cr_tiny_ns_224(pretrained=False, **kwargs):
extra_norm_stage=True, extra_norm_stage=True,
**kwargs **kwargs
) )
return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_ns_224', pretrained=pretrained, **model_kwargs) return _create_swin_transformer_v2_cr('swinv2_cr_tiny_ns_224', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def swin_v2_cr_small_384(pretrained=False, **kwargs): def swinv2_cr_small_384(pretrained=False, **kwargs):
"""Swin-S V2 CR @ 384x384, trained ImageNet-1k""" """Swin-S V2 CR @ 384x384, trained ImageNet-1k"""
model_kwargs = dict( model_kwargs = dict(
embed_dim=96, embed_dim=96,
@ -843,12 +881,12 @@ def swin_v2_cr_small_384(pretrained=False, **kwargs):
num_heads=(3, 6, 12, 24), num_heads=(3, 6, 12, 24),
**kwargs **kwargs
) )
return _create_swin_transformer_v2_cr('swin_v2_cr_small_384', pretrained=pretrained, **model_kwargs return _create_swin_transformer_v2_cr('swinv2_cr_small_384', pretrained=pretrained, **model_kwargs
) )
@register_model @register_model
def swin_v2_cr_small_224(pretrained=False, **kwargs): def swinv2_cr_small_224(pretrained=False, **kwargs):
"""Swin-S V2 CR @ 224x224, trained ImageNet-1k""" """Swin-S V2 CR @ 224x224, trained ImageNet-1k"""
model_kwargs = dict( model_kwargs = dict(
embed_dim=96, embed_dim=96,
@ -856,11 +894,24 @@ def swin_v2_cr_small_224(pretrained=False, **kwargs):
num_heads=(3, 6, 12, 24), num_heads=(3, 6, 12, 24),
**kwargs **kwargs
) )
return _create_swin_transformer_v2_cr('swin_v2_cr_small_224', pretrained=pretrained, **model_kwargs) return _create_swin_transformer_v2_cr('swinv2_cr_small_224', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def swin_v2_cr_base_384(pretrained=False, **kwargs): def swinv2_cr_small_ns_224(pretrained=False, **kwargs):
"""Swin-S V2 CR @ 224x224, trained ImageNet-1k"""
model_kwargs = dict(
embed_dim=96,
depths=(2, 2, 18, 2),
num_heads=(3, 6, 12, 24),
extra_norm_stage=True,
**kwargs
)
return _create_swin_transformer_v2_cr('swinv2_cr_small_ns_224', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_cr_base_384(pretrained=False, **kwargs):
"""Swin-B V2 CR @ 384x384, trained ImageNet-1k""" """Swin-B V2 CR @ 384x384, trained ImageNet-1k"""
model_kwargs = dict( model_kwargs = dict(
embed_dim=128, embed_dim=128,
@ -868,23 +919,36 @@ def swin_v2_cr_base_384(pretrained=False, **kwargs):
num_heads=(4, 8, 16, 32), num_heads=(4, 8, 16, 32),
**kwargs **kwargs
) )
return _create_swin_transformer_v2_cr('swin_v2_cr_base_384', pretrained=pretrained, **model_kwargs) return _create_swin_transformer_v2_cr('swinv2_cr_base_384', pretrained=pretrained, **model_kwargs)
@register_model
def swinv2_cr_base_224(pretrained=False, **kwargs):
"""Swin-B V2 CR @ 224x224, trained ImageNet-1k"""
model_kwargs = dict(
embed_dim=128,
depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32),
**kwargs
)
return _create_swin_transformer_v2_cr('swinv2_cr_base_224', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def swin_v2_cr_base_224(pretrained=False, **kwargs): def swinv2_cr_base_ns_224(pretrained=False, **kwargs):
"""Swin-B V2 CR @ 224x224, trained ImageNet-1k""" """Swin-B V2 CR @ 224x224, trained ImageNet-1k"""
model_kwargs = dict( model_kwargs = dict(
embed_dim=128, embed_dim=128,
depths=(2, 2, 18, 2), depths=(2, 2, 18, 2),
num_heads=(4, 8, 16, 32), num_heads=(4, 8, 16, 32),
extra_norm_stage=True,
**kwargs **kwargs
) )
return _create_swin_transformer_v2_cr('swin_v2_cr_base_224', pretrained=pretrained, **model_kwargs) return _create_swin_transformer_v2_cr('swinv2_cr_base_ns_224', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def swin_v2_cr_large_384(pretrained=False, **kwargs): def swinv2_cr_large_384(pretrained=False, **kwargs):
"""Swin-L V2 CR @ 384x384, trained ImageNet-1k""" """Swin-L V2 CR @ 384x384, trained ImageNet-1k"""
model_kwargs = dict( model_kwargs = dict(
embed_dim=192, embed_dim=192,
@ -892,12 +956,12 @@ def swin_v2_cr_large_384(pretrained=False, **kwargs):
num_heads=(6, 12, 24, 48), num_heads=(6, 12, 24, 48),
**kwargs **kwargs
) )
return _create_swin_transformer_v2_cr('swin_v2_cr_large_384', pretrained=pretrained, **model_kwargs return _create_swin_transformer_v2_cr('swinv2_cr_large_384', pretrained=pretrained, **model_kwargs
) )
@register_model @register_model
def swin_v2_cr_large_224(pretrained=False, **kwargs): def swinv2_cr_large_224(pretrained=False, **kwargs):
"""Swin-L V2 CR @ 224x224, trained ImageNet-1k""" """Swin-L V2 CR @ 224x224, trained ImageNet-1k"""
model_kwargs = dict( model_kwargs = dict(
embed_dim=192, embed_dim=192,
@ -905,11 +969,11 @@ def swin_v2_cr_large_224(pretrained=False, **kwargs):
num_heads=(6, 12, 24, 48), num_heads=(6, 12, 24, 48),
**kwargs **kwargs
) )
return _create_swin_transformer_v2_cr('swin_v2_cr_large_224', pretrained=pretrained, **model_kwargs) return _create_swin_transformer_v2_cr('swinv2_cr_large_224', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def swin_v2_cr_huge_384(pretrained=False, **kwargs): def swinv2_cr_huge_384(pretrained=False, **kwargs):
"""Swin-H V2 CR @ 384x384, trained ImageNet-1k""" """Swin-H V2 CR @ 384x384, trained ImageNet-1k"""
model_kwargs = dict( model_kwargs = dict(
embed_dim=352, embed_dim=352,
@ -918,11 +982,11 @@ def swin_v2_cr_huge_384(pretrained=False, **kwargs):
extra_norm_period=6, extra_norm_period=6,
**kwargs **kwargs
) )
return _create_swin_transformer_v2_cr('swin_v2_cr_huge_384', pretrained=pretrained, **model_kwargs) return _create_swin_transformer_v2_cr('swinv2_cr_huge_384', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def swin_v2_cr_huge_224(pretrained=False, **kwargs): def swinv2_cr_huge_224(pretrained=False, **kwargs):
"""Swin-H V2 CR @ 224x224, trained ImageNet-1k""" """Swin-H V2 CR @ 224x224, trained ImageNet-1k"""
model_kwargs = dict( model_kwargs = dict(
embed_dim=352, embed_dim=352,
@ -931,11 +995,11 @@ def swin_v2_cr_huge_224(pretrained=False, **kwargs):
extra_norm_period=6, extra_norm_period=6,
**kwargs **kwargs
) )
return _create_swin_transformer_v2_cr('swin_v2_cr_huge_224', pretrained=pretrained, **model_kwargs) return _create_swin_transformer_v2_cr('swinv2_cr_huge_224', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def swin_v2_cr_giant_384(pretrained=False, **kwargs): def swinv2_cr_giant_384(pretrained=False, **kwargs):
"""Swin-G V2 CR @ 384x384, trained ImageNet-1k""" """Swin-G V2 CR @ 384x384, trained ImageNet-1k"""
model_kwargs = dict( model_kwargs = dict(
embed_dim=512, embed_dim=512,
@ -944,12 +1008,12 @@ def swin_v2_cr_giant_384(pretrained=False, **kwargs):
extra_norm_period=6, extra_norm_period=6,
**kwargs **kwargs
) )
return _create_swin_transformer_v2_cr('swin_v2_cr_giant_384', pretrained=pretrained, **model_kwargs return _create_swin_transformer_v2_cr('swinv2_cr_giant_384', pretrained=pretrained, **model_kwargs
) )
@register_model @register_model
def swin_v2_cr_giant_224(pretrained=False, **kwargs): def swinv2_cr_giant_224(pretrained=False, **kwargs):
"""Swin-G V2 CR @ 224x224, trained ImageNet-1k""" """Swin-G V2 CR @ 224x224, trained ImageNet-1k"""
model_kwargs = dict( model_kwargs = dict(
embed_dim=512, embed_dim=512,
@ -958,4 +1022,4 @@ def swin_v2_cr_giant_224(pretrained=False, **kwargs):
extra_norm_period=6, extra_norm_period=6,
**kwargs **kwargs
) )
return _create_swin_transformer_v2_cr('swin_v2_cr_giant_224', pretrained=pretrained, **model_kwargs) return _create_swin_transformer_v2_cr('swinv2_cr_giant_224', pretrained=pretrained, **model_kwargs)

@ -23,6 +23,7 @@ import math
import logging import logging
from functools import partial from functools import partial
from collections import OrderedDict from collections import OrderedDict
from typing import Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -107,7 +108,6 @@ default_cfgs = {
'vit_giant_patch14_224': _cfg(url=''), 'vit_giant_patch14_224': _cfg(url=''),
'vit_gigantic_patch14_224': _cfg(url=''), 'vit_gigantic_patch14_224': _cfg(url=''),
'vit_base2_patch32_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),
# patch models, imagenet21k (weights from official Google JAX impl) # patch models, imagenet21k (weights from official Google JAX impl)
'vit_tiny_patch16_224_in21k': _cfg( 'vit_tiny_patch16_224_in21k': _cfg(
@ -171,7 +171,12 @@ default_cfgs = {
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
), ),
# experimental 'vit_base_patch16_rpn_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth'),
# experimental (may be removed)
'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95),
'vit_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95),
'vit_small_patch16_36x1_224': _cfg(url=''), 'vit_small_patch16_36x1_224': _cfg(url=''),
'vit_small_patch16_18x2_224': _cfg(url=''), 'vit_small_patch16_18x2_224': _cfg(url=''),
'vit_base_patch16_18x2_224': _cfg(url=''), 'vit_base_patch16_18x2_224': _cfg(url=''),
@ -229,8 +234,7 @@ class Block(nn.Module):
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim) self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -240,6 +244,36 @@ class Block(nn.Module):
return x return x
class ResPostBlock(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.init_values = init_values
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.norm1 = norm_layer(dim)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.norm2 = norm_layer(dim)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.init_weights()
def init_weights(self):
# NOTE this init overrides that base model init with specific changes for the block type
if self.init_values is not None:
nn.init.constant_(self.norm1.weight, self.init_values)
nn.init.constant_(self.norm2.weight, self.init_values)
def forward(self, x):
x = x + self.drop_path1(self.norm1(self.attn(x)))
x = x + self.drop_path2(self.norm2(self.mlp(x)))
return x
class ParallelBlock(nn.Module): class ParallelBlock(nn.Module):
def __init__( def __init__(
@ -290,8 +324,8 @@ class VisionTransformer(nn.Module):
def __init__( def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None, class_token=True, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='',
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
""" """
Args: Args:
@ -305,33 +339,36 @@ class VisionTransformer(nn.Module):
num_heads (int): number of attention heads num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True qkv_bias (bool): enable bias for qkv if True
representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set init_values: (float): layer-scale init values
class_token (bool): use class token
fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None)
drop_rate (float): dropout rate drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate drop_path_rate (float): stochastic depth rate
weight_init: (str): weight init scheme weight_init (str): weight init scheme
init_values: (float): layer-scale init values
embed_layer (nn.Module): patch embedding layer embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): MLP activation layer act_layer: (nn.Module): MLP activation layer
""" """
super().__init__() super().__init__()
assert global_pool in ('', 'avg', 'token') assert global_pool in ('', 'avg', 'token')
assert class_token or global_pool != 'token'
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU act_layer = act_layer or nn.GELU
self.num_classes = num_classes self.num_classes = num_classes
self.global_pool = global_pool self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1 self.num_tokens = 1 if class_token else 0
self.grad_checkpointing = False self.grad_checkpointing = False
self.patch_embed = embed_layer( self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if self.num_tokens > 0 else None
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, embed_dim) * .02)
self.pos_drop = nn.Dropout(p=drop_rate) self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
@ -340,38 +377,21 @@ class VisionTransformer(nn.Module):
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)]) for i in range(depth)])
use_fc_norm = self.global_pool == 'avg'
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
# Representation layer. Used for original ViT models w/ in21k pretraining.
self.representation_size = representation_size
self.pre_logits = nn.Identity()
if representation_size:
self._reset_representation(representation_size)
# Classifier Head # Classifier Head
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
final_chs = self.representation_size if self.representation_size else self.embed_dim self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity()
if weight_init != 'skip': if weight_init != 'skip':
self.init_weights(weight_init) self.init_weights(weight_init)
def _reset_representation(self, representation_size):
self.representation_size = representation_size
if self.representation_size:
self.pre_logits = nn.Sequential(OrderedDict([
('fc', nn.Linear(self.embed_dim, self.representation_size)),
('act', nn.Tanh())
]))
else:
self.pre_logits = nn.Identity()
def init_weights(self, mode=''): def init_weights(self, mode=''):
assert mode in ('jax', 'jax_nlhb', 'moco', '') assert mode in ('jax', 'jax_nlhb', 'moco', '')
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.pos_embed, std=.02)
nn.init.normal_(self.cls_token, std=1e-6) if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
named_apply(get_init_weights_vit(mode, head_bias), self) named_apply(get_init_weights_vit(mode, head_bias), self)
def _init_weights(self, m): def _init_weights(self, m):
@ -401,19 +421,17 @@ class VisionTransformer(nn.Module):
def get_classifier(self): def get_classifier(self):
return self.head return self.head
def reset_classifier(self, num_classes: int, global_pool=None, representation_size=None): def reset_classifier(self, num_classes: int, global_pool=None):
self.num_classes = num_classes self.num_classes = num_classes
if global_pool is not None: if global_pool is not None:
assert global_pool in ('', 'avg', 'token') assert global_pool in ('', 'avg', 'token')
self.global_pool = global_pool self.global_pool = global_pool
if representation_size is not None: self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self._reset_representation(representation_size)
final_chs = self.representation_size if self.representation_size else self.embed_dim
self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x): def forward_features(self, x):
x = self.patch_embed(x) x = self.patch_embed(x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = self.pos_drop(x + self.pos_embed) x = self.pos_drop(x + self.pos_embed)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x) x = checkpoint_seq(self.blocks, x)
@ -424,9 +442,8 @@ class VisionTransformer(nn.Module):
def forward_head(self, x, pre_logits: bool = False): def forward_head(self, x, pre_logits: bool = False):
if self.global_pool: if self.global_pool:
x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x) x = self.fc_norm(x)
x = self.pre_logits(x)
return x if pre_logits else self.head(x) return x if pre_logits else self.head(x)
def forward(self, x): def forward(self, x):
@ -441,6 +458,8 @@ def init_weights_vit_timm(module: nn.Module, name: str = ''):
trunc_normal_(module.weight, std=.02) trunc_normal_(module.weight, std=.02)
if module.bias is not None: if module.bias is not None:
nn.init.zeros_(module.bias) nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
module.init_weights()
def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.): def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.):
@ -449,9 +468,6 @@ def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0
if name.startswith('head'): if name.startswith('head'):
nn.init.zeros_(module.weight) nn.init.zeros_(module.weight)
nn.init.constant_(module.bias, head_bias) nn.init.constant_(module.bias, head_bias)
elif name.startswith('pre_logits'):
lecun_normal_(module.weight)
nn.init.zeros_(module.bias)
else: else:
nn.init.xavier_uniform_(module.weight) nn.init.xavier_uniform_(module.weight)
if module.bias is not None: if module.bias is not None:
@ -460,6 +476,8 @@ def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0
lecun_normal_(module.weight) lecun_normal_(module.weight)
if module.bias is not None: if module.bias is not None:
nn.init.zeros_(module.bias) nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
module.init_weights()
def init_weights_vit_moco(module: nn.Module, name: str = ''): def init_weights_vit_moco(module: nn.Module, name: str = ''):
@ -473,6 +491,8 @@ def init_weights_vit_moco(module: nn.Module, name: str = ''):
nn.init.xavier_uniform_(module.weight) nn.init.xavier_uniform_(module.weight)
if module.bias is not None: if module.bias is not None:
nn.init.zeros_(module.bias) nn.init.zeros_(module.bias)
elif hasattr(module, 'init_weights'):
module.init_weights()
def get_init_weights_vit(mode='jax', head_bias: float = 0.): def get_init_weights_vit(mode='jax', head_bias: float = 0.):
@ -543,9 +563,10 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights
model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
# model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
for i, block in enumerate(model.blocks.children()): for i, block in enumerate(model.blocks.children()):
block_prefix = f'{prefix}Transformer/encoderblock_{i}/' block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
@ -601,6 +622,9 @@ def checkpoint_filter_fn(state_dict, model):
# To resize pos embedding when using model at different size from pretrained weights # To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed( v = resize_pos_embed(
v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
elif 'pre_logits' in k:
# NOTE representation layer removed as not used in latest 21k/1k pretrained weights
continue
out_dict[k] = v out_dict[k] = v
return out_dict return out_dict
@ -609,21 +633,10 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.') raise RuntimeError('features_only not implemented for Vision Transformer models.')
# NOTE this extra code to support handling of repr size for in21k pretrained models
pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs)
default_num_classes = pretrained_cfg['num_classes']
num_classes = kwargs.get('num_classes', default_num_classes)
repr_size = kwargs.pop('representation_size', None)
if repr_size is not None and num_classes != default_num_classes:
# Remove representation layer if fine-tuning. This may not always be the desired action,
# but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface?
_logger.warning("Removing representation layer for fine-tuning.")
repr_size = None
model = build_model_with_cfg( model = build_model_with_cfg(
VisionTransformer, variant, pretrained, VisionTransformer, variant, pretrained,
pretrained_cfg=pretrained_cfg, pretrained_cfg=pretrained_cfg,
representation_size=repr_size,
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=checkpoint_filter_fn,
pretrained_custom_load='npz' in pretrained_cfg['url'], pretrained_custom_load='npz' in pretrained_cfg['url'],
**kwargs) **kwargs)
@ -696,16 +709,6 @@ def vit_base_patch32_224(pretrained=False, **kwargs):
return model return model
@register_model
def vit_base2_patch32_256(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32)
# FIXME experiment
"""
model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, **kwargs)
model = _create_vision_transformer('vit_base2_patch32_256', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_base_patch32_384(pretrained=False, **kwargs): def vit_base_patch32_384(pretrained=False, **kwargs):
""" ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929).
@ -860,8 +863,7 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs):
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
""" """
model_kwargs = dict( model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
return model return model
@ -872,8 +874,7 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs):
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
""" """
model_kwargs = dict( model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model return model
@ -884,8 +885,7 @@ def vit_base_patch8_224_in21k(pretrained=False, **kwargs):
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
""" """
model_kwargs = dict( model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch8_224_in21k', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_base_patch8_224_in21k', pretrained=pretrained, **model_kwargs)
return model return model
@ -896,8 +896,7 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs):
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
""" """
model_kwargs = dict( model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs)
patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs)
model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs)
return model return model
@ -908,8 +907,7 @@ def vit_large_patch16_224_in21k(pretrained=False, **kwargs):
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer
""" """
model_kwargs = dict( model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs)
model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs)
return model return model
@ -920,8 +918,7 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights
""" """
model_kwargs = dict( model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs)
patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs)
model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs)
return model return model
@ -930,7 +927,6 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs):
def vit_base_patch16_224_sam(pretrained=False, **kwargs): def vit_base_patch16_224_sam(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 """ ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
""" """
# NOTE original SAM weights release worked with representation_size=768
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch16_224_sam', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_base_patch16_224_sam', pretrained=pretrained, **model_kwargs)
return model return model
@ -940,7 +936,6 @@ def vit_base_patch16_224_sam(pretrained=False, **kwargs):
def vit_base_patch32_224_sam(pretrained=False, **kwargs): def vit_base_patch32_224_sam(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 """ ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548
""" """
# NOTE original SAM weights release worked with representation_size=768
model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer('vit_base_patch32_224_sam', pretrained=pretrained, **model_kwargs) model = _create_vision_transformer('vit_base_patch32_224_sam', pretrained=pretrained, **model_kwargs)
return model return model
@ -1002,6 +997,37 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs):
return model return model
# Experimental models below
@register_model
def vit_base_patch32_plus_256(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32+)
"""
model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs)
model = _create_vision_transformer('vit_base_patch32_plus_256', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_plus_240(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16+)
"""
model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs)
model = _create_vision_transformer('vit_base_patch16_plus_240', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_base_patch16_rpn_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ residual post-norm
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, class_token=False,
block_fn=ResPostBlock, global_pool=kwargs.pop('global_pool', 'avg'), **kwargs)
model = _create_vision_transformer('vit_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs)
return model
@register_model @register_model
def vit_small_patch16_36x1_224(pretrained=False, **kwargs): def vit_small_patch16_36x1_224(pretrained=False, **kwargs):
""" ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove. """ ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove.

@ -295,7 +295,7 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs):
ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer.
""" """
backbone = _resnetv2(layers=(3, 4, 9), **kwargs) backbone = _resnetv2(layers=(3, 4, 9), **kwargs)
model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs)
model = _create_vision_transformer_hybrid( model = _create_vision_transformer_hybrid(
'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) 'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs)
return model return model

@ -0,0 +1,558 @@
""" Relative Position Vision Transformer (ViT) in PyTorch
NOTE: these models are experimental / WIP, expect changes
Hacked together by / Copyright 2022, Ross Wightman
"""
import math
import logging
from functools import partial
from collections import OrderedDict
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple
from .registry import register_model
_logger = logging.getLogger(__name__)
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
'vit_relpos_base_patch32_plus_rpn_256': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_replos_base_patch32_plus_rpn_256-sw-dd486f51.pth',
input_size=(3, 256, 256)),
'vit_relpos_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240)),
'vit_relpos_small_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_small_patch16_224-sw-ec2778b4.pth'),
'vit_relpos_medium_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_224-sw-11c174af.pth'),
'vit_relpos_base_patch16_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth'),
'vit_relpos_base_patch16_cls_224': _cfg(
url=''),
'vit_relpos_base_patch16_gapcls_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth'),
'vit_relpos_small_patch16_rpn_224': _cfg(url=''),
'vit_relpos_medium_patch16_rpn_224': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_rpn_224-sw-5d2befd8.pth'),
'vit_relpos_base_patch16_rpn_224': _cfg(url=''),
}
def gen_relative_position_index(win_size: Tuple[int, int], class_token: int = 0) -> torch.Tensor:
# cut and paste w/ modifications from swin / beit codebase
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
window_area = win_size[0] * win_size[1]
coords = torch.stack(torch.meshgrid([torch.arange(win_size[0]), torch.arange(win_size[1])])).flatten(1) # 2, Wh, Ww
relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += win_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += win_size[1] - 1
relative_coords[:, :, 0] *= 2 * win_size[1] - 1
if class_token:
num_relative_distance = (2 * win_size[0] - 1) * (2 * win_size[1] - 1) + 3
relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = num_relative_distance - 3
relative_position_index[0:, 0] = num_relative_distance - 2
relative_position_index[0, 0] = num_relative_distance - 1
else:
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
return relative_position_index
def gen_relative_log_coords(
win_size: Tuple[int, int],
pretrained_win_size: Tuple[int, int] = (0, 0),
mode='swin'
):
# as per official swin-v2 impl, supporting timm swin-v2-cr coords as well
relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32)
relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32)
relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2
if mode == 'swin':
if pretrained_win_size[0] > 0:
relative_coords_table[:, :, 0] /= (pretrained_win_size[0] - 1)
relative_coords_table[:, :, 1] /= (pretrained_win_size[1] - 1)
else:
relative_coords_table[:, :, 0] /= (win_size[0] - 1)
relative_coords_table[:, :, 1] /= (win_size[1] - 1)
relative_coords_table *= 8 # normalize to -8, 8
scale = math.log2(8)
else:
# FIXME we should support a form of normalization (to -1/1) for this mode?
scale = math.log2(math.e)
relative_coords_table = torch.sign(relative_coords_table) * torch.log2(
1.0 + relative_coords_table.abs()) / scale
return relative_coords_table
class RelPosMlp(nn.Module):
def __init__(
self,
window_size,
num_heads=8,
hidden_dim=128,
class_token=False,
mode='cr',
pretrained_window_size=(0, 0)
):
super().__init__()
self.window_size = window_size
self.window_area = self.window_size[0] * self.window_size[1]
self.class_token = 1 if class_token else 0
self.num_heads = num_heads
self.bias_shape = (self.window_area,) * 2 + (num_heads,)
self.apply_sigmoid = mode == 'swin'
mlp_bias = (True, False) if mode == 'swin' else True
self.mlp = Mlp(
2, # x, y
hidden_features=hidden_dim,
out_features=num_heads,
act_layer=nn.ReLU,
bias=mlp_bias,
drop=(0.125, 0.)
)
self.register_buffer(
"relative_position_index",
gen_relative_position_index(window_size),
persistent=False)
# get relative_coords_table
self.register_buffer(
"rel_coords_log",
gen_relative_log_coords(window_size, pretrained_window_size, mode=mode),
persistent=False)
def get_bias(self) -> torch.Tensor:
relative_position_bias = self.mlp(self.rel_coords_log)
if self.relative_position_index is not None:
relative_position_bias = relative_position_bias.view(-1, self.num_heads)[
self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.view(self.bias_shape)
relative_position_bias = relative_position_bias.permute(2, 0, 1)
if self.apply_sigmoid:
relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
if self.class_token:
relative_position_bias = F.pad(relative_position_bias, [self.class_token, 0, self.class_token, 0])
return relative_position_bias.unsqueeze(0).contiguous()
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
return attn + self.get_bias()
class RelPosBias(nn.Module):
def __init__(self, window_size, num_heads, class_token=False):
super().__init__()
self.window_size = window_size
self.window_area = window_size[0] * window_size[1]
self.class_token = 1 if class_token else 0
self.bias_shape = (self.window_area + self.class_token,) * 2 + (num_heads,)
num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * self.class_token
self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads))
self.register_buffer(
"relative_position_index",
gen_relative_position_index(self.window_size, class_token=self.class_token),
persistent=False,
)
self.init_weights()
def init_weights(self):
trunc_normal_(self.relative_position_bias_table, std=.02)
def get_bias(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)]
# win_h * win_w, win_h * win_w, num_heads
relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1)
return relative_position_bias.unsqueeze(0).contiguous()
def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None):
return attn + self.get_bias()
class RelPosAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, rel_pos_cls=None, attn_drop=0., proj_drop=0.):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.rel_pos = rel_pos_cls(num_heads=num_heads) if rel_pos_cls else None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1)) * self.scale
if self.rel_pos is not None:
attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos)
elif shared_rel_pos is not None:
attn = attn + shared_rel_pos
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class RelPosBlock(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, rel_pos_cls=None, init_values=None,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = RelPosAttention(
dim, num_heads, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=drop)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos)))
x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x))))
return x
class ResPostRelPosBlock(nn.Module):
def __init__(
self, dim, num_heads, mlp_ratio=4., qkv_bias=False, rel_pos_cls=None, init_values=None,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.init_values = init_values
self.attn = RelPosAttention(
dim, num_heads, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=drop)
self.norm1 = norm_layer(dim)
self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.norm2 = norm_layer(dim)
self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.init_weights()
def init_weights(self):
# NOTE this init overrides that base model init with specific changes for the block type
if self.init_values is not None:
nn.init.constant_(self.norm1.weight, self.init_values)
nn.init.constant_(self.norm2.weight, self.init_values)
def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None):
x = x + self.drop_path1(self.norm1(self.attn(x, shared_rel_pos=shared_rel_pos)))
x = x + self.drop_path2(self.norm2(self.mlp(x)))
return x
class VisionTransformerRelPos(nn.Module):
""" Vision Transformer w/ Relative Position Bias
Differing from classic vit, this impl
* uses relative position index (swin v1 / beit) or relative log coord + mlp (swin v2) pos embed
* defaults to no class token (can be enabled)
* defaults to global avg pool for head (can be changed)
* layer-scale (residual branch gain) enabled
"""
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg',
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=1e-6,
class_token=False, fc_norm=False, rel_pos_type='mlp', shared_rel_pos=False, rel_pos_dim=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip',
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=RelPosBlock):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
num_classes (int): number of classes for classification head
global_pool (str): type of global pooling for final sequence (default: 'avg')
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
init_values: (float): layer-scale init values
class_token (bool): use class token (default: False)
fc_norm (bool): use pre classifier norm instead of pre-pool
rel_pos_ty pe (str): type of relative position
shared_rel_pos (bool): share relative pos across all blocks
drop_rate (float): dropout rate
attn_drop_rate (float): attention dropout rate
drop_path_rate (float): stochastic depth rate
weight_init (str): weight init scheme
embed_layer (nn.Module): patch embedding layer
norm_layer: (nn.Module): normalization layer
act_layer: (nn.Module): MLP activation layer
"""
super().__init__()
assert global_pool in ('', 'avg', 'token')
assert class_token or global_pool != 'token'
norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
act_layer = act_layer or nn.GELU
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1 if class_token else 0
self.grad_checkpointing = False
self.patch_embed = embed_layer(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
feat_size = self.patch_embed.grid_size
rel_pos_args = dict(window_size=feat_size, class_token=class_token)
if rel_pos_type.startswith('mlp'):
if rel_pos_dim:
rel_pos_args['hidden_dim'] = rel_pos_dim
if 'swin' in rel_pos_type:
rel_pos_args['mode'] = 'swin'
rel_pos_cls = partial(RelPosMlp, **rel_pos_args)
else:
rel_pos_cls = partial(RelPosBias, **rel_pos_args)
self.shared_rel_pos = None
if shared_rel_pos:
self.shared_rel_pos = rel_pos_cls(num_heads=num_heads)
# NOTE shared rel pos currently mutually exclusive w/ per-block, but could support both...
rel_pos_cls = None
self.cls_token = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) if self.num_tokens else None
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
block_fn(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls,
init_values=init_values, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i],
norm_layer=norm_layer, act_layer=act_layer)
for i in range(depth)])
self.norm = norm_layer(embed_dim) if not fc_norm else nn.Identity()
# Classifier Head
self.fc_norm = norm_layer(embed_dim) if fc_norm else nn.Identity()
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if weight_init != 'skip':
self.init_weights(weight_init)
def init_weights(self, mode=''):
assert mode in ('jax', 'moco', '')
if self.cls_token is not None:
nn.init.normal_(self.cls_token, std=1e-6)
# FIXME weight init scheme using PyTorch defaults curently
#named_apply(get_init_weights_vit(mode, head_bias), self)
@torch.jit.ignore
def no_weight_decay(self):
return {'cls_token'}
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^cls_token|patch_embed', # stem and embed
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes: int, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'token')
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
if self.cls_token is not None:
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
shared_rel_pos = self.shared_rel_pos.get_bias() if self.shared_rel_pos is not None else None
for blk in self.blocks:
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint(blk, x, shared_rel_pos=shared_rel_pos)
else:
x = blk(x, shared_rel_pos=shared_rel_pos)
x = self.norm(x)
return x
def forward_head(self, x, pre_logits: bool = False):
if self.global_pool:
x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
x = self.fc_norm(x)
return x if pre_logits else self.head(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _create_vision_transformer_relpos(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(VisionTransformerRelPos, variant, pretrained, **kwargs)
return model
@register_model
def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token
"""
model_kwargs = dict(
patch_size=32, embed_dim=896, depth=12, num_heads=14, block_fn=ResPostRelPosBlock, **kwargs)
model = _create_vision_transformer_relpos(
'vit_relpos_base_patch32_plus_rpn_256', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_relpos_base_patch16_plus_240(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16+) w/ relative log-coord position, no class token
"""
model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, **kwargs)
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_plus_240', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_relpos_small_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ relative log-coord position, no class token
"""
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=6, qkv_bias=False, fc_norm=True, **kwargs)
model = _create_vision_transformer_relpos('vit_relpos_small_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_relpos_medium_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ relative log-coord position, no class token
"""
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=True, **kwargs)
model = _create_vision_transformer_relpos('vit_relpos_medium_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_relpos_base_patch16_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ relative log-coord position, no class token
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, **kwargs)
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ relative log-coord position, class token present
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False,
class_token=True, global_pool='token', **kwargs)
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_cls_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_relpos_base_patch16_gapcls_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ relative log-coord position, class token present
NOTE this config is a bit of a mistake, class token was enabled but global avg-pool w/ fc-norm was not disabled
Leaving here for comparisons w/ a future re-train as it performs quite well.
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, class_token=True, **kwargs)
model = _create_vision_transformer_relpos('vit_relpos_base_patch16_gapcls_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_relpos_small_patch16_rpn_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token
"""
model_kwargs = dict(
patch_size=16, embed_dim=384, depth=12, num_heads=6, qkv_bias=False, block_fn=ResPostRelPosBlock, **kwargs)
model = _create_vision_transformer_relpos(
'vit_relpos_small_patch16_rpn_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_relpos_medium_patch16_rpn_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token
"""
model_kwargs = dict(
patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, block_fn=ResPostRelPosBlock, **kwargs)
model = _create_vision_transformer_relpos(
'vit_relpos_medium_patch16_rpn_224', pretrained=pretrained, **model_kwargs)
return model
@register_model
def vit_relpos_base_patch16_rpn_224(pretrained=False, **kwargs):
""" ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token
"""
model_kwargs = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, block_fn=ResPostRelPosBlock, **kwargs)
model = _create_vision_transformer_relpos(
'vit_relpos_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs)
return model

@ -71,7 +71,7 @@ def create_scheduler(args, optimizer):
elif args.sched == 'multistep': elif args.sched == 'multistep':
lr_scheduler = MultiStepLRScheduler( lr_scheduler = MultiStepLRScheduler(
optimizer, optimizer,
decay_t=args.decay_epochs, decay_t=args.decay_milestones,
decay_rate=args.decay_rate, decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr, warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs, warmup_t=args.warmup_epochs,

@ -34,9 +34,9 @@ def set_jit_fuser(fuser):
torch._C._jit_override_can_fuse_on_gpu(True) torch._C._jit_override_can_fuse_on_gpu(True)
torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_texpr_fuser_enabled(False)
elif fuser == "nvfuser" or fuser == "nvf": elif fuser == "nvfuser" or fuser == "nvf":
os.environ['PYTORCH_CUDA_FUSER_DISABLE_FALLBACK'] = '1' os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK'] = '1'
os.environ['PYTORCH_CUDA_FUSER_DISABLE_FMA'] = '1' #os.environ['PYTORCH_NVFUSER_DISABLE_FMA'] = '1'
os.environ['PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL'] = '0' #os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0'
torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True) torch._C._jit_set_profiling_mode(True)

@ -1 +1 @@
__version__ = '0.6.1' __version__ = '0.7.0.dev0'

@ -156,6 +156,8 @@ parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N',
help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).')
parser.add_argument('--start-epoch', default=None, type=int, metavar='N', parser.add_argument('--start-epoch', default=None, type=int, metavar='N',
help='manual epoch number (useful on restarts)') help='manual epoch number (useful on restarts)')
parser.add_argument('--decay-milestones', default=[30, 60], type=int, nargs='+', metavar="MILESTONES",
help='list of decay epoch indices for multistep lr. must be increasing')
parser.add_argument('--decay-epochs', type=float, default=100, metavar='N', parser.add_argument('--decay-epochs', type=float, default=100, metavar='N',
help='epoch interval to decay LR') help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N', parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',

Loading…
Cancel
Save