diff --git a/README.md b/README.md index e79845b3..4c39c692 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,32 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor ## What's New +### May 13, 2022 +* Official Swin-V2 models and weights added from (https://github.com/microsoft/Swin-Transformer). Cleaned up to support torchscript. +* Some refactoring for existing `timm` Swin-V2-CR impl, will likely do a bit more to bring parts closer to official and decide whether to merge some aspects. +* More Vision Transformer relative position / residual post-norm experiments (all trained on TPU thanks to TRC program) + * `vit_relpos_small_patch16_224` - 81.5 @ 224, 82.5 @ 320 -- rel pos, layer scale, no class token, avg pool + * `vit_relpos_medium_patch16_rpn_224` - 82.3 @ 224, 83.1 @ 320 -- rel pos + res-post-norm, no class token, avg pool + * `vit_relpos_medium_patch16_224` - 82.5 @ 224, 83.3 @ 320 -- rel pos, layer scale, no class token, avg pool + * `vit_relpos_base_patch16_gapcls_224` - 82.8 @ 224, 83.9 @ 320 -- rel pos, layer scale, class token, avg pool (by mistake) +* Bring 512 dim, 8-head 'medium' ViT model variant back to life (after using in a pre DeiT 'small' model for first ViT impl back in 2020) +* Add ViT relative position support for switching btw existing impl and some additions in official Swin-V2 impl for future trials +* Sequencer2D impl (https://arxiv.org/abs/2205.01972), added via PR from author (https://github.com/okojoalg) + +### May 2, 2022 +* Vision Transformer experiments adding Relative Position (Swin-V2 log-coord) (`vision_transformer_relpos.py`) and Residual Post-Norm branches (from Swin-V2) (`vision_transformer*.py`) + * `vit_relpos_base_patch32_plus_rpn_256` - 79.5 @ 256, 80.6 @ 320 -- rel pos + extended width + res-post-norm, no class token, avg pool + * `vit_relpos_base_patch16_224` - 82.5 @ 224, 83.6 @ 320 -- rel pos, layer scale, no class token, avg pool + * `vit_base_patch16_rpn_224` - 82.3 @ 224 -- rel pos + res-post-norm, no class token, avg pool +* Vision Transformer refactor to remove representation layer that was only used in initial vit and rarely used since with newer pretrain (ie `How to Train Your ViT`) +* `vit_*` models support removal of class token, use of global average pool, use of fc_norm (ala beit, mae). + +### April 22, 2022 +* `timm` models are now officially supported in [fast.ai](https://www.fast.ai/)! Just in time for the new Practical Deep Learning course. `timmdocs` documentation link updated to [timm.fast.ai](http://timm.fast.ai/). +* Two more model weights added in the TPU trained [series](https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights). Some In22k pretrain still in progress. + * `seresnext101d_32x8d` - 83.69 @ 224, 84.35 @ 288 + * `seresnextaa101d_32x8d` (anti-aliased w/ AvgPool2d) - 83.85 @ 224, 84.57 @ 288 + ### March 23, 2022 * Add `ParallelBlock` and `LayerScale` option to base vit models to support model configs in [Three things everyone should know about ViT](https://arxiv.org/abs/2203.09795) * `convnext_tiny_hnf` (head norm first) weights trained with (close to) A2 recipe, 82.2% top-1, could do better with more epochs. @@ -375,6 +401,7 @@ A full version of the list below with source links can be found in the [document * ReXNet - https://arxiv.org/abs/2007.00992 * SelecSLS - https://arxiv.org/abs/1907.00837 * Selective Kernel Networks - https://arxiv.org/abs/1903.06586 +* Sequencer2D - https://arxiv.org/abs/2205.01972 * Swin S3 (AutoFormerV2) - https://arxiv.org/abs/2111.14725 * Swin Transformer - https://arxiv.org/abs/2103.14030 * Swin Transformer V2 - https://arxiv.org/abs/2111.09883 @@ -462,7 +489,7 @@ My current [documentation](https://rwightman.github.io/pytorch-image-models/) fo [Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail. -[timmdocs](https://fastai.github.io/timmdocs/) is quickly becoming a much more comprehensive set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs. +[timmdocs](http://timm.fast.ai/) is quickly becoming a much more comprehensive set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs. [paperswithcode](https://paperswithcode.com/lib/timm) is a good resource for browsing the models within `timm`. diff --git a/docs/archived_changes.md b/docs/archived_changes.md index f8d88fd7..36b7b9a1 100644 --- a/docs/archived_changes.md +++ b/docs/archived_changes.md @@ -1,5 +1,134 @@ # Archived Changes +### June 8, 2021 +* Add first ResMLP weights, trained in PyTorch XLA on TPU-VM w/ my XLA branch. 24 block variant, 79.2 top-1. +* Add ResNet51-Q model w/ pretrained weights at 82.36 top-1. + * NFNet inspired block layout with quad layer stem and no maxpool + * Same param count (35.7M) and throughput as ResNetRS-50 but +1.5 top-1 @ 224x224 and +2.5 top-1 at 288x288 + +### May 25, 2021 +* Add LeViT, Visformer, Convit (PR by Aman Arora), Twins (PR by paper authors) transformer models +* Cleanup input_size/img_size override handling and testing for all vision transformer models +* Add `efficientnetv2_rw_m` model and weights (started training before official code). 84.8 top-1, 53M params. + +### May 14, 2021 +* Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl. + * 1k trained variants: `tf_efficientnetv2_s/m/l` + * 21k trained variants: `tf_efficientnetv2_s/m/l_in21k` + * 21k pretrained -> 1k fine-tuned: `tf_efficientnetv2_s/m/l_in21ft1k` + * v2 models w/ v1 scaling: `tf_efficientnetv2_b0` through `b3` + * Rename my prev V2 guess `efficientnet_v2s` -> `efficientnetv2_rw_s` + * Some blank `efficientnetv2_*` models in-place for future native PyTorch training + +### May 5, 2021 +* Add MLP-Mixer models and port pretrained weights from [Google JAX impl](https://github.com/google-research/vision_transformer/tree/linen) +* Add CaiT models and pretrained weights from [FB](https://github.com/facebookresearch/deit) +* Add ResNet-RS models and weights from [TF](https://github.com/tensorflow/tpu/tree/master/models/official/resnet/resnet_rs). Thanks [Aman Arora](https://github.com/amaarora) +* Add CoaT models and weights. Thanks [Mohammed Rizin](https://github.com/morizin) +* Add new ImageNet-21k weights & finetuned weights for TResNet, MobileNet-V3, ViT models. Thanks [mrT](https://github.com/mrT23) +* Add GhostNet models and weights. Thanks [Kai Han](https://github.com/iamhankai) +* Update ByoaNet attention modles + * Improve SA module inits + * Hack together experimental stand-alone Swin based attn module and `swinnet` + * Consistent '26t' model defs for experiments. +* Add improved Efficientnet-V2S (prelim model def) weights. 83.8 top-1. +* WandB logging support + +### April 13, 2021 +* Add Swin Transformer models and weights from https://github.com/microsoft/Swin-Transformer + +### April 12, 2021 +* Add ECA-NFNet-L1 (slimmed down F1 w/ SiLU, 41M params) trained with this code. 84% top-1 @ 320x320. Trained at 256x256. +* Add EfficientNet-V2S model (unverified model definition) weights. 83.3 top-1 @ 288x288. Only trained single res 224. Working on progressive training. +* Add ByoaNet model definition (Bring-your-own-attention) w/ SelfAttention block and corresponding SA/SA-like modules and model defs + * Lambda Networks - https://arxiv.org/abs/2102.08602 + * Bottleneck Transformers - https://arxiv.org/abs/2101.11605 + * Halo Nets - https://arxiv.org/abs/2103.12731 +* Adabelief optimizer contributed by Juntang Zhuang + +### April 1, 2021 +* Add snazzy `benchmark.py` script for bulk `timm` model benchmarking of train and/or inference +* Add Pooling-based Vision Transformer (PiT) models (from https://github.com/naver-ai/pit) + * Merged distilled variant into main for torchscript compatibility + * Some `timm` cleanup/style tweaks and weights have hub download support +* Cleanup Vision Transformer (ViT) models + * Merge distilled (DeiT) model into main so that torchscript can work + * Support updated weight init (defaults to old still) that closer matches original JAX impl (possibly better training from scratch) + * Separate hybrid model defs into different file and add several new model defs to fiddle with, support patch_size != 1 for hybrids + * Fix fine-tuning num_class changes (PiT and ViT) and pos_embed resizing (Vit) with distilled variants + * nn.Sequential for block stack (does not break downstream compat) +* TnT (Transformer-in-Transformer) models contributed by author (from https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT) +* Add RegNetY-160 weights from DeiT teacher model +* Add new NFNet-L0 w/ SE attn (rename `nfnet_l0b`->`nfnet_l0`) weights 82.75 top-1 @ 288x288 +* Some fixes/improvements for TFDS dataset wrapper + +### March 7, 2021 +* First 0.4.x PyPi release w/ NFNets (& related), ByoB (GPU-Efficient, RepVGG, etc). +* Change feature extraction for pre-activation nets (NFNets, ResNetV2) to return features before activation. + +### Feb 18, 2021 +* Add pretrained weights and model variants for NFNet-F* models from [DeepMind Haiku impl](https://github.com/deepmind/deepmind-research/tree/master/nfnets). + * Models are prefixed with `dm_`. They require SAME padding conv, skipinit enabled, and activation gains applied in act fn. + * These models are big, expect to run out of GPU memory. With the GELU activiation + other options, they are roughly 1/2 the inference speed of my SiLU PyTorch optimized `s` variants. + * Original model results are based on pre-processing that is not the same as all other models so you'll see different results in the results csv (once updated). + * Matching the original pre-processing as closely as possible I get these results: + * `dm_nfnet_f6` - 86.352 + * `dm_nfnet_f5` - 86.100 + * `dm_nfnet_f4` - 85.834 + * `dm_nfnet_f3` - 85.676 + * `dm_nfnet_f2` - 85.178 + * `dm_nfnet_f1` - 84.696 + * `dm_nfnet_f0` - 83.464 + +### Feb 16, 2021 +* Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py. + * AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc` + * PyTorch global norm of 1.0 (old behaviour, always norm), `--clip-grad 1.0` + * PyTorch value clipping of 10, `--clip-grad 10. --clip-mode value` + * AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training w/ NFNet/NF-ResNet. + +### Feb 12, 2021 +* Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs + +### Feb 10, 2021 +* More model archs, incl a flexible ByobNet backbone ('Bring-your-own-blocks') + * GPU-Efficient-Networks (https://github.com/idstcv/GPU-Efficient-Networks), impl in `byobnet.py` + * RepVGG (https://github.com/DingXiaoH/RepVGG), impl in `byobnet.py` + * classic VGG (from torchvision, impl in `vgg`) +* Refinements to normalizer layer arg handling and normalizer+act layer handling in some models +* Default AMP mode changed to native PyTorch AMP instead of APEX. Issues not being fixed with APEX. Native works with `--channels-last` and `--torchscript` model training, APEX does not. +* Fix a few bugs introduced since last pypi release + +### Feb 8, 2021 +* Add several ResNet weights with ECA attention. 26t & 50t trained @ 256, test @ 320. 269d train @ 256, fine-tune @320, test @ 352. + * `ecaresnet26t` - 79.88 top-1 @ 320x320, 79.08 @ 256x256 + * `ecaresnet50t` - 82.35 top-1 @ 320x320, 81.52 @ 256x256 + * `ecaresnet269d` - 84.93 top-1 @ 352x352, 84.87 @ 320x320 +* Remove separate tiered (`t`) vs tiered_narrow (`tn`) ResNet model defs, all `tn` changed to `t` and `t` models removed (`seresnext26t_32x4d` only model w/ weights that was removed). +* Support model default_cfgs with separate train vs test resolution `test_input_size` and remove extra `_320` suffix ResNet model defs that were just for test. + +### Jan 30, 2021 +* Add initial "Normalization Free" NF-RegNet-B* and NF-ResNet model definitions based on [paper](https://arxiv.org/abs/2101.08692) + +### Jan 25, 2021 +* Add ResNetV2 Big Transfer (BiT) models w/ ImageNet-1k and 21k weights from https://github.com/google-research/big_transfer +* Add official R50+ViT-B/16 hybrid models + weights from https://github.com/google-research/vision_transformer +* ImageNet-21k ViT weights are added w/ model defs and representation layer (pre logits) support + * NOTE: ImageNet-21k classifier heads were zero'd in original weights, they are only useful for transfer learning +* Add model defs and weights for DeiT Vision Transformer models from https://github.com/facebookresearch/deit +* Refactor dataset classes into ImageDataset/IterableImageDataset + dataset specific parser classes +* Add Tensorflow-Datasets (TFDS) wrapper to allow use of TFDS image classification sets with train script + * Ex: `train.py /data/tfds --dataset tfds/oxford_iiit_pet --val-split test --model resnet50 -b 256 --amp --num-classes 37 --opt adamw --lr 3e-4 --weight-decay .001 --pretrained -j 2` +* Add improved .tar dataset parser that reads images from .tar, folder of .tar files, or .tar within .tar + * Run validation on full ImageNet-21k directly from tar w/ BiT model: `validate.py /data/fall11_whole.tar --model resnetv2_50x1_bitm_in21k --amp` +* Models in this update should be stable w/ possible exception of ViT/BiT, possibility of some regressions with train/val scripts and dataset handling + +### Jan 3, 2021 +* Add SE-ResNet-152D weights + * 256x256 val, 0.94 crop top-1 - 83.75 + * 320x320 val, 1.0 crop - 84.36 +* Update results files + ### Dec 18, 2020 * Add ResNet-101D, ResNet-152D, and ResNet-200D weights trained @ 256x256 * 256x256 val, 0.94 crop (top-1) - 101D (82.33), 152D (83.08), 200D (83.25) diff --git a/docs/changes.md b/docs/changes.md index 6ff50756..d2965e8f 100644 --- a/docs/changes.md +++ b/docs/changes.md @@ -1,130 +1,130 @@ # Recent Changes -### June 8, 2021 -* Add first ResMLP weights, trained in PyTorch XLA on TPU-VM w/ my XLA branch. 24 block variant, 79.2 top-1. -* Add ResNet51-Q model w/ pretrained weights at 82.36 top-1. - * NFNet inspired block layout with quad layer stem and no maxpool - * Same param count (35.7M) and throughput as ResNetRS-50 but +1.5 top-1 @ 224x224 and +2.5 top-1 at 288x288 -### May 25, 2021 -* Add LeViT, Visformer, Convit (PR by Aman Arora), Twins (PR by paper authors) transformer models -* Cleanup input_size/img_size override handling and testing for all vision transformer models -* Add `efficientnetv2_rw_m` model and weights (started training before official code). 84.8 top-1, 53M params. - -### May 14, 2021 -* Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl. - * 1k trained variants: `tf_efficientnetv2_s/m/l` - * 21k trained variants: `tf_efficientnetv2_s/m/l_in21k` - * 21k pretrained -> 1k fine-tuned: `tf_efficientnetv2_s/m/l_in21ft1k` - * v2 models w/ v1 scaling: `tf_efficientnetv2_b0` through `b3` - * Rename my prev V2 guess `efficientnet_v2s` -> `efficientnetv2_rw_s` - * Some blank `efficientnetv2_*` models in-place for future native PyTorch training - -### May 5, 2021 -* Add MLP-Mixer models and port pretrained weights from [Google JAX impl](https://github.com/google-research/vision_transformer/tree/linen) -* Add CaiT models and pretrained weights from [FB](https://github.com/facebookresearch/deit) -* Add ResNet-RS models and weights from [TF](https://github.com/tensorflow/tpu/tree/master/models/official/resnet/resnet_rs). Thanks [Aman Arora](https://github.com/amaarora) -* Add CoaT models and weights. Thanks [Mohammed Rizin](https://github.com/morizin) -* Add new ImageNet-21k weights & finetuned weights for TResNet, MobileNet-V3, ViT models. Thanks [mrT](https://github.com/mrT23) -* Add GhostNet models and weights. Thanks [Kai Han](https://github.com/iamhankai) -* Update ByoaNet attention modles - * Improve SA module inits - * Hack together experimental stand-alone Swin based attn module and `swinnet` - * Consistent '26t' model defs for experiments. -* Add improved Efficientnet-V2S (prelim model def) weights. 83.8 top-1. -* WandB logging support - -### April 13, 2021 -* Add Swin Transformer models and weights from https://github.com/microsoft/Swin-Transformer - -### April 12, 2021 -* Add ECA-NFNet-L1 (slimmed down F1 w/ SiLU, 41M params) trained with this code. 84% top-1 @ 320x320. Trained at 256x256. -* Add EfficientNet-V2S model (unverified model definition) weights. 83.3 top-1 @ 288x288. Only trained single res 224. Working on progressive training. -* Add ByoaNet model definition (Bring-your-own-attention) w/ SelfAttention block and corresponding SA/SA-like modules and model defs - * Lambda Networks - https://arxiv.org/abs/2102.08602 - * Bottleneck Transformers - https://arxiv.org/abs/2101.11605 - * Halo Nets - https://arxiv.org/abs/2103.12731 -* Adabelief optimizer contributed by Juntang Zhuang - -### April 1, 2021 -* Add snazzy `benchmark.py` script for bulk `timm` model benchmarking of train and/or inference -* Add Pooling-based Vision Transformer (PiT) models (from https://github.com/naver-ai/pit) - * Merged distilled variant into main for torchscript compatibility - * Some `timm` cleanup/style tweaks and weights have hub download support -* Cleanup Vision Transformer (ViT) models - * Merge distilled (DeiT) model into main so that torchscript can work - * Support updated weight init (defaults to old still) that closer matches original JAX impl (possibly better training from scratch) - * Separate hybrid model defs into different file and add several new model defs to fiddle with, support patch_size != 1 for hybrids - * Fix fine-tuning num_class changes (PiT and ViT) and pos_embed resizing (Vit) with distilled variants - * nn.Sequential for block stack (does not break downstream compat) -* TnT (Transformer-in-Transformer) models contributed by author (from https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT) -* Add RegNetY-160 weights from DeiT teacher model -* Add new NFNet-L0 w/ SE attn (rename `nfnet_l0b`->`nfnet_l0`) weights 82.75 top-1 @ 288x288 -* Some fixes/improvements for TFDS dataset wrapper - -### March 7, 2021 -* First 0.4.x PyPi release w/ NFNets (& related), ByoB (GPU-Efficient, RepVGG, etc). -* Change feature extraction for pre-activation nets (NFNets, ResNetV2) to return features before activation. - -### Feb 18, 2021 -* Add pretrained weights and model variants for NFNet-F* models from [DeepMind Haiku impl](https://github.com/deepmind/deepmind-research/tree/master/nfnets). - * Models are prefixed with `dm_`. They require SAME padding conv, skipinit enabled, and activation gains applied in act fn. - * These models are big, expect to run out of GPU memory. With the GELU activiation + other options, they are roughly 1/2 the inference speed of my SiLU PyTorch optimized `s` variants. - * Original model results are based on pre-processing that is not the same as all other models so you'll see different results in the results csv (once updated). - * Matching the original pre-processing as closely as possible I get these results: - * `dm_nfnet_f6` - 86.352 - * `dm_nfnet_f5` - 86.100 - * `dm_nfnet_f4` - 85.834 - * `dm_nfnet_f3` - 85.676 - * `dm_nfnet_f2` - 85.178 - * `dm_nfnet_f1` - 84.696 - * `dm_nfnet_f0` - 83.464 - -### Feb 16, 2021 -* Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py. - * AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc` - * PyTorch global norm of 1.0 (old behaviour, always norm), `--clip-grad 1.0` - * PyTorch value clipping of 10, `--clip-grad 10. --clip-mode value` - * AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training w/ NFNet/NF-ResNet. - -### Feb 12, 2021 -* Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs - -### Feb 10, 2021 -* More model archs, incl a flexible ByobNet backbone ('Bring-your-own-blocks') - * GPU-Efficient-Networks (https://github.com/idstcv/GPU-Efficient-Networks), impl in `byobnet.py` - * RepVGG (https://github.com/DingXiaoH/RepVGG), impl in `byobnet.py` - * classic VGG (from torchvision, impl in `vgg`) -* Refinements to normalizer layer arg handling and normalizer+act layer handling in some models -* Default AMP mode changed to native PyTorch AMP instead of APEX. Issues not being fixed with APEX. Native works with `--channels-last` and `--torchscript` model training, APEX does not. -* Fix a few bugs introduced since last pypi release - -### Feb 8, 2021 -* Add several ResNet weights with ECA attention. 26t & 50t trained @ 256, test @ 320. 269d train @ 256, fine-tune @320, test @ 352. - * `ecaresnet26t` - 79.88 top-1 @ 320x320, 79.08 @ 256x256 - * `ecaresnet50t` - 82.35 top-1 @ 320x320, 81.52 @ 256x256 - * `ecaresnet269d` - 84.93 top-1 @ 352x352, 84.87 @ 320x320 -* Remove separate tiered (`t`) vs tiered_narrow (`tn`) ResNet model defs, all `tn` changed to `t` and `t` models removed (`seresnext26t_32x4d` only model w/ weights that was removed). -* Support model default_cfgs with separate train vs test resolution `test_input_size` and remove extra `_320` suffix ResNet model defs that were just for test. - -### Jan 30, 2021 -* Add initial "Normalization Free" NF-RegNet-B* and NF-ResNet model definitions based on [paper](https://arxiv.org/abs/2101.08692) - -### Jan 25, 2021 -* Add ResNetV2 Big Transfer (BiT) models w/ ImageNet-1k and 21k weights from https://github.com/google-research/big_transfer -* Add official R50+ViT-B/16 hybrid models + weights from https://github.com/google-research/vision_transformer -* ImageNet-21k ViT weights are added w/ model defs and representation layer (pre logits) support - * NOTE: ImageNet-21k classifier heads were zero'd in original weights, they are only useful for transfer learning -* Add model defs and weights for DeiT Vision Transformer models from https://github.com/facebookresearch/deit -* Refactor dataset classes into ImageDataset/IterableImageDataset + dataset specific parser classes -* Add Tensorflow-Datasets (TFDS) wrapper to allow use of TFDS image classification sets with train script - * Ex: `train.py /data/tfds --dataset tfds/oxford_iiit_pet --val-split test --model resnet50 -b 256 --amp --num-classes 37 --opt adamw --lr 3e-4 --weight-decay .001 --pretrained -j 2` -* Add improved .tar dataset parser that reads images from .tar, folder of .tar files, or .tar within .tar - * Run validation on full ImageNet-21k directly from tar w/ BiT model: `validate.py /data/fall11_whole.tar --model resnetv2_50x1_bitm_in21k --amp` -* Models in this update should be stable w/ possible exception of ViT/BiT, possibility of some regressions with train/val scripts and dataset handling - -### Jan 3, 2021 -* Add SE-ResNet-152D weights - * 256x256 val, 0.94 crop top-1 - 83.75 - * 320x320 val, 1.0 crop - 84.36 -* Update results files +### March 23, 2022 +* Add `ParallelBlock` and `LayerScale` option to base vit models to support model configs in [Three things everyone should know about ViT](https://arxiv.org/abs/2203.09795) +* `convnext_tiny_hnf` (head norm first) weights trained with (close to) A2 recipe, 82.2% top-1, could do better with more epochs. + +### March 21, 2022 +* Merge `norm_norm_norm`. **IMPORTANT** this update for a coming 0.6.x release will likely de-stabilize the master branch for a while. Branch [`0.5.x`](https://github.com/rwightman/pytorch-image-models/tree/0.5.x) or a previous 0.5.x release can be used if stability is required. +* Significant weights update (all TPU trained) as described in this [release](https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights) + * `regnety_040` - 82.3 @ 224, 82.96 @ 288 + * `regnety_064` - 83.0 @ 224, 83.65 @ 288 + * `regnety_080` - 83.17 @ 224, 83.86 @ 288 + * `regnetv_040` - 82.44 @ 224, 83.18 @ 288 (timm pre-act) + * `regnetv_064` - 83.1 @ 224, 83.71 @ 288 (timm pre-act) + * `regnetz_040` - 83.67 @ 256, 84.25 @ 320 + * `regnetz_040h` - 83.77 @ 256, 84.5 @ 320 (w/ extra fc in head) + * `resnetv2_50d_gn` - 80.8 @ 224, 81.96 @ 288 (pre-act GroupNorm) + * `resnetv2_50d_evos` 80.77 @ 224, 82.04 @ 288 (pre-act EvoNormS) + * `regnetz_c16_evos` - 81.9 @ 256, 82.64 @ 320 (EvoNormS) + * `regnetz_d8_evos` - 83.42 @ 256, 84.04 @ 320 (EvoNormS) + * `xception41p` - 82 @ 299 (timm pre-act) + * `xception65` - 83.17 @ 299 + * `xception65p` - 83.14 @ 299 (timm pre-act) + * `resnext101_64x4d` - 82.46 @ 224, 83.16 @ 288 + * `seresnext101_32x8d` - 83.57 @ 224, 84.270 @ 288 + * `resnetrs200` - 83.85 @ 256, 84.44 @ 320 +* HuggingFace hub support fixed w/ initial groundwork for allowing alternative 'config sources' for pretrained model definitions and weights (generic local file / remote url support soon) +* SwinTransformer-V2 implementation added. Submitted by [Christoph Reich](https://github.com/ChristophReich1996). Training experiments and model changes by myself are ongoing so expect compat breaks. +* Swin-S3 (AutoFormerV2) models / weights added from https://github.com/microsoft/Cream/tree/main/AutoFormerV2 +* MobileViT models w/ weights adapted from https://github.com/apple/ml-cvnets +* PoolFormer models w/ weights adapted from https://github.com/sail-sg/poolformer +* VOLO models w/ weights adapted from https://github.com/sail-sg/volo +* Significant work experimenting with non-BatchNorm norm layers such as EvoNorm, FilterResponseNorm, GroupNorm, etc +* Enhance support for alternate norm + act ('NormAct') layers added to a number of models, esp EfficientNet/MobileNetV3, RegNet, and aligned Xception +* Grouped conv support added to EfficientNet family +* Add 'group matching' API to all models to allow grouping model parameters for application of 'layer-wise' LR decay, lr scale added to LR scheduler +* Gradient checkpointing support added to many models +* `forward_head(x, pre_logits=False)` fn added to all models to allow separate calls of `forward_features` + `forward_head` +* All vision transformer and vision MLP models update to return non-pooled / non-token selected features from `foward_features`, for consistency with CNN models, token selection or pooling now applied in `forward_head` + +### Feb 2, 2022 +* [Chris Hughes](https://github.com/Chris-hughes10) posted an exhaustive run through of `timm` on his blog yesterday. Well worth a read. [Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) +* I'm currently prepping to merge the `norm_norm_norm` branch back to master (ver 0.6.x) in next week or so. + * The changes are more extensive than usual and may destabilize and break some model API use (aiming for full backwards compat). So, beware `pip install git+https://github.com/rwightman/pytorch-image-models` installs! + * `0.5.x` releases and a `0.5.x` branch will remain stable with a cherry pick or two until dust clears. Recommend sticking to pypi install for a bit if you want stable. + +### Jan 14, 2022 +* Version 0.5.4 w/ release to be pushed to pypi. It's been a while since last pypi update and riskier changes will be merged to main branch soon.... +* Add ConvNeXT models /w weights from official impl (https://github.com/facebookresearch/ConvNeXt), a few perf tweaks, compatible with timm features +* Tried training a few small (~1.8-3M param) / mobile optimized models, a few are good so far, more on the way... + * `mnasnet_small` - 65.6 top-1 + * `mobilenetv2_050` - 65.9 + * `lcnet_100/075/050` - 72.1 / 68.8 / 63.1 + * `semnasnet_075` - 73 + * `fbnetv3_b/d/g` - 79.1 / 79.7 / 82.0 +* TinyNet models added by [rsomani95](https://github.com/rsomani95) +* LCNet added via MobileNetV3 architecture + +### Nov 22, 2021 +* A number of updated weights anew new model defs + * `eca_halonext26ts` - 79.5 @ 256 + * `resnet50_gn` (new) - 80.1 @ 224, 81.3 @ 288 + * `resnet50` - 80.7 @ 224, 80.9 @ 288 (trained at 176, not replacing current a1 weights as default since these don't scale as well to higher res, [weights](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1h2_176-001a1197.pth)) + * `resnext50_32x4d` - 81.1 @ 224, 82.0 @ 288 + * `sebotnet33ts_256` (new) - 81.2 @ 224 + * `lamhalobotnet50ts_256` - 81.5 @ 256 + * `halonet50ts` - 81.7 @ 256 + * `halo2botnet50ts_256` - 82.0 @ 256 + * `resnet101` - 82.0 @ 224, 82.8 @ 288 + * `resnetv2_101` (new) - 82.1 @ 224, 83.0 @ 288 + * `resnet152` - 82.8 @ 224, 83.5 @ 288 + * `regnetz_d8` (new) - 83.5 @ 256, 84.0 @ 320 + * `regnetz_e8` (new) - 84.5 @ 256, 85.0 @ 320 +* `vit_base_patch8_224` (85.8 top-1) & `in21k` variant weights added thanks [Martins Bruveris](https://github.com/martinsbruveris) +* Groundwork in for FX feature extraction thanks to [Alexander Soare](https://github.com/alexander-soare) + * models updated for tracing compatibility (almost full support with some distlled transformer exceptions) + +### Oct 19, 2021 +* ResNet strikes back (https://arxiv.org/abs/2110.00476) weights added, plus any extra training components used. Model weights and some more details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-rsb-weights) +* BCE loss and Repeated Augmentation support for RSB paper +* 4 series of ResNet based attention model experiments being added (implemented across byobnet.py/byoanet.py). These include all sorts of attention, from channel attn like SE, ECA to 2D QKV self-attention layers such as Halo, Bottlneck, Lambda. Details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights) +* Working implementations of the following 2D self-attention modules (likely to be differences from paper or eventual official impl): + * Halo (https://arxiv.org/abs/2103.12731) + * Bottleneck Transformer (https://arxiv.org/abs/2101.11605) + * LambdaNetworks (https://arxiv.org/abs/2102.08602) +* A RegNetZ series of models with some attention experiments (being added to). These do not follow the paper (https://arxiv.org/abs/2103.06877) in any way other than block architecture, details of official models are not available. See more here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights) +* ConvMixer (https://openreview.net/forum?id=TVHS5Y4dNvM), CrossVit (https://arxiv.org/abs/2103.14899), and BeiT (https://arxiv.org/abs/2106.08254) architectures + weights added +* freeze/unfreeze helpers by [Alexander Soare](https://github.com/alexander-soare) + +### Aug 18, 2021 +* Optimizer bonanza! + * Add LAMB and LARS optimizers, incl trust ratio clipping options. Tweaked to work properly in PyTorch XLA (tested on TPUs w/ `timm bits` [branch](https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/timm/bits)) + * Add MADGRAD from FB research w/ a few tweaks (decoupled decay option, step handling that works with PyTorch XLA) + * Some cleanup on all optimizers and factory. No more `.data`, a bit more consistency, unit tests for all! + * SGDP and AdamP still won't work with PyTorch XLA but others should (have yet to test Adabelief, Adafactor, Adahessian myself). +* EfficientNet-V2 XL TF ported weights added, but they don't validate well in PyTorch (L is better). The pre-processing for the V2 TF training is a bit diff and the fine-tuned 21k -> 1k weights are very sensitive and less robust than the 1k weights. +* Added PyTorch trained EfficientNet-V2 'Tiny' w/ GlobalContext attn weights. Only .1-.2 top-1 better than the SE so more of a curiosity for those interested. + +### July 12, 2021 +* Add XCiT models from [official facebook impl](https://github.com/facebookresearch/xcit). Contributed by [Alexander Soare](https://github.com/alexander-soare) + +### July 5-9, 2021 +* Add `efficientnetv2_rw_t` weights, a custom 'tiny' 13.6M param variant that is a bit better than (non NoisyStudent) B3 models. Both faster and better accuracy (at same or lower res) + * top-1 82.34 @ 288x288 and 82.54 @ 320x320 +* Add [SAM pretrained](https://arxiv.org/abs/2106.01548) in1k weight for ViT B/16 (`vit_base_patch16_sam_224`) and B/32 (`vit_base_patch32_sam_224`) models. +* Add 'Aggregating Nested Transformer' (NesT) w/ weights converted from official [Flax impl](https://github.com/google-research/nested-transformer). Contributed by [Alexander Soare](https://github.com/alexander-soare). + * `jx_nest_base` - 83.534, `jx_nest_small` - 83.120, `jx_nest_tiny` - 81.426 + +### June 23, 2021 +* Reproduce gMLP model training, `gmlp_s16_224` trained to 79.6 top-1, matching [paper](https://arxiv.org/abs/2105.08050). Hparams for this and other recent MLP training [here](https://gist.github.com/rwightman/d6c264a9001f9167e06c209f630b2cc6) + +### June 20, 2021 +* Release Vision Transformer 'AugReg' weights from [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270) + * .npz weight loading support added, can load any of the 50K+ weights from the [AugReg series](https://console.cloud.google.com/storage/browser/vit_models/augreg) + * See [example notebook](https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax_augreg.ipynb) from [official impl](https://github.com/google-research/vision_transformer/) for navigating the augreg weights + * Replaced all default weights w/ best AugReg variant (if possible). All AugReg 21k classifiers work. + * Highlights: `vit_large_patch16_384` (87.1 top-1), `vit_large_r50_s32_384` (86.2 top-1), `vit_base_patch16_384` (86.0 top-1) + * `vit_deit_*` renamed to just `deit_*` + * Remove my old small model, replace with DeiT compatible small w/ AugReg weights +* Add 1st training of my `gmixer_24_224` MLP /w GLU, 78.1 top-1 w/ 25M params. +* Add weights from official ResMLP release (https://github.com/facebookresearch/deit) +* Add `eca_nfnet_l2` weights from my 'lightweight' series. 84.7 top-1 at 384x384. +* Add distilled BiT 50x1 student and 152x2 Teacher weights from [Knowledge distillation: A good teacher is patient and consistent](https://arxiv.org/abs/2106.05237) +* NFNets and ResNetV2-BiT models work w/ Pytorch XLA now + * weight standardization uses F.batch_norm instead of std_mean (std_mean wasn't lowered) + * eps values adjusted, will be slight differences but should be quite close +* Improve test coverage and classifier interface of non-conv (vision transformer and mlp) models +* Cleanup a few classifier / flatten details for models w/ conv classifiers or early global pool +* Please report any regressions, this PR touched quite a few models. diff --git a/docs/index.md b/docs/index.md index 95f7df64..e022f891 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,7 +4,7 @@ Welcome to the `timm` documentation, a lean set of docs that covers the basics of `timm`. -For a more comprehensive set of docs (currently under development), please visit [timmdocs](https://fastai.github.io/timmdocs/) by [Aman Arora](https://github.com/amaarora). +For a more comprehensive set of docs (currently under development), please visit [timmdocs](http://timm.fast.ai) by [Aman Arora](https://github.com/amaarora). ## Install @@ -20,17 +20,17 @@ pip install git+https://github.com/rwightman/pytorch-image-models.git ``` !!! info "Conda Environment" - All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically Python 3.6.x, 3.7.x., 3.8.x., 3.9 + All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically 3.7, 3.8, 3.9, 3.10 Little to no care has been taken to be Python 2.x friendly and will not support it. If you run into any challenges running on Windows, or other OS, I'm definitely open to looking into those issues so long as it's in a reproducible (read Conda) environment. - PyTorch versions 1.4, 1.5.x, 1.6, 1.7.x, and 1.8 have been tested with this code. + PyTorch versions 1.9, 1.10, 1.11 have been tested with the latest versions of this code. I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda: ``` conda create -n torch-env conda activate torch-env - conda install pytorch torchvision cudatoolkit=11.1 -c pytorch -c conda-forge + conda install pytorch torchvision cudatoolkit=11.3 -c pytorch conda install pyyaml ``` diff --git a/docs/scripts.md b/docs/scripts.md index f48eec0d..0fbf18f1 100644 --- a/docs/scripts.md +++ b/docs/scripts.md @@ -12,7 +12,7 @@ To train an SE-ResNet34 on ImageNet, locally distributed, 4 GPUs, one process pe `./distributed_train.sh 4 /data/imagenet --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 --amp -j 4` -NOTE: It is recommended to use PyTorch 1.7+ w/ PyTorch native AMP and DDP instead of APEX AMP. `--amp` defaults to native AMP as of timm ver 0.4.3. `--apex-amp` will force use of APEX components if they are installed. +NOTE: It is recommended to use PyTorch 1.9+ w/ PyTorch native AMP and DDP instead of APEX AMP. `--amp` defaults to native AMP as of timm ver 0.4.3. `--apex-amp` will force use of APEX components if they are installed. ## Validation / Inference Scripts diff --git a/tests/test_models.py b/tests/test_models.py index 4e50de6e..7ea9af6e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -25,7 +25,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*', - 'poolformer_*', 'volo_*'] + 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*'] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures @@ -202,13 +202,15 @@ def test_model_default_cfgs_non_std(model_name, batch_size): pytest.skip("Fixed input size model > limit.") input_tensor = torch.randn((batch_size, *input_size)) + feat_dim = getattr(model, 'feature_dim', None) outputs = model.forward_features(input_tensor) if isinstance(outputs, (tuple, list)): # cannot currently verify multi-tensor output. pass else: - feat_dim = -1 if outputs.ndim == 3 else 1 + if feat_dim is None: + feat_dim = -1 if outputs.ndim == 3 else 1 assert outputs.shape[feat_dim] == model.num_features # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features @@ -216,14 +218,16 @@ def test_model_default_cfgs_non_std(model_name, batch_size): outputs = model.forward(input_tensor) if isinstance(outputs, (tuple, list)): outputs = outputs[0] - feat_dim = -1 if outputs.ndim == 3 else 1 + if feat_dim is None: + feat_dim = -1 if outputs.ndim == 3 else 1 assert outputs.shape[feat_dim] == model.num_features, 'pooled num_features != config' model = create_model(model_name, pretrained=False, num_classes=0).eval() outputs = model.forward(input_tensor) if isinstance(outputs, (tuple, list)): outputs = outputs[0] - feat_dim = -1 if outputs.ndim == 3 else 1 + if feat_dim is None: + feat_dim = -1 if outputs.ndim == 3 else 1 assert outputs.shape[feat_dim] == model.num_features # check classifier name matches default_cfg diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 45ead5dc..8cb6c70a 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -39,8 +39,10 @@ from .resnetv2 import * from .rexnet import * from .selecsls import * from .senet import * +from .sequencer import * from .sknet import * from .swin_transformer import * +from .swin_transformer_v2 import * from .swin_transformer_v2_cr import * from .tnt import * from .tresnet import * @@ -49,6 +51,7 @@ from .vgg import * from .visformer import * from .vision_transformer import * from .vision_transformer_hybrid import * +from .vision_transformer_relpos import * from .volo import * from .vovnet import * from .xception import * diff --git a/timm/models/beit.py b/timm/models/beit.py index 557ac9b0..a56653dd 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -46,27 +46,27 @@ def _cfg(url='', **kwargs): default_cfgs = { 'beit_base_patch16_224': _cfg( - url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'), + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'), 'beit_base_patch16_384': _cfg( - url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth', + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth', input_size=(3, 384, 384), crop_pct=1.0, ), 'beit_base_patch16_224_in22k': _cfg( - url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22k.pth', + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth', num_classes=21841, ), 'beit_large_patch16_224': _cfg( - url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'), + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'), 'beit_large_patch16_384': _cfg( - url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth', + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth', input_size=(3, 384, 384), crop_pct=1.0, ), 'beit_large_patch16_512': _cfg( - url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth', + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth', input_size=(3, 512, 512), crop_pct=1.0, ), 'beit_large_patch16_224_in22k': _cfg( - url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_224_pt22k_ft22k.pth', + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth', num_classes=21841, ), } diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 9fd4525a..1aacef2b 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -421,14 +421,14 @@ def convnext_large(pretrained=False, **kwargs): @register_model def convnext_tiny_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) model = _create_convnext('convnext_tiny_in22ft1k', pretrained=pretrained, **model_args) return model @register_model def convnext_small_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) model = _create_convnext('convnext_small_in22ft1k', pretrained=pretrained, **model_args) return model @@ -456,14 +456,14 @@ def convnext_xlarge_in22ft1k(pretrained=False, **kwargs): @register_model def convnext_tiny_384_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) model = _create_convnext('convnext_tiny_384_in22ft1k', pretrained=pretrained, **model_args) return model @register_model def convnext_small_384_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) model = _create_convnext('convnext_small_384_in22ft1k', pretrained=pretrained, **model_args) return model @@ -491,14 +491,14 @@ def convnext_xlarge_384_in22ft1k(pretrained=False, **kwargs): @register_model def convnext_tiny_in22k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) model = _create_convnext('convnext_tiny_in22k', pretrained=pretrained, **model_args) return model @register_model def convnext_small_in22k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) model = _create_convnext('convnext_small_in22k', pretrained=pretrained, **model_args) return model diff --git a/timm/models/helpers.py b/timm/models/helpers.py index c4f48d6a..1276b68e 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -477,7 +477,7 @@ def build_model_with_cfg( pretrained_cfg: Optional[Dict] = None, model_cfg: Optional[Any] = None, feature_cfg: Optional[Dict] = None, - pretrained_strict: bool = False, + pretrained_strict: bool = True, pretrained_filter_fn: Optional[Callable] = None, pretrained_custom_load: bool = False, kwargs_filter: Optional[Tuple[str]] = None, diff --git a/timm/models/layers/cond_conv2d.py b/timm/models/layers/cond_conv2d.py index 8b4bbca8..43654c59 100644 --- a/timm/models/layers/cond_conv2d.py +++ b/timm/models/layers/cond_conv2d.py @@ -91,7 +91,8 @@ class CondConv2d(nn.Module): bias = torch.matmul(routing_weights, self.bias) bias = bias.view(B * self.out_channels) # move batch elements with channels so each batch element can be efficiently convolved with separate kernel - x = x.view(1, B * C, H, W) + # reshape instead of view to work with channels_last input + x = x.reshape(1, B * C, H, W) if self.dynamic_padding: out = conv2d_same( x, weight, bias, stride=self.stride, padding=self.padding, diff --git a/timm/models/layers/mlp.py b/timm/models/layers/mlp.py index a85e28d0..91e80a84 100644 --- a/timm/models/layers/mlp.py +++ b/timm/models/layers/mlp.py @@ -10,16 +10,17 @@ from .helpers import to_2tuple class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features + bias = to_2tuple(bias) drop_probs = to_2tuple(drop) - self.fc1 = nn.Linear(in_features, hidden_features) + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) - self.fc2 = nn.Linear(hidden_features, out_features) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): @@ -35,17 +36,18 @@ class GluMlp(nn.Module): """ MLP w/ GLU style gating See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 """ - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, bias=True, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features assert hidden_features % 2 == 0 + bias = to_2tuple(bias) drop_probs = to_2tuple(drop) - self.fc1 = nn.Linear(in_features, hidden_features) + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) - self.fc2 = nn.Linear(hidden_features // 2, out_features) + self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def init_weights(self): @@ -67,14 +69,16 @@ class GluMlp(nn.Module): class GatedMlp(nn.Module): """ MLP as used in gMLP """ - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, - gate_layer=None, drop=0.): + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, + gate_layer=None, bias=True, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features + bias = to_2tuple(bias) drop_probs = to_2tuple(drop) - self.fc1 = nn.Linear(in_features, hidden_features) + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) if gate_layer is not None: @@ -83,7 +87,7 @@ class GatedMlp(nn.Module): hidden_features = hidden_features // 2 # FIXME base reduction on gate property? else: self.gate = nn.Identity() - self.fc2 = nn.Linear(hidden_features, out_features) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): @@ -100,15 +104,18 @@ class ConvMlp(nn.Module): """ MLP using 1x1 convs that keeps spatial dims """ def __init__( - self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.): + self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, + norm_layer=None, bias=True, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features - self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True) + bias = to_2tuple(bias) + + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0]) self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() self.act = act_layer() - self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True) self.drop = nn.Dropout(drop) + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1]) def forward(self, x): x = self.fc1(x) diff --git a/timm/models/poolformer.py b/timm/models/poolformer.py index 3cf6b1a3..17d657b0 100644 --- a/timm/models/poolformer.py +++ b/timm/models/poolformer.py @@ -234,8 +234,8 @@ class PoolFormer(nn.Module): return dict( stem=r'^patch_embed', # stem and embed blocks=[ - (r'^network\.(\d+)\.(\d+)', None), - (r'^network\.(\d+)', (0,)), + (r'^network\.(\d+).*\.proj', (99999,)), + (r'^network\.(\d+)', None) if coarse else (r'^network\.(\d+)\.(\d+)', None), (r'^norm', (99999,)) ], ) diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 3e22bf56..9d1f1f64 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -458,7 +458,7 @@ class RegNet(nn.Module): def group_matcher(self, coarse=False): return dict( stem=r'^stem', - blocks=r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.blocks\.(\d+)', + blocks=r'^s(\d+)' if coarse else r'^s(\d+)\.b(\d+)', ) @torch.jit.ignore diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 7a5afb3b..a7f0c0f6 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -148,6 +148,49 @@ default_cfgs = { 'swsl_resnext101_32x16d': _cfg( url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth'), + # Efficient Channel Attention ResNets + 'ecaresnet26t': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet26t_ra2-46609757.pth', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), + crop_pct=0.95, test_input_size=(3, 320, 320)), + 'ecaresnetlight': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNetLight_4f34b35b.pth', + interpolation='bicubic'), + 'ecaresnet50d': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet50D_833caf58.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'ecaresnet50d_pruned': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45899/outputs/ECAResNet50D_P_9c67f710.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'ecaresnet50t': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet50t_ra2-f7ac63c4.pth', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), + crop_pct=0.95, test_input_size=(3, 320, 320)), + 'ecaresnet101d': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet101D_281c5844.pth', + interpolation='bicubic', first_conv='conv1.0'), + 'ecaresnet101d_pruned': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'ecaresnet200d': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)), + 'ecaresnet269d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet269d_320_ra2-7baa55cb.pth', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), pool_size=(10, 10), + crop_pct=1.0, test_input_size=(3, 352, 352)), + + # Efficient Channel Attention ResNeXts + 'ecaresnext26t_32x4d': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0'), + 'ecaresnext50t_32x4d': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0'), + # Squeeze-Excitation ResNets, to eventually replace the models in senet.py 'seresnet18': _cfg( url='', @@ -180,7 +223,6 @@ default_cfgs = { url='', interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)), - # Squeeze-Excitation ResNeXts, to eventually replace the models in senet.py 'seresnext26d_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26d_32x4d-80fa48a3.pth', @@ -199,55 +241,16 @@ default_cfgs = { 'seresnext101_32x8d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnext101_32x8d_ah-e6bc4c0a.pth', interpolation='bicubic', test_input_size=(3, 288, 288), crop_pct=1.0), + 'seresnext101d_32x8d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnext101d_32x8d_ah-191d7b94.pth', + interpolation='bicubic', first_conv='conv1.0', test_input_size=(3, 288, 288), crop_pct=1.0), + 'senet154': _cfg( url='', interpolation='bicubic', first_conv='conv1.0'), - # Efficient Channel Attention ResNets - 'ecaresnet26t': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet26t_ra2-46609757.pth', - interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), - crop_pct=0.95, test_input_size=(3, 320, 320)), - 'ecaresnetlight': _cfg( - url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNetLight_4f34b35b.pth', - interpolation='bicubic'), - 'ecaresnet50d': _cfg( - url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet50D_833caf58.pth', - interpolation='bicubic', - first_conv='conv1.0'), - 'ecaresnet50d_pruned': _cfg( - url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45899/outputs/ECAResNet50D_P_9c67f710.pth', - interpolation='bicubic', - first_conv='conv1.0'), - 'ecaresnet50t': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet50t_ra2-f7ac63c4.pth', - interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), - crop_pct=0.95, test_input_size=(3, 320, 320)), - 'ecaresnet101d': _cfg( - url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet101D_281c5844.pth', - interpolation='bicubic', first_conv='conv1.0'), - 'ecaresnet101d_pruned': _cfg( - url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth', - interpolation='bicubic', - first_conv='conv1.0'), - 'ecaresnet200d': _cfg( - url='', - interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)), - 'ecaresnet269d': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet269d_320_ra2-7baa55cb.pth', - interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), pool_size=(10, 10), - crop_pct=1.0, test_input_size=(3, 352, 352)), - - # Efficient Channel Attention ResNeXts - 'ecaresnext26t_32x4d': _cfg( - url='', - interpolation='bicubic', first_conv='conv1.0'), - 'ecaresnext50t_32x4d': _cfg( - url='', - interpolation='bicubic', first_conv='conv1.0'), - - # ResNets with anti-aliasing blur pool + # ResNets with anti-aliasing / blur pool 'resnetblur18': _cfg( interpolation='bicubic'), 'resnetblur50': _cfg( @@ -268,6 +271,9 @@ default_cfgs = { 'seresnetaa50d': _cfg( url='', interpolation='bicubic', first_conv='conv1.0'), + 'seresnextaa101d_32x8d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnextaa101d_32x8d_ah-83c8ae12.pth', + interpolation='bicubic', first_conv='conv1.0', test_input_size=(3, 288, 288), crop_pct=1.0), # ResNet-RS models 'resnetrs50': _cfg( @@ -1157,98 +1163,6 @@ def ecaresnet50d(pretrained=False, **kwargs): return _create_resnet('ecaresnet50d', pretrained, **model_args) -@register_model -def resnetrs50(pretrained=False, **kwargs): - """Constructs a ResNet-RS-50 model. - Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 - Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs - """ - attn_layer = partial(get_attn('se'), rd_ratio=0.25) - model_args = dict( - block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, - avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) - return _create_resnet('resnetrs50', pretrained, **model_args) - - -@register_model -def resnetrs101(pretrained=False, **kwargs): - """Constructs a ResNet-RS-101 model. - Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 - Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs - """ - attn_layer = partial(get_attn('se'), rd_ratio=0.25) - model_args = dict( - block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, - avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) - return _create_resnet('resnetrs101', pretrained, **model_args) - - -@register_model -def resnetrs152(pretrained=False, **kwargs): - """Constructs a ResNet-RS-152 model. - Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 - Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs - """ - attn_layer = partial(get_attn('se'), rd_ratio=0.25) - model_args = dict( - block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, - avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) - return _create_resnet('resnetrs152', pretrained, **model_args) - - -@register_model -def resnetrs200(pretrained=False, **kwargs): - """Constructs a ResNet-RS-200 model. - Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 - Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs - """ - attn_layer = partial(get_attn('se'), rd_ratio=0.25) - model_args = dict( - block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, - avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) - return _create_resnet('resnetrs200', pretrained, **model_args) - - -@register_model -def resnetrs270(pretrained=False, **kwargs): - """Constructs a ResNet-RS-270 model. - Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 - Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs - """ - attn_layer = partial(get_attn('se'), rd_ratio=0.25) - model_args = dict( - block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, - avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) - return _create_resnet('resnetrs270', pretrained, **model_args) - - - -@register_model -def resnetrs350(pretrained=False, **kwargs): - """Constructs a ResNet-RS-350 model. - Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 - Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs - """ - attn_layer = partial(get_attn('se'), rd_ratio=0.25) - model_args = dict( - block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, - avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) - return _create_resnet('resnetrs350', pretrained, **model_args) - - -@register_model -def resnetrs420(pretrained=False, **kwargs): - """Constructs a ResNet-RS-420 model - Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 - Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs - """ - attn_layer = partial(get_attn('se'), rd_ratio=0.25) - model_args = dict( - block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, - avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) - return _create_resnet('resnetrs420', pretrained, **model_args) - - @register_model def ecaresnet50d_pruned(pretrained=False, **kwargs): """Constructs a ResNet-50-D model pruned with eca. @@ -1346,72 +1260,6 @@ def ecaresnext50t_32x4d(pretrained=False, **kwargs): return _create_resnet('ecaresnext50t_32x4d', pretrained, **model_args) -@register_model -def resnetblur18(pretrained=False, **kwargs): - """Constructs a ResNet-18 model with blur anti-aliasing - """ - model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d, **kwargs) - return _create_resnet('resnetblur18', pretrained, **model_args) - - -@register_model -def resnetblur50(pretrained=False, **kwargs): - """Constructs a ResNet-50 model with blur anti-aliasing - """ - model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, **kwargs) - return _create_resnet('resnetblur50', pretrained, **model_args) - - -@register_model -def resnetblur50d(pretrained=False, **kwargs): - """Constructs a ResNet-50-D model with blur anti-aliasing - """ - model_args = dict( - block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, - stem_width=32, stem_type='deep', avg_down=True, **kwargs) - return _create_resnet('resnetblur50d', pretrained, **model_args) - - -@register_model -def resnetblur101d(pretrained=False, **kwargs): - """Constructs a ResNet-101-D model with blur anti-aliasing - """ - model_args = dict( - block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=BlurPool2d, - stem_width=32, stem_type='deep', avg_down=True, **kwargs) - return _create_resnet('resnetblur101d', pretrained, **model_args) - - -@register_model -def resnetaa50d(pretrained=False, **kwargs): - """Constructs a ResNet-50-D model with avgpool anti-aliasing - """ - model_args = dict( - block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, - stem_width=32, stem_type='deep', avg_down=True, **kwargs) - return _create_resnet('resnetaa50d', pretrained, **model_args) - - -@register_model -def resnetaa101d(pretrained=False, **kwargs): - """Constructs a ResNet-101-D model with avgpool anti-aliasing - """ - model_args = dict( - block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=nn.AvgPool2d, - stem_width=32, stem_type='deep', avg_down=True, **kwargs) - return _create_resnet('resnetaa101d', pretrained, **model_args) - - -@register_model -def seresnetaa50d(pretrained=False, **kwargs): - """Constructs a SE=ResNet-50-D model with avgpool anti-aliasing - """ - model_args = dict( - block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, - stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('seresnetaa50d', pretrained, **model_args) - - @register_model def seresnet18(pretrained=False, **kwargs): model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'), **kwargs) @@ -1535,9 +1383,187 @@ def seresnext101_32x8d(pretrained=False, **kwargs): return _create_resnet('seresnext101_32x8d', pretrained, **model_args) +@register_model +def seresnext101d_32x8d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, + stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext101d_32x8d', pretrained, **model_args) + + @register_model def senet154(pretrained=False, **kwargs): model_args = dict( block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se'), **kwargs) return _create_resnet('senet154', pretrained, **model_args) + + +@register_model +def resnetblur18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model with blur anti-aliasing + """ + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d, **kwargs) + return _create_resnet('resnetblur18', pretrained, **model_args) + + +@register_model +def resnetblur50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model with blur anti-aliasing + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, **kwargs) + return _create_resnet('resnetblur50', pretrained, **model_args) + + +@register_model +def resnetblur50d(pretrained=False, **kwargs): + """Constructs a ResNet-50-D model with blur anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, + stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnetblur50d', pretrained, **model_args) + + +@register_model +def resnetblur101d(pretrained=False, **kwargs): + """Constructs a ResNet-101-D model with blur anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=BlurPool2d, + stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnetblur101d', pretrained, **model_args) + + +@register_model +def resnetaa50d(pretrained=False, **kwargs): + """Constructs a ResNet-50-D model with avgpool anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, + stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnetaa50d', pretrained, **model_args) + + +@register_model +def resnetaa101d(pretrained=False, **kwargs): + """Constructs a ResNet-101-D model with avgpool anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=nn.AvgPool2d, + stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnetaa101d', pretrained, **model_args) + + +@register_model +def seresnetaa50d(pretrained=False, **kwargs): + """Constructs a SE=ResNet-50-D model with avgpool anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, + stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnetaa50d', pretrained, **model_args) + + +@register_model +def seresnextaa101d_32x8d(pretrained=False, **kwargs): + """Constructs a SE=ResNeXt-101-D 32x8d model with avgpool anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, + stem_width=32, stem_type='deep', avg_down=True, aa_layer=nn.AvgPool2d, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnextaa101d_32x8d', pretrained, **model_args) + + +@register_model +def resnetrs50(pretrained=False, **kwargs): + """Constructs a ResNet-RS-50 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs50', pretrained, **model_args) + + +@register_model +def resnetrs101(pretrained=False, **kwargs): + """Constructs a ResNet-RS-101 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs101', pretrained, **model_args) + + +@register_model +def resnetrs152(pretrained=False, **kwargs): + """Constructs a ResNet-RS-152 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs152', pretrained, **model_args) + + +@register_model +def resnetrs200(pretrained=False, **kwargs): + """Constructs a ResNet-RS-200 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs200', pretrained, **model_args) + + +@register_model +def resnetrs270(pretrained=False, **kwargs): + """Constructs a ResNet-RS-270 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs270', pretrained, **model_args) + + + +@register_model +def resnetrs350(pretrained=False, **kwargs): + """Constructs a ResNet-RS-350 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs350', pretrained, **model_args) + + +@register_model +def resnetrs420(pretrained=False, **kwargs): + """Constructs a ResNet-RS-420 model + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs420', pretrained, **model_args) diff --git a/timm/models/sequencer.py b/timm/models/sequencer.py new file mode 100644 index 00000000..b1ae92a4 --- /dev/null +++ b/timm/models/sequencer.py @@ -0,0 +1,417 @@ +""" Sequencer + +Paper: `Sequencer: Deep LSTM for Image Classification` - https://arxiv.org/pdf/2205.01972.pdf + +""" +# Copyright (c) 2022. Yuki Tatsunami +# Licensed under the Apache License, Version 2.0 (the "License"); + + +import math +from functools import partial +from typing import Tuple + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT +from .helpers import build_model_with_cfg, named_apply +from .layers import lecun_normal_, DropPath, Mlp, PatchEmbed as TimmPatchEmbed +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': DEFAULT_CROP_PCT, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = dict( + sequencer2d_s=_cfg(url="https://github.com/okojoalg/sequencer/releases/download/weights/sequencer2d_s.pth"), + sequencer2d_m=_cfg(url="https://github.com/okojoalg/sequencer/releases/download/weights/sequencer2d_m.pth"), + sequencer2d_l=_cfg(url="https://github.com/okojoalg/sequencer/releases/download/weights/sequencer2d_l.pth"), +) + + +def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False): + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + if flax: + # Flax defaults + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.RNN, nn.GRU, nn.LSTM)): + stdv = 1.0 / math.sqrt(module.hidden_size) + for weight in module.parameters(): + nn.init.uniform_(weight, -stdv, stdv) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def get_stage( + index, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer, rnn_layer, mlp_layer, + norm_layer, act_layer, num_layers, bidirectional, union, + with_fc, drop=0., drop_path_rate=0., **kwargs): + assert len(layers) == len(patch_sizes) == len(embed_dims) == len(hidden_sizes) == len(mlp_ratios) + blocks = [] + for block_idx in range(layers[index]): + drop_path = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) + blocks.append(block_layer( + embed_dims[index], hidden_sizes[index], mlp_ratio=mlp_ratios[index], + rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, + num_layers=num_layers, bidirectional=bidirectional, union=union, with_fc=with_fc, + drop=drop, drop_path=drop_path)) + + if index < len(embed_dims) - 1: + blocks.append(Downsample2D(embed_dims[index], embed_dims[index + 1], patch_sizes[index + 1])) + + blocks = nn.Sequential(*blocks) + return blocks + + +class RNNIdentity(nn.Module): + def __init__(self, *args, **kwargs): + super(RNNIdentity, self).__init__() + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, None]: + return x, None + + +class RNN2DBase(nn.Module): + + def __init__( + self, input_size: int, hidden_size: int, + num_layers: int = 1, bias: bool = True, bidirectional: bool = True, + union="cat", with_fc=True): + super().__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.output_size = 2 * hidden_size if bidirectional else hidden_size + self.union = union + + self.with_vertical = True + self.with_horizontal = True + self.with_fc = with_fc + + self.fc = None + if with_fc: + if union == "cat": + self.fc = nn.Linear(2 * self.output_size, input_size) + elif union == "add": + self.fc = nn.Linear(self.output_size, input_size) + elif union == "vertical": + self.fc = nn.Linear(self.output_size, input_size) + self.with_horizontal = False + elif union == "horizontal": + self.fc = nn.Linear(self.output_size, input_size) + self.with_vertical = False + else: + raise ValueError("Unrecognized union: " + union) + elif union == "cat": + pass + if 2 * self.output_size != input_size: + raise ValueError(f"The output channel {2 * self.output_size} is different from the input channel {input_size}.") + elif union == "add": + pass + if self.output_size != input_size: + raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.") + elif union == "vertical": + if self.output_size != input_size: + raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.") + self.with_horizontal = False + elif union == "horizontal": + if self.output_size != input_size: + raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.") + self.with_vertical = False + else: + raise ValueError("Unrecognized union: " + union) + + self.rnn_v = RNNIdentity() + self.rnn_h = RNNIdentity() + + def forward(self, x): + B, H, W, C = x.shape + + if self.with_vertical: + v = x.permute(0, 2, 1, 3) + v = v.reshape(-1, H, C) + v, _ = self.rnn_v(v) + v = v.reshape(B, W, H, -1) + v = v.permute(0, 2, 1, 3) + else: + v = None + + if self.with_horizontal: + h = x.reshape(-1, W, C) + h, _ = self.rnn_h(h) + h = h.reshape(B, H, W, -1) + else: + h = None + + if v is not None and h is not None: + if self.union == "cat": + x = torch.cat([v, h], dim=-1) + else: + x = v + h + elif v is not None: + x = v + elif h is not None: + x = h + + if self.fc is not None: + x = self.fc(x) + + return x + + +class LSTM2D(RNN2DBase): + + def __init__( + self, input_size: int, hidden_size: int, + num_layers: int = 1, bias: bool = True, bidirectional: bool = True, + union="cat", with_fc=True): + super().__init__(input_size, hidden_size, num_layers, bias, bidirectional, union, with_fc) + if self.with_vertical: + self.rnn_v = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bias=bias, bidirectional=bidirectional) + if self.with_horizontal: + self.rnn_h = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bias=bias, bidirectional=bidirectional) + + +class Sequencer2DBlock(nn.Module): + def __init__( + self, dim, hidden_size, mlp_ratio=3.0, rnn_layer=LSTM2D, mlp_layer=Mlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, + num_layers=1, bidirectional=True, union="cat", with_fc=True, drop=0., drop_path=0.): + super().__init__() + channels_dim = int(mlp_ratio * dim) + self.norm1 = norm_layer(dim) + self.rnn_tokens = rnn_layer(dim, hidden_size, num_layers=num_layers, bidirectional=bidirectional, + union=union, with_fc=with_fc) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.rnn_tokens(self.norm1(x))) + x = x + self.drop_path(self.mlp_channels(self.norm2(x))) + return x + + +class PatchEmbed(TimmPatchEmbed): + def forward(self, x): + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + else: + x = x.permute(0, 2, 3, 1) # BCHW -> BHWC + x = self.norm(x) + return x + + +class Shuffle(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + if self.training: + B, H, W, C = x.shape + r = torch.randperm(H * W) + x = x.reshape(B, -1, C) + x = x[:, r, :].reshape(B, H, W, -1) + return x + + +class Downsample2D(nn.Module): + def __init__(self, input_dim, output_dim, patch_size): + super().__init__() + self.down = nn.Conv2d(input_dim, output_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + x = x.permute(0, 3, 1, 2) + x = self.down(x) + x = x.permute(0, 2, 3, 1) + return x + + +class Sequencer2D(nn.Module): + def __init__( + self, + num_classes=1000, + img_size=224, + in_chans=3, + global_pool='avg', + layers=[4, 3, 8, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + block_layer=Sequencer2DBlock, + rnn_layer=LSTM2D, + mlp_layer=Mlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + num_rnn_layers=1, + bidirectional=True, + union="cat", + with_fc=True, + drop_rate=0., + drop_path_rate=0., + nlhb=False, + stem_norm=False, + ): + super().__init__() + assert global_pool in ('', 'avg') + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = embed_dims[-1] # num_features for consistency with other models + self.feature_dim = -1 # channel dim index for feature outputs (rank 4, NHWC) + self.embed_dims = embed_dims + self.stem = PatchEmbed( + img_size=img_size, patch_size=patch_sizes[0], in_chans=in_chans, + embed_dim=embed_dims[0], norm_layer=norm_layer if stem_norm else None, + flatten=False) + + self.blocks = nn.Sequential(*[ + get_stage( + i, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer=block_layer, + rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, + num_layers=num_rnn_layers, bidirectional=bidirectional, + union=union, with_fc=with_fc, drop=drop_rate, drop_path_rate=drop_path_rate, + ) + for i, _ in enumerate(embed_dims)]) + + self.norm = norm_layer(embed_dims[-1]) + self.head = nn.Linear(embed_dims[-1], self.num_classes) if num_classes > 0 else nn.Identity() + + self.init_weights(nlhb=nlhb) + + def init_weights(self, nlhb=False): + head_bias = -math.log(self.num_classes) if nlhb else 0. + named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=[ + (r'^blocks\.(\d+)\..*\.down', (99999,)), + (r'^blocks\.(\d+)', None) if coarse else (r'^blocks\.(\d+)\.(\d+)', None), + (r'^norm', (99999,)) + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'avg') + self.global_pool = global_pool + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.stem(x) + x = self.blocks(x) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool == 'avg': + x = x.mean(dim=(1, 2)) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_sequencer2d(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Sequencer2D models.') + + model = build_model_with_cfg(Sequencer2D, variant, pretrained, **kwargs) + return model + + +# main + +@register_model +def sequencer2d_s(pretrained=False, **kwargs): + model_args = dict( + layers=[4, 3, 8, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM2D, + bidirectional=True, + union="cat", + with_fc=True, + **kwargs) + model = _create_sequencer2d('sequencer2d_s', pretrained=pretrained, **model_args) + return model + + +@register_model +def sequencer2d_m(pretrained=False, **kwargs): + model_args = dict( + layers=[4, 3, 14, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM2D, + bidirectional=True, + union="cat", + with_fc=True, + **kwargs) + model = _create_sequencer2d('sequencer2d_m', pretrained=pretrained, **model_args) + return model + + +@register_model +def sequencer2d_l(pretrained=False, **kwargs): + model_args = dict( + layers=[8, 8, 16, 4], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM2D, + bidirectional=True, + union="cat", + with_fc=True, + **kwargs) + model = _create_sequencer2d('sequencer2d_l', pretrained=pretrained, **model_args) + return model diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py new file mode 100644 index 00000000..0c9db3dd --- /dev/null +++ b/timm/models/swin_transformer_v2.py @@ -0,0 +1,753 @@ +""" Swin Transformer V2 +A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` + - https://arxiv.org/abs/2111.09883 + +Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below + +Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman +""" +# -------------------------------------------------------- +# Swin Transformer V2 +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# -------------------------------------------------------- +import math +from typing import Tuple, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_notrace_function +from .helpers import build_model_with_cfg, named_apply +from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'swinv2_tiny_window8_256': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth', + input_size=(3, 256, 256) + ), + 'swinv2_tiny_window16_256': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth', + input_size=(3, 256, 256) + ), + 'swinv2_small_window8_256': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth', + input_size=(3, 256, 256) + ), + 'swinv2_small_window16_256': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth', + input_size=(3, 256, 256) + ), + 'swinv2_base_window8_256': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth', + input_size=(3, 256, 256) + ), + 'swinv2_base_window16_256': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth', + input_size=(3, 256, 256) + ), + + 'swinv2_base_window12_192_22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth', + num_classes=21841, input_size=(3, 192, 192) + ), + 'swinv2_base_window12to16_192to256_22kft1k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth', + input_size=(3, 256, 256) + ), + 'swinv2_base_window12to24_192to384_22kft1k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth', + input_size=(3, 384, 384), crop_pct=1.0, + ), + 'swinv2_large_window12_192_22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth', + num_classes=21841, input_size=(3, 192, 192) + ), + 'swinv2_large_window12to16_192to256_22kft1k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth', + input_size=(3, 256, 256) + ), + 'swinv2_large_window12to24_192to384_22kft1k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth', + input_size=(3, 384, 384), crop_pct=1.0, + ), +} + + +def window_partition(x, window_size: Tuple[int, int]): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]): + """ + Args: + windows: (num_windows * B, window_size[0], window_size[1], C) + window_size (Tuple[int, int]): Window size + img_size (Tuple[int, int]): Image size + + Returns: + x: (B, H, W, C) + """ + H, W = img_size + B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) + x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + pretrained_window_size (tuple[int]): The height and width of the window in pre-training. + """ + + def __init__( + self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., + pretrained_window_size=[0, 0]): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.pretrained_window_size = pretrained_window_size + self.num_heads = num_heads + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential( + nn.Linear(2, 512, bias=True), + nn.ReLU(inplace=True), + nn.Linear(512, num_heads, bias=False) + ) + + # get relative_coords_table + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) + relative_coords_table = torch.stack(torch.meshgrid([ + relative_coords_h, + relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + if pretrained_window_size[0] > 0: + relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) + else: + relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / math.log2(8) + + self.register_buffer("relative_coords_table", relative_coords_table, persistent=False) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index, persistent=False) + + self.qkv = nn.Linear(dim, dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(dim)) + self.register_buffer('k_bias', torch.zeros(dim), persistent=False) + self.v_bias = nn.Parameter(torch.zeros(dim)) + else: + self.q_bias = None + self.k_bias = None + self.v_bias = None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask: Optional[torch.Tensor] = None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + + # cosine attention + attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) + logit_scale = torch.clamp(self.logit_scale, max=math.log(1. / 0.01)).exp() + attn = attn * logit_scale + + relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) + relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + pretrained_window_size (int): Window size in pretraining. + """ + + def __init__( + self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0): + super().__init__() + self.dim = dim + self.input_resolution = to_2tuple(input_resolution) + self.num_heads = num_heads + ws, ss = self._calc_window_shift(window_size, shift_size) + self.window_size: Tuple[int, int] = ws + self.shift_size: Tuple[int, int] = ss + self.window_area = self.window_size[0] * self.window_size[1] + self.mlp_ratio = mlp_ratio + + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, + pretrained_window_size=to_2tuple(pretrained_window_size)) + self.norm1 = norm_layer(dim) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.norm2 = norm_layer(dim) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + if any(self.shift_size): + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + cnt = 0 + for h in ( + slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)): + for w in ( + slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)): + img_mask[:, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_area) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]: + target_window_size = to_2tuple(target_window_size) + target_shift_size = to_2tuple(target_shift_size) + window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)] + shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)] + return tuple(window_size), tuple(shift_size) + + def _attn(self, x): + H, W = self.input_resolution + B, L, C = x.shape + _assert(L == H * W, "input feature has wrong size") + x = x.view(B, H, W, C) + + # cyclic shift + has_shift = any(self.shift_size) + if has_shift: + shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C) + shifted_x = window_reverse(attn_windows, self.window_size, self.input_resolution) # B H' W' C + + # reverse cyclic shift + if has_shift: + x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + return x + + def forward(self, x): + x = x + self.drop_path1(self.norm1(self._attn(x))) + x = x + self.drop_path2(self.norm2(self.mlp(x))) + return x + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(2 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + _assert(L == H * W, "input feature has wrong size") + _assert(H % 2 == 0, f"x size ({H}*{W}) are not even.") + _assert(W % 2 == 0, f"x size ({H}*{W}) are not even.") + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.reduction(x) + x = self.norm(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + pretrained_window_size (int): Local window size in pre-training. + """ + + def __init__( + self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., + norm_layer=nn.LayerNorm, downsample=None, pretrained_window_size=0): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.grad_checkpointing = False + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + pretrained_window_size=pretrained_window_size) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = nn.Identity() + + def forward(self, x): + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + x = self.downsample(x) + return x + + def _init_respostnorm(self): + for blk in self.blocks: + nn.init.constant_(blk.norm1.bias, 0) + nn.init.constant_(blk.norm1.weight, 0) + nn.init.constant_(blk.norm2.bias, 0) + nn.init.constant_(blk.norm2.weight, 0) + + +class SwinTransformerV2(nn.Module): + r""" Swin Transformer V2 + A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` + - https://arxiv.org/abs/2111.09883 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer. + """ + + def __init__( + self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg', + embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), + window_size=7, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + pretrained_window_sizes=(0, 0, 0, 0), **kwargs): + super().__init__() + + self.num_classes = num_classes + assert global_pool in ('', 'avg') + self.global_pool = global_pool + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + + # absolute position embedding + if ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + else: + self.absolute_pos_embed = None + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + input_resolution=( + self.patch_embed.grid_size[0] // (2 ** i_layer), + self.patch_embed.grid_size[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + pretrained_window_size=pretrained_window_sizes[i_layer] + ) + self.layers.append(layer) + + self.norm = norm_layer(self.num_features) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + for bly in self.layers: + bly._init_respostnorm() + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def no_weight_decay(self): + nod = {'absolute_pos_embed'} + for n, m in self.named_modules(): + if any([kw in n for kw in ("cpb_mlp", "logit_scale", 'relative_position_bias_table')]): + nod.add(n) + return nod + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^absolute_pos_embed|patch_embed', # stem and embed + blocks=r'^layers\.(\d+)' if coarse else [ + (r'^layers\.(\d+).downsample', (0,)), + (r'^layers\.(\d+)\.\w+\.(\d+)', None), + (r'^norm', (99999,)), + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for l in self.layers: + l.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'avg') + self.global_pool = global_pool + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + if self.absolute_pos_embed is not None: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.norm(x) # B L C + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool == 'avg': + x = x.mean(dim=1) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + out_dict = {} + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + for k, v in state_dict.items(): + if any([n in k for n in ('relative_position_index', 'relative_coords_table')]): + continue # skip buffers that should not be persistent + out_dict[k] = v + return out_dict + + +def _create_swin_transformer_v2(variant, pretrained=False, **kwargs): + model = build_model_with_cfg( + SwinTransformerV2, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + +@register_model +def swinv2_tiny_window16_256(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=16, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer_v2('swinv2_tiny_window16_256', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_tiny_window8_256(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=8, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer_v2('swinv2_tiny_window8_256', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_small_window16_256(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=16, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer_v2('swinv2_small_window16_256', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_small_window8_256(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=8, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer_v2('swinv2_small_window8_256', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_base_window16_256(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=16, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer_v2('swinv2_base_window16_256', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_base_window8_256(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=8, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer_v2('swinv2_base_window8_256', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_base_window12_192_22k(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer_v2('swinv2_base_window12_192_22k', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_base_window12to16_192to256_22kft1k(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=16, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), + pretrained_window_sizes=(12, 12, 12, 6), **kwargs) + return _create_swin_transformer_v2( + 'swinv2_base_window12to16_192to256_22kft1k', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_base_window12to24_192to384_22kft1k(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=24, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), + pretrained_window_sizes=(12, 12, 12, 6), **kwargs) + return _create_swin_transformer_v2( + 'swinv2_base_window12to24_192to384_22kft1k', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_large_window12_192_22k(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer_v2('swinv2_large_window12_192_22k', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_large_window12to16_192to256_22kft1k(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=16, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), + pretrained_window_sizes=(12, 12, 12, 6), **kwargs) + return _create_swin_transformer_v2( + 'swinv2_large_window12to16_192to256_22kft1k', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_large_window12to24_192to384_22kft1k(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=24, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), + pretrained_window_sizes=(12, 12, 12, 6), **kwargs) + return _create_swin_transformer_v2( + 'swinv2_large_window12to24_192to384_22kft1k', pretrained=pretrained, **model_kwargs) diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 72a0b594..e6590e8a 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -34,6 +34,7 @@ from typing import Tuple, Optional, List, Union, Any, Type import torch import torch.nn as nn +import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -41,7 +42,7 @@ from .fx_features import register_notrace_function from .helpers import build_model_with_cfg, named_apply from .layers import DropPath, Mlp, to_2tuple, _assert from .registry import register_model -from .vision_transformer import checkpoint_filter_fn + _logger = logging.getLogger(__name__) @@ -51,7 +52,7 @@ def _cfg(url='', **kwargs): 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), - 'pool_size': None, + 'pool_size': (7, 7), 'crop_pct': 0.9, 'interpolation': 'bicubic', 'fixed_input_size': True, @@ -64,33 +65,38 @@ def _cfg(url='', **kwargs): default_cfgs = { - 'swin_v2_cr_tiny_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0), - 'swin_v2_cr_tiny_224': _cfg( + 'swinv2_cr_tiny_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), + 'swinv2_cr_tiny_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), - 'swin_v2_cr_tiny_ns_224': _cfg( + 'swinv2_cr_tiny_ns_224': _cfg( url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth", input_size=(3, 224, 224), crop_pct=0.9), - 'swin_v2_cr_small_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0), - 'swin_v2_cr_small_224': _cfg( + 'swinv2_cr_small_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), + 'swinv2_cr_small_224': _cfg( url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth", input_size=(3, 224, 224), crop_pct=0.9), - 'swin_v2_cr_base_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0), - 'swin_v2_cr_base_224': _cfg( + 'swinv2_cr_small_ns_224': _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth", + input_size=(3, 224, 224), crop_pct=0.9), + 'swinv2_cr_base_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), + 'swinv2_cr_base_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), - 'swin_v2_cr_large_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0), - 'swin_v2_cr_large_224': _cfg( + 'swinv2_cr_base_ns_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), - 'swin_v2_cr_huge_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0), - 'swin_v2_cr_huge_224': _cfg( + 'swinv2_cr_large_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), + 'swinv2_cr_large_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), - 'swin_v2_cr_giant_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0), - 'swin_v2_cr_giant_224': _cfg( + 'swinv2_cr_huge_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), + 'swinv2_cr_huge_224': _cfg( + url="", input_size=(3, 224, 224), crop_pct=0.9), + 'swinv2_cr_giant_384': _cfg( + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), + 'swinv2_cr_giant_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), } @@ -179,14 +185,15 @@ class WindowMultiHeadAttention(nn.Module): hidden_features=meta_hidden_dim, out_features=num_heads, act_layer=nn.ReLU, - drop=0.1 # FIXME should there be stochasticity, appears to 'overfit' without? + drop=(0.125, 0.) # FIXME should there be stochasticity, appears to 'overfit' without? ) - self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads))) + # NOTE old checkpoints used inverse of logit_scale ('tau') following the paper, see conversion fn + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones(num_heads))) self._make_pair_wise_relative_positions() def _make_pair_wise_relative_positions(self) -> None: """Method initializes the pair-wise relative positions to compute the positional biases.""" - device = self.tau.device + device = self.logit_scale.device coordinates = torch.stack(torch.meshgrid([ torch.arange(self.window_size[0], device=device), torch.arange(self.window_size[1], device=device)]), dim=0).flatten(1) @@ -245,10 +252,11 @@ class WindowMultiHeadAttention(nn.Module): query, key, value = qkv.unbind(0) # compute attention map with scaled cosine attention - denom = torch.norm(query, dim=-1, keepdim=True) @ torch.norm(key, dim=-1, keepdim=True).transpose(-2, -1) - attn = query @ key.transpose(-2, -1) / denom.clamp(min=1e-6) - attn = attn / self.tau.clamp(min=0.01).reshape(1, self.num_heads, 1, 1) + attn = (F.normalize(query, dim=-1) @ F.normalize(key, dim=-1).transpose(-2, -1)) + logit_scale = torch.clamp(self.logit_scale.reshape(1, self.num_heads, 1, 1), max=math.log(1. / 0.01)).exp() + attn = attn * logit_scale attn = attn + self._relative_positional_encodings() + if mask is not None: # Apply mask if utilized num_win: int = mask.shape[0] @@ -304,6 +312,7 @@ class SwinTransformerBlock(nn.Module): window_size: Tuple[int, int], shift_size: Tuple[int, int] = (0, 0), mlp_ratio: float = 4.0, + init_values: Optional[float] = 0, drop: float = 0.0, drop_attn: float = 0.0, drop_path: float = 0.0, @@ -317,6 +326,7 @@ class SwinTransformerBlock(nn.Module): self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size) self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size)) self.window_area = self.window_size[0] * self.window_size[1] + self.init_values: Optional[float] = init_values # attn branch self.attn = WindowMultiHeadAttention( @@ -345,6 +355,7 @@ class SwinTransformerBlock(nn.Module): self.norm3 = norm_layer(dim) if extra_norm else nn.Identity() self._make_attention_mask() + self.init_weights() def _calc_window_shift(self, target_window_size): window_size = [f if f <= w else w for f, w in zip(self.feat_size, target_window_size)] @@ -377,6 +388,12 @@ class SwinTransformerBlock(nn.Module): attn_mask = None self.register_buffer("attn_mask", attn_mask, persistent=False) + def init_weights(self): + # extra, module specific weight init + if self.init_values is not None: + nn.init.constant_(self.norm1.weight, self.init_values) + nn.init.constant_(self.norm2.weight, self.init_values) + def update_input_size(self, new_window_size: Tuple[int, int], new_feat_size: Tuple[int, int]) -> None: """Method updates the image resolution to be processed and window size and so the pair-wise relative positions. @@ -435,7 +452,7 @@ class SwinTransformerBlock(nn.Module): Returns: output (torch.Tensor): Output tensor of the shape [B, C, H, W] """ - # NOTE post-norm branches (op -> norm -> drop) + # post-norm branches (op -> norm -> drop) x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x))) x = x + self.drop_path2(self.norm2(self.mlp(x))) x = self.norm3(x) # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant) @@ -522,6 +539,7 @@ class SwinTransformerStage(nn.Module): feat_size: Tuple[int, int], window_size: Tuple[int, int], mlp_ratio: float = 4.0, + init_values: Optional[float] = 0.0, drop: float = 0.0, drop_attn: float = 0.0, drop_path: Union[List[float], float] = 0.0, @@ -552,6 +570,7 @@ class SwinTransformerStage(nn.Module): window_size=window_size, shift_size=tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]), mlp_ratio=mlp_ratio, + init_values=init_values, drop=drop, drop_attn=drop_attn, drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path, @@ -634,6 +653,7 @@ class SwinTransformerV2Cr(nn.Module): depths: Tuple[int, ...] = (2, 2, 6, 2), num_heads: Tuple[int, ...] = (3, 6, 12, 24), mlp_ratio: float = 4.0, + init_values: Optional[float] = 0., drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, @@ -674,6 +694,7 @@ class SwinTransformerV2Cr(nn.Module): num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio, + init_values=init_values, drop=drop_rate, drop_attn=attn_drop_rate, drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])], @@ -786,6 +807,23 @@ def init_weights(module: nn.Module, name: str = ''): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def checkpoint_filter_fn(state_dict, model): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + for k, v in state_dict.items(): + if 'tau' in k: + # convert old tau based checkpoints -> logit_scale (inverse) + v = torch.log(1 / v) + k = k.replace('tau', 'logit_scale') + out_dict[k] = v + return out_dict def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs): @@ -796,7 +834,7 @@ def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs): @register_model -def swin_v2_cr_tiny_384(pretrained=False, **kwargs): +def swinv2_cr_tiny_384(pretrained=False, **kwargs): """Swin-T V2 CR @ 384x384, trained ImageNet-1k""" model_kwargs = dict( embed_dim=96, @@ -804,11 +842,11 @@ def swin_v2_cr_tiny_384(pretrained=False, **kwargs): num_heads=(3, 6, 12, 24), **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_384', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swinv2_cr_tiny_384', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_tiny_224(pretrained=False, **kwargs): +def swinv2_cr_tiny_224(pretrained=False, **kwargs): """Swin-T V2 CR @ 224x224, trained ImageNet-1k""" model_kwargs = dict( embed_dim=96, @@ -816,11 +854,11 @@ def swin_v2_cr_tiny_224(pretrained=False, **kwargs): num_heads=(3, 6, 12, 24), **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_224', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swinv2_cr_tiny_224', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_tiny_ns_224(pretrained=False, **kwargs): +def swinv2_cr_tiny_ns_224(pretrained=False, **kwargs): """Swin-T V2 CR @ 224x224, trained ImageNet-1k w/ extra stage norms. ** Experimental, may make default if results are improved. ** """ @@ -831,11 +869,11 @@ def swin_v2_cr_tiny_ns_224(pretrained=False, **kwargs): extra_norm_stage=True, **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_ns_224', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swinv2_cr_tiny_ns_224', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_small_384(pretrained=False, **kwargs): +def swinv2_cr_small_384(pretrained=False, **kwargs): """Swin-S V2 CR @ 384x384, trained ImageNet-1k""" model_kwargs = dict( embed_dim=96, @@ -843,12 +881,12 @@ def swin_v2_cr_small_384(pretrained=False, **kwargs): num_heads=(3, 6, 12, 24), **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_small_384', pretrained=pretrained, **model_kwargs + return _create_swin_transformer_v2_cr('swinv2_cr_small_384', pretrained=pretrained, **model_kwargs ) @register_model -def swin_v2_cr_small_224(pretrained=False, **kwargs): +def swinv2_cr_small_224(pretrained=False, **kwargs): """Swin-S V2 CR @ 224x224, trained ImageNet-1k""" model_kwargs = dict( embed_dim=96, @@ -856,11 +894,24 @@ def swin_v2_cr_small_224(pretrained=False, **kwargs): num_heads=(3, 6, 12, 24), **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_small_224', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swinv2_cr_small_224', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_base_384(pretrained=False, **kwargs): +def swinv2_cr_small_ns_224(pretrained=False, **kwargs): + """Swin-S V2 CR @ 224x224, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=96, + depths=(2, 2, 18, 2), + num_heads=(3, 6, 12, 24), + extra_norm_stage=True, + **kwargs + ) + return _create_swin_transformer_v2_cr('swinv2_cr_small_ns_224', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_cr_base_384(pretrained=False, **kwargs): """Swin-B V2 CR @ 384x384, trained ImageNet-1k""" model_kwargs = dict( embed_dim=128, @@ -868,23 +919,36 @@ def swin_v2_cr_base_384(pretrained=False, **kwargs): num_heads=(4, 8, 16, 32), **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_base_384', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swinv2_cr_base_384', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_cr_base_224(pretrained=False, **kwargs): + """Swin-B V2 CR @ 224x224, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=128, + depths=(2, 2, 18, 2), + num_heads=(4, 8, 16, 32), + **kwargs + ) + return _create_swin_transformer_v2_cr('swinv2_cr_base_224', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_base_224(pretrained=False, **kwargs): +def swinv2_cr_base_ns_224(pretrained=False, **kwargs): """Swin-B V2 CR @ 224x224, trained ImageNet-1k""" model_kwargs = dict( embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), + extra_norm_stage=True, **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_base_224', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swinv2_cr_base_ns_224', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_large_384(pretrained=False, **kwargs): +def swinv2_cr_large_384(pretrained=False, **kwargs): """Swin-L V2 CR @ 384x384, trained ImageNet-1k""" model_kwargs = dict( embed_dim=192, @@ -892,12 +956,12 @@ def swin_v2_cr_large_384(pretrained=False, **kwargs): num_heads=(6, 12, 24, 48), **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_large_384', pretrained=pretrained, **model_kwargs + return _create_swin_transformer_v2_cr('swinv2_cr_large_384', pretrained=pretrained, **model_kwargs ) @register_model -def swin_v2_cr_large_224(pretrained=False, **kwargs): +def swinv2_cr_large_224(pretrained=False, **kwargs): """Swin-L V2 CR @ 224x224, trained ImageNet-1k""" model_kwargs = dict( embed_dim=192, @@ -905,11 +969,11 @@ def swin_v2_cr_large_224(pretrained=False, **kwargs): num_heads=(6, 12, 24, 48), **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_large_224', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swinv2_cr_large_224', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_huge_384(pretrained=False, **kwargs): +def swinv2_cr_huge_384(pretrained=False, **kwargs): """Swin-H V2 CR @ 384x384, trained ImageNet-1k""" model_kwargs = dict( embed_dim=352, @@ -918,11 +982,11 @@ def swin_v2_cr_huge_384(pretrained=False, **kwargs): extra_norm_period=6, **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_huge_384', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swinv2_cr_huge_384', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_huge_224(pretrained=False, **kwargs): +def swinv2_cr_huge_224(pretrained=False, **kwargs): """Swin-H V2 CR @ 224x224, trained ImageNet-1k""" model_kwargs = dict( embed_dim=352, @@ -931,11 +995,11 @@ def swin_v2_cr_huge_224(pretrained=False, **kwargs): extra_norm_period=6, **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_huge_224', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swinv2_cr_huge_224', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_giant_384(pretrained=False, **kwargs): +def swinv2_cr_giant_384(pretrained=False, **kwargs): """Swin-G V2 CR @ 384x384, trained ImageNet-1k""" model_kwargs = dict( embed_dim=512, @@ -944,12 +1008,12 @@ def swin_v2_cr_giant_384(pretrained=False, **kwargs): extra_norm_period=6, **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_giant_384', pretrained=pretrained, **model_kwargs + return _create_swin_transformer_v2_cr('swinv2_cr_giant_384', pretrained=pretrained, **model_kwargs ) @register_model -def swin_v2_cr_giant_224(pretrained=False, **kwargs): +def swinv2_cr_giant_224(pretrained=False, **kwargs): """Swin-G V2 CR @ 224x224, trained ImageNet-1k""" model_kwargs = dict( embed_dim=512, @@ -958,4 +1022,4 @@ def swin_v2_cr_giant_224(pretrained=False, **kwargs): extra_norm_period=6, **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_giant_224', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swinv2_cr_giant_224', pretrained=pretrained, **model_kwargs) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 17faba53..59fd7849 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -23,6 +23,7 @@ import math import logging from functools import partial from collections import OrderedDict +from typing import Optional import torch import torch.nn as nn @@ -107,7 +108,6 @@ default_cfgs = { 'vit_giant_patch14_224': _cfg(url=''), 'vit_gigantic_patch14_224': _cfg(url=''), - 'vit_base2_patch32_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), # patch models, imagenet21k (weights from official Google JAX impl) 'vit_tiny_patch16_224_in21k': _cfg( @@ -171,7 +171,12 @@ default_cfgs = { mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', ), - # experimental + 'vit_base_patch16_rpn_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth'), + + # experimental (may be removed) + 'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), + 'vit_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95), 'vit_small_patch16_36x1_224': _cfg(url=''), 'vit_small_patch16_18x2_224': _cfg(url=''), 'vit_base_patch16_18x2_224': _cfg(url=''), @@ -229,8 +234,7 @@ class Block(nn.Module): self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -240,6 +244,36 @@ class Block(nn.Module): return x +class ResPostBlock(nn.Module): + + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.init_values = init_values + + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.norm1 = norm_layer(dim) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.norm2 = norm_layer(dim) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.init_weights() + + def init_weights(self): + # NOTE this init overrides that base model init with specific changes for the block type + if self.init_values is not None: + nn.init.constant_(self.norm1.weight, self.init_values) + nn.init.constant_(self.norm2.weight, self.init_values) + + def forward(self, x): + x = x + self.drop_path1(self.norm1(self.attn(x))) + x = x + self.drop_path2(self.norm2(self.mlp(x))) + return x + + class ParallelBlock(nn.Module): def __init__( @@ -290,8 +324,8 @@ class VisionTransformer(nn.Module): def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', - embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None, + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, + class_token=True, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): """ Args: @@ -305,33 +339,36 @@ class VisionTransformer(nn.Module): num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True - representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + init_values: (float): layer-scale init values + class_token (bool): use class token + fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate - weight_init: (str): weight init scheme - init_values: (float): layer-scale init values + weight_init (str): weight init scheme embed_layer (nn.Module): patch embedding layer norm_layer: (nn.Module): normalization layer act_layer: (nn.Module): MLP activation layer """ super().__init__() assert global_pool in ('', 'avg', 'token') + assert class_token or global_pool != 'token' + use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.num_tokens = 1 + self.num_tokens = 1 if class_token else 0 self.grad_checkpointing = False self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if self.num_tokens > 0 else None + self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, embed_dim) * .02) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule @@ -340,38 +377,21 @@ class VisionTransformer(nn.Module): dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) for i in range(depth)]) - use_fc_norm = self.global_pool == 'avg' self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() - # Representation layer. Used for original ViT models w/ in21k pretraining. - self.representation_size = representation_size - self.pre_logits = nn.Identity() - if representation_size: - self._reset_representation(representation_size) - # Classifier Head self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() - final_chs = self.representation_size if self.representation_size else self.embed_dim - self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() if weight_init != 'skip': self.init_weights(weight_init) - def _reset_representation(self, representation_size): - self.representation_size = representation_size - if self.representation_size: - self.pre_logits = nn.Sequential(OrderedDict([ - ('fc', nn.Linear(self.embed_dim, self.representation_size)), - ('act', nn.Tanh()) - ])) - else: - self.pre_logits = nn.Identity() - def init_weights(self, mode=''): assert mode in ('jax', 'jax_nlhb', 'moco', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. trunc_normal_(self.pos_embed, std=.02) - nn.init.normal_(self.cls_token, std=1e-6) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) named_apply(get_init_weights_vit(mode, head_bias), self) def _init_weights(self, m): @@ -401,19 +421,17 @@ class VisionTransformer(nn.Module): def get_classifier(self): return self.head - def reset_classifier(self, num_classes: int, global_pool=None, representation_size=None): + def reset_classifier(self, num_classes: int, global_pool=None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'avg', 'token') self.global_pool = global_pool - if representation_size is not None: - self._reset_representation(representation_size) - final_chs = self.representation_size if self.representation_size else self.embed_dim - self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.patch_embed(x) - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = self.pos_drop(x + self.pos_embed) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) @@ -424,9 +442,8 @@ class VisionTransformer(nn.Module): def forward_head(self, x, pre_logits: bool = False): if self.global_pool: - x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.fc_norm(x) - x = self.pre_logits(x) return x if pre_logits else self.head(x) def forward(self, x): @@ -441,6 +458,8 @@ def init_weights_vit_timm(module: nn.Module, name: str = ''): trunc_normal_(module.weight, std=.02) if module.bias is not None: nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.): @@ -449,9 +468,6 @@ def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0 if name.startswith('head'): nn.init.zeros_(module.weight) nn.init.constant_(module.bias, head_bias) - elif name.startswith('pre_logits'): - lecun_normal_(module.weight) - nn.init.zeros_(module.bias) else: nn.init.xavier_uniform_(module.weight) if module.bias is not None: @@ -460,6 +476,8 @@ def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0 lecun_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() def init_weights_vit_moco(module: nn.Module, name: str = ''): @@ -473,6 +491,8 @@ def init_weights_vit_moco(module: nn.Module, name: str = ''): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() def get_init_weights_vit(mode='jax', head_bias: float = 0.): @@ -543,9 +563,10 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) - if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: - model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) - model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights + # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: + # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) + # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) for i, block in enumerate(model.blocks.children()): block_prefix = f'{prefix}Transformer/encoderblock_{i}/' mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' @@ -601,6 +622,9 @@ def checkpoint_filter_fn(state_dict, model): # To resize pos embedding when using model at different size from pretrained weights v = resize_pos_embed( v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + elif 'pre_logits' in k: + # NOTE representation layer removed as not used in latest 21k/1k pretrained weights + continue out_dict[k] = v return out_dict @@ -609,21 +633,10 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') - # NOTE this extra code to support handling of repr size for in21k pretrained models pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) - default_num_classes = pretrained_cfg['num_classes'] - num_classes = kwargs.get('num_classes', default_num_classes) - repr_size = kwargs.pop('representation_size', None) - if repr_size is not None and num_classes != default_num_classes: - # Remove representation layer if fine-tuning. This may not always be the desired action, - # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? - _logger.warning("Removing representation layer for fine-tuning.") - repr_size = None - model = build_model_with_cfg( VisionTransformer, variant, pretrained, pretrained_cfg=pretrained_cfg, - representation_size=repr_size, pretrained_filter_fn=checkpoint_filter_fn, pretrained_custom_load='npz' in pretrained_cfg['url'], **kwargs) @@ -696,16 +709,6 @@ def vit_base_patch32_224(pretrained=False, **kwargs): return model -@register_model -def vit_base2_patch32_256(pretrained=False, **kwargs): - """ ViT-Base (ViT-B/32) - # FIXME experiment - """ - model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, **kwargs) - model = _create_vision_transformer('vit_base2_patch32_256', pretrained=pretrained, **model_kwargs) - return model - - @register_model def vit_base_patch32_384(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). @@ -860,8 +863,7 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ - model_kwargs = dict( - patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -872,8 +874,7 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ - model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -884,8 +885,7 @@ def vit_base_patch8_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ - model_kwargs = dict( - patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) + model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch8_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -896,8 +896,7 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights """ - model_kwargs = dict( - patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -908,8 +907,7 @@ def vit_large_patch16_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ - model_kwargs = dict( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -920,8 +918,7 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights """ - model_kwargs = dict( - patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs) model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -930,7 +927,6 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): def vit_base_patch16_224_sam(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 """ - # NOTE original SAM weights release worked with representation_size=768 model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch16_224_sam', pretrained=pretrained, **model_kwargs) return model @@ -940,7 +936,6 @@ def vit_base_patch16_224_sam(pretrained=False, **kwargs): def vit_base_patch32_224_sam(pretrained=False, **kwargs): """ ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 """ - # NOTE original SAM weights release worked with representation_size=768 model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch32_224_sam', pretrained=pretrained, **model_kwargs) return model @@ -1002,6 +997,37 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs): return model +# Experimental models below + +@register_model +def vit_base_patch32_plus_256(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/32+) + """ + model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs) + model = _create_vision_transformer('vit_base_patch32_plus_256', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_plus_240(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16+) + """ + model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs) + model = _create_vision_transformer('vit_base_patch16_plus_240', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_rpn_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ residual post-norm + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, class_token=False, + block_fn=ResPostBlock, global_pool=kwargs.pop('global_pool', 'avg'), **kwargs) + model = _create_vision_transformer('vit_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_small_patch16_36x1_224(pretrained=False, **kwargs): """ ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove. diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 0eee2044..24ff2096 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -295,7 +295,7 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. """ backbone = _resnetv2(layers=(3, 4, 9), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer_hybrid( 'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) return model diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py new file mode 100644 index 00000000..0c2ac376 --- /dev/null +++ b/timm/models/vision_transformer_relpos.py @@ -0,0 +1,558 @@ +""" Relative Position Vision Transformer (ViT) in PyTorch + +NOTE: these models are experimental / WIP, expect changes + +Hacked together by / Copyright 2022, Ross Wightman +""" +import math +import logging +from functools import partial +from collections import OrderedDict +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply +from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple +from .registry import register_model + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'vit_relpos_base_patch32_plus_rpn_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_replos_base_patch32_plus_rpn_256-sw-dd486f51.pth', + input_size=(3, 256, 256)), + 'vit_relpos_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240)), + + 'vit_relpos_small_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_small_patch16_224-sw-ec2778b4.pth'), + 'vit_relpos_medium_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_224-sw-11c174af.pth'), + 'vit_relpos_base_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth'), + + 'vit_relpos_base_patch16_cls_224': _cfg( + url=''), + 'vit_relpos_base_patch16_gapcls_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth'), + + 'vit_relpos_small_patch16_rpn_224': _cfg(url=''), + 'vit_relpos_medium_patch16_rpn_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_rpn_224-sw-5d2befd8.pth'), + 'vit_relpos_base_patch16_rpn_224': _cfg(url=''), +} + + +def gen_relative_position_index(win_size: Tuple[int, int], class_token: int = 0) -> torch.Tensor: + # cut and paste w/ modifications from swin / beit codebase + # cls to token & token 2 cls & cls to cls + # get pair-wise relative position index for each token inside the window + window_area = win_size[0] * win_size[1] + coords = torch.stack(torch.meshgrid([torch.arange(win_size[0]), torch.arange(win_size[1])])).flatten(1) # 2, Wh, Ww + relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += win_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += win_size[1] - 1 + relative_coords[:, :, 0] *= 2 * win_size[1] - 1 + if class_token: + num_relative_distance = (2 * win_size[0] - 1) * (2 * win_size[1] - 1) + 3 + relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = num_relative_distance - 3 + relative_position_index[0:, 0] = num_relative_distance - 2 + relative_position_index[0, 0] = num_relative_distance - 1 + else: + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + return relative_position_index + + +def gen_relative_log_coords( + win_size: Tuple[int, int], + pretrained_win_size: Tuple[int, int] = (0, 0), + mode='swin' +): + # as per official swin-v2 impl, supporting timm swin-v2-cr coords as well + relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32) + relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) + relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2 + if mode == 'swin': + if pretrained_win_size[0] > 0: + relative_coords_table[:, :, 0] /= (pretrained_win_size[0] - 1) + relative_coords_table[:, :, 1] /= (pretrained_win_size[1] - 1) + else: + relative_coords_table[:, :, 0] /= (win_size[0] - 1) + relative_coords_table[:, :, 1] /= (win_size[1] - 1) + relative_coords_table *= 8 # normalize to -8, 8 + scale = math.log2(8) + else: + # FIXME we should support a form of normalization (to -1/1) for this mode? + scale = math.log2(math.e) + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + 1.0 + relative_coords_table.abs()) / scale + return relative_coords_table + + +class RelPosMlp(nn.Module): + def __init__( + self, + window_size, + num_heads=8, + hidden_dim=128, + class_token=False, + mode='cr', + pretrained_window_size=(0, 0) + ): + super().__init__() + self.window_size = window_size + self.window_area = self.window_size[0] * self.window_size[1] + self.class_token = 1 if class_token else 0 + self.num_heads = num_heads + self.bias_shape = (self.window_area,) * 2 + (num_heads,) + self.apply_sigmoid = mode == 'swin' + + mlp_bias = (True, False) if mode == 'swin' else True + self.mlp = Mlp( + 2, # x, y + hidden_features=hidden_dim, + out_features=num_heads, + act_layer=nn.ReLU, + bias=mlp_bias, + drop=(0.125, 0.) + ) + + self.register_buffer( + "relative_position_index", + gen_relative_position_index(window_size), + persistent=False) + + # get relative_coords_table + self.register_buffer( + "rel_coords_log", + gen_relative_log_coords(window_size, pretrained_window_size, mode=mode), + persistent=False) + + def get_bias(self) -> torch.Tensor: + relative_position_bias = self.mlp(self.rel_coords_log) + if self.relative_position_index is not None: + relative_position_bias = relative_position_bias.view(-1, self.num_heads)[ + self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.view(self.bias_shape) + relative_position_bias = relative_position_bias.permute(2, 0, 1) + if self.apply_sigmoid: + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + if self.class_token: + relative_position_bias = F.pad(relative_position_bias, [self.class_token, 0, self.class_token, 0]) + return relative_position_bias.unsqueeze(0).contiguous() + + def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): + return attn + self.get_bias() + + +class RelPosBias(nn.Module): + + def __init__(self, window_size, num_heads, class_token=False): + super().__init__() + self.window_size = window_size + self.window_area = window_size[0] * window_size[1] + self.class_token = 1 if class_token else 0 + self.bias_shape = (self.window_area + self.class_token,) * 2 + (num_heads,) + + num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * self.class_token + self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads)) + self.register_buffer( + "relative_position_index", + gen_relative_position_index(self.window_size, class_token=self.class_token), + persistent=False, + ) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.relative_position_bias_table, std=.02) + + def get_bias(self) -> torch.Tensor: + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)] + # win_h * win_w, win_h * win_w, num_heads + relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1) + return relative_position_bias.unsqueeze(0).contiguous() + + def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): + return attn + self.get_bias() + + +class RelPosAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, rel_pos_cls=None, attn_drop=0., proj_drop=0.): + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.rel_pos = rel_pos_cls(num_heads=num_heads) if rel_pos_cls else None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + if self.rel_pos is not None: + attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos) + elif shared_rel_pos is not None: + attn = attn + shared_rel_pos + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class RelPosBlock(nn.Module): + + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, rel_pos_cls=None, init_values=None, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = RelPosAttention( + dim, num_heads, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=drop) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class ResPostRelPosBlock(nn.Module): + + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, rel_pos_cls=None, init_values=None, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.init_values = init_values + + self.attn = RelPosAttention( + dim, num_heads, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=drop) + self.norm1 = norm_layer(dim) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.norm2 = norm_layer(dim) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.init_weights() + + def init_weights(self): + # NOTE this init overrides that base model init with specific changes for the block type + if self.init_values is not None: + nn.init.constant_(self.norm1.weight, self.init_values) + nn.init.constant_(self.norm2.weight, self.init_values) + + def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + x = x + self.drop_path1(self.norm1(self.attn(x, shared_rel_pos=shared_rel_pos))) + x = x + self.drop_path2(self.norm2(self.mlp(x))) + return x + + +class VisionTransformerRelPos(nn.Module): + """ Vision Transformer w/ Relative Position Bias + + Differing from classic vit, this impl + * uses relative position index (swin v1 / beit) or relative log coord + mlp (swin v2) pos embed + * defaults to no class token (can be enabled) + * defaults to global avg pool for head (can be changed) + * layer-scale (residual branch gain) enabled + """ + + def __init__( + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg', + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=1e-6, + class_token=False, fc_norm=False, rel_pos_type='mlp', shared_rel_pos=False, rel_pos_dim=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip', + embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=RelPosBlock): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + global_pool (str): type of global pooling for final sequence (default: 'avg') + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + init_values: (float): layer-scale init values + class_token (bool): use class token (default: False) + fc_norm (bool): use pre classifier norm instead of pre-pool + rel_pos_ty pe (str): type of relative position + shared_rel_pos (bool): share relative pos across all blocks + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + weight_init (str): weight init scheme + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + act_layer: (nn.Module): MLP activation layer + """ + super().__init__() + assert global_pool in ('', 'avg', 'token') + assert class_token or global_pool != 'token' + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 if class_token else 0 + self.grad_checkpointing = False + + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + feat_size = self.patch_embed.grid_size + + rel_pos_args = dict(window_size=feat_size, class_token=class_token) + if rel_pos_type.startswith('mlp'): + if rel_pos_dim: + rel_pos_args['hidden_dim'] = rel_pos_dim + if 'swin' in rel_pos_type: + rel_pos_args['mode'] = 'swin' + rel_pos_cls = partial(RelPosMlp, **rel_pos_args) + else: + rel_pos_cls = partial(RelPosBias, **rel_pos_args) + self.shared_rel_pos = None + if shared_rel_pos: + self.shared_rel_pos = rel_pos_cls(num_heads=num_heads) + # NOTE shared rel pos currently mutually exclusive w/ per-block, but could support both... + rel_pos_cls = None + + self.cls_token = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) if self.num_tokens else None + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + block_fn( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, + init_values=init_values, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], + norm_layer=norm_layer, act_layer=act_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) if not fc_norm else nn.Identity() + + # Classifier Head + self.fc_norm = norm_layer(embed_dim) if fc_norm else nn.Identity() + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if weight_init != 'skip': + self.init_weights(weight_init) + + def init_weights(self, mode=''): + assert mode in ('jax', 'moco', '') + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + # FIXME weight init scheme using PyTorch defaults curently + #named_apply(get_init_weights_vit(mode, head_bias), self) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_token'} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^cls_token|patch_embed', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes: int, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'avg', 'token') + self.global_pool = global_pool + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + + shared_rel_pos = self.shared_rel_pos.get_bias() if self.shared_rel_pos is not None else None + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, shared_rel_pos=shared_rel_pos) + else: + x = blk(x, shared_rel_pos=shared_rel_pos) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool: + x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = self.fc_norm(x) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_vision_transformer_relpos(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg(VisionTransformerRelPos, variant, pretrained, **kwargs) + return model + + +@register_model +def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token + """ + model_kwargs = dict( + patch_size=32, embed_dim=896, depth=12, num_heads=14, block_fn=ResPostRelPosBlock, **kwargs) + model = _create_vision_transformer_relpos( + 'vit_relpos_base_patch32_plus_rpn_256', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_base_patch16_plus_240(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16+) w/ relative log-coord position, no class token + """ + model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, **kwargs) + model = _create_vision_transformer_relpos('vit_relpos_base_patch16_plus_240', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_small_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ relative log-coord position, no class token + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, qkv_bias=False, fc_norm=True, **kwargs) + model = _create_vision_transformer_relpos('vit_relpos_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_medium_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ relative log-coord position, no class token + """ + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=True, **kwargs) + model = _create_vision_transformer_relpos('vit_relpos_medium_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_base_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ relative log-coord position, no class token + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, **kwargs) + model = _create_vision_transformer_relpos('vit_relpos_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, + class_token=True, global_pool='token', **kwargs) + model = _create_vision_transformer_relpos('vit_relpos_base_patch16_cls_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_base_patch16_gapcls_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present + NOTE this config is a bit of a mistake, class token was enabled but global avg-pool w/ fc-norm was not disabled + Leaving here for comparisons w/ a future re-train as it performs quite well. + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, class_token=True, **kwargs) + model = _create_vision_transformer_relpos('vit_relpos_base_patch16_gapcls_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_small_patch16_rpn_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, qkv_bias=False, block_fn=ResPostRelPosBlock, **kwargs) + model = _create_vision_transformer_relpos( + 'vit_relpos_small_patch16_rpn_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_medium_patch16_rpn_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token + """ + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, block_fn=ResPostRelPosBlock, **kwargs) + model = _create_vision_transformer_relpos( + 'vit_relpos_medium_patch16_rpn_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_base_patch16_rpn_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, block_fn=ResPostRelPosBlock, **kwargs) + model = _create_vision_transformer_relpos( + 'vit_relpos_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs) + return model diff --git a/timm/scheduler/scheduler_factory.py b/timm/scheduler/scheduler_factory.py index 72a979c2..3e100fe0 100644 --- a/timm/scheduler/scheduler_factory.py +++ b/timm/scheduler/scheduler_factory.py @@ -71,7 +71,7 @@ def create_scheduler(args, optimizer): elif args.sched == 'multistep': lr_scheduler = MultiStepLRScheduler( optimizer, - decay_t=args.decay_epochs, + decay_t=args.decay_milestones, decay_rate=args.decay_rate, warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, diff --git a/timm/utils/jit.py b/timm/utils/jit.py index 6039823f..a32cbd40 100644 --- a/timm/utils/jit.py +++ b/timm/utils/jit.py @@ -34,9 +34,9 @@ def set_jit_fuser(fuser): torch._C._jit_override_can_fuse_on_gpu(True) torch._C._jit_set_texpr_fuser_enabled(False) elif fuser == "nvfuser" or fuser == "nvf": - os.environ['PYTORCH_CUDA_FUSER_DISABLE_FALLBACK'] = '1' - os.environ['PYTORCH_CUDA_FUSER_DISABLE_FMA'] = '1' - os.environ['PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL'] = '0' + os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK'] = '1' + #os.environ['PYTORCH_NVFUSER_DISABLE_FMA'] = '1' + #os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0' torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_mode(True) diff --git a/timm/version.py b/timm/version.py index 8411e551..06f971e2 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.6.1' +__version__ = '0.7.0.dev0' diff --git a/train.py b/train.py index d1d218f7..cec7efeb 100755 --- a/train.py +++ b/train.py @@ -156,6 +156,8 @@ parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N', help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') parser.add_argument('--start-epoch', default=None, type=int, metavar='N', help='manual epoch number (useful on restarts)') +parser.add_argument('--decay-milestones', default=[30, 60], type=int, nargs='+', metavar="MILESTONES", + help='list of decay epoch indices for multistep lr. must be increasing') parser.add_argument('--decay-epochs', type=float, default=100, metavar='N', help='epoch interval to decay LR') parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',