From 8f0bc0591e8fe4803cae892f860d75401b517d6c Mon Sep 17 00:00:00 2001 From: SeeFun Date: Tue, 5 Apr 2022 20:00:57 +0800 Subject: [PATCH 01/32] fix convnext args --- timm/models/convnext.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 9fd4525a..1aacef2b 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -421,14 +421,14 @@ def convnext_large(pretrained=False, **kwargs): @register_model def convnext_tiny_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) model = _create_convnext('convnext_tiny_in22ft1k', pretrained=pretrained, **model_args) return model @register_model def convnext_small_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) model = _create_convnext('convnext_small_in22ft1k', pretrained=pretrained, **model_args) return model @@ -456,14 +456,14 @@ def convnext_xlarge_in22ft1k(pretrained=False, **kwargs): @register_model def convnext_tiny_384_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) model = _create_convnext('convnext_tiny_384_in22ft1k', pretrained=pretrained, **model_args) return model @register_model def convnext_small_384_in22ft1k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) model = _create_convnext('convnext_small_384_in22ft1k', pretrained=pretrained, **model_args) return model @@ -491,14 +491,14 @@ def convnext_xlarge_384_in22ft1k(pretrained=False, **kwargs): @register_model def convnext_tiny_in22k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) model = _create_convnext('convnext_tiny_in22k', pretrained=pretrained, **model_args) return model @register_model def convnext_small_in22k(pretrained=False, **kwargs): - model_args = dict(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + model_args = dict(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) model = _create_convnext('convnext_small_in22k', pretrained=pretrained, **model_args) return model From fbf597049cee9f0b8e2aff2be2a37abea6d22460 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 22 Apr 2022 16:52:05 -0700 Subject: [PATCH 02/32] Update README and change timmdocs link in documentation --- README.md | 8 +- docs/archived_changes.md | 129 ++++++++++++++++++++ docs/changes.md | 254 +++++++++++++++++++-------------------- docs/index.md | 8 +- docs/scripts.md | 2 +- 5 files changed, 268 insertions(+), 133 deletions(-) diff --git a/README.md b/README.md index e79845b3..355cedaf 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,12 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor ## What's New +### April 22, 2022 +* `timm` models are now officially supported in [fast.ai](https://www.fast.ai/)! Just in time for the new Practical Deep Learning course. `timmdocs` documentation link updated to [timm.fast.ai](http://timm.fast.ai/). +* Two more model weights added in the TPU trained [series](https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights). Some In22k pretrain still in progress. + * `seresnext101d_32x8d` - 83.69 @ 224, 84.35 @ 288 + * `seresnextaa101d_32x8d` (anti-aliased w/ AvgPool2d) - 83.85 @ 224, 84.57 @ 288 + ### March 23, 2022 * Add `ParallelBlock` and `LayerScale` option to base vit models to support model configs in [Three things everyone should know about ViT](https://arxiv.org/abs/2203.09795) * `convnext_tiny_hnf` (head norm first) weights trained with (close to) A2 recipe, 82.2% top-1, could do better with more epochs. @@ -462,7 +468,7 @@ My current [documentation](https://rwightman.github.io/pytorch-image-models/) fo [Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) by [Chris Hughes](https://github.com/Chris-hughes10) is an extensive blog post covering many aspects of `timm` in detail. -[timmdocs](https://fastai.github.io/timmdocs/) is quickly becoming a much more comprehensive set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs. +[timmdocs](http://timm.fast.ai/) is quickly becoming a much more comprehensive set of documentation for `timm`. A big thanks to [Aman Arora](https://github.com/amaarora) for his efforts creating timmdocs. [paperswithcode](https://paperswithcode.com/lib/timm) is a good resource for browsing the models within `timm`. diff --git a/docs/archived_changes.md b/docs/archived_changes.md index f8d88fd7..36b7b9a1 100644 --- a/docs/archived_changes.md +++ b/docs/archived_changes.md @@ -1,5 +1,134 @@ # Archived Changes +### June 8, 2021 +* Add first ResMLP weights, trained in PyTorch XLA on TPU-VM w/ my XLA branch. 24 block variant, 79.2 top-1. +* Add ResNet51-Q model w/ pretrained weights at 82.36 top-1. + * NFNet inspired block layout with quad layer stem and no maxpool + * Same param count (35.7M) and throughput as ResNetRS-50 but +1.5 top-1 @ 224x224 and +2.5 top-1 at 288x288 + +### May 25, 2021 +* Add LeViT, Visformer, Convit (PR by Aman Arora), Twins (PR by paper authors) transformer models +* Cleanup input_size/img_size override handling and testing for all vision transformer models +* Add `efficientnetv2_rw_m` model and weights (started training before official code). 84.8 top-1, 53M params. + +### May 14, 2021 +* Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl. + * 1k trained variants: `tf_efficientnetv2_s/m/l` + * 21k trained variants: `tf_efficientnetv2_s/m/l_in21k` + * 21k pretrained -> 1k fine-tuned: `tf_efficientnetv2_s/m/l_in21ft1k` + * v2 models w/ v1 scaling: `tf_efficientnetv2_b0` through `b3` + * Rename my prev V2 guess `efficientnet_v2s` -> `efficientnetv2_rw_s` + * Some blank `efficientnetv2_*` models in-place for future native PyTorch training + +### May 5, 2021 +* Add MLP-Mixer models and port pretrained weights from [Google JAX impl](https://github.com/google-research/vision_transformer/tree/linen) +* Add CaiT models and pretrained weights from [FB](https://github.com/facebookresearch/deit) +* Add ResNet-RS models and weights from [TF](https://github.com/tensorflow/tpu/tree/master/models/official/resnet/resnet_rs). Thanks [Aman Arora](https://github.com/amaarora) +* Add CoaT models and weights. Thanks [Mohammed Rizin](https://github.com/morizin) +* Add new ImageNet-21k weights & finetuned weights for TResNet, MobileNet-V3, ViT models. Thanks [mrT](https://github.com/mrT23) +* Add GhostNet models and weights. Thanks [Kai Han](https://github.com/iamhankai) +* Update ByoaNet attention modles + * Improve SA module inits + * Hack together experimental stand-alone Swin based attn module and `swinnet` + * Consistent '26t' model defs for experiments. +* Add improved Efficientnet-V2S (prelim model def) weights. 83.8 top-1. +* WandB logging support + +### April 13, 2021 +* Add Swin Transformer models and weights from https://github.com/microsoft/Swin-Transformer + +### April 12, 2021 +* Add ECA-NFNet-L1 (slimmed down F1 w/ SiLU, 41M params) trained with this code. 84% top-1 @ 320x320. Trained at 256x256. +* Add EfficientNet-V2S model (unverified model definition) weights. 83.3 top-1 @ 288x288. Only trained single res 224. Working on progressive training. +* Add ByoaNet model definition (Bring-your-own-attention) w/ SelfAttention block and corresponding SA/SA-like modules and model defs + * Lambda Networks - https://arxiv.org/abs/2102.08602 + * Bottleneck Transformers - https://arxiv.org/abs/2101.11605 + * Halo Nets - https://arxiv.org/abs/2103.12731 +* Adabelief optimizer contributed by Juntang Zhuang + +### April 1, 2021 +* Add snazzy `benchmark.py` script for bulk `timm` model benchmarking of train and/or inference +* Add Pooling-based Vision Transformer (PiT) models (from https://github.com/naver-ai/pit) + * Merged distilled variant into main for torchscript compatibility + * Some `timm` cleanup/style tweaks and weights have hub download support +* Cleanup Vision Transformer (ViT) models + * Merge distilled (DeiT) model into main so that torchscript can work + * Support updated weight init (defaults to old still) that closer matches original JAX impl (possibly better training from scratch) + * Separate hybrid model defs into different file and add several new model defs to fiddle with, support patch_size != 1 for hybrids + * Fix fine-tuning num_class changes (PiT and ViT) and pos_embed resizing (Vit) with distilled variants + * nn.Sequential for block stack (does not break downstream compat) +* TnT (Transformer-in-Transformer) models contributed by author (from https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT) +* Add RegNetY-160 weights from DeiT teacher model +* Add new NFNet-L0 w/ SE attn (rename `nfnet_l0b`->`nfnet_l0`) weights 82.75 top-1 @ 288x288 +* Some fixes/improvements for TFDS dataset wrapper + +### March 7, 2021 +* First 0.4.x PyPi release w/ NFNets (& related), ByoB (GPU-Efficient, RepVGG, etc). +* Change feature extraction for pre-activation nets (NFNets, ResNetV2) to return features before activation. + +### Feb 18, 2021 +* Add pretrained weights and model variants for NFNet-F* models from [DeepMind Haiku impl](https://github.com/deepmind/deepmind-research/tree/master/nfnets). + * Models are prefixed with `dm_`. They require SAME padding conv, skipinit enabled, and activation gains applied in act fn. + * These models are big, expect to run out of GPU memory. With the GELU activiation + other options, they are roughly 1/2 the inference speed of my SiLU PyTorch optimized `s` variants. + * Original model results are based on pre-processing that is not the same as all other models so you'll see different results in the results csv (once updated). + * Matching the original pre-processing as closely as possible I get these results: + * `dm_nfnet_f6` - 86.352 + * `dm_nfnet_f5` - 86.100 + * `dm_nfnet_f4` - 85.834 + * `dm_nfnet_f3` - 85.676 + * `dm_nfnet_f2` - 85.178 + * `dm_nfnet_f1` - 84.696 + * `dm_nfnet_f0` - 83.464 + +### Feb 16, 2021 +* Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py. + * AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc` + * PyTorch global norm of 1.0 (old behaviour, always norm), `--clip-grad 1.0` + * PyTorch value clipping of 10, `--clip-grad 10. --clip-mode value` + * AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training w/ NFNet/NF-ResNet. + +### Feb 12, 2021 +* Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs + +### Feb 10, 2021 +* More model archs, incl a flexible ByobNet backbone ('Bring-your-own-blocks') + * GPU-Efficient-Networks (https://github.com/idstcv/GPU-Efficient-Networks), impl in `byobnet.py` + * RepVGG (https://github.com/DingXiaoH/RepVGG), impl in `byobnet.py` + * classic VGG (from torchvision, impl in `vgg`) +* Refinements to normalizer layer arg handling and normalizer+act layer handling in some models +* Default AMP mode changed to native PyTorch AMP instead of APEX. Issues not being fixed with APEX. Native works with `--channels-last` and `--torchscript` model training, APEX does not. +* Fix a few bugs introduced since last pypi release + +### Feb 8, 2021 +* Add several ResNet weights with ECA attention. 26t & 50t trained @ 256, test @ 320. 269d train @ 256, fine-tune @320, test @ 352. + * `ecaresnet26t` - 79.88 top-1 @ 320x320, 79.08 @ 256x256 + * `ecaresnet50t` - 82.35 top-1 @ 320x320, 81.52 @ 256x256 + * `ecaresnet269d` - 84.93 top-1 @ 352x352, 84.87 @ 320x320 +* Remove separate tiered (`t`) vs tiered_narrow (`tn`) ResNet model defs, all `tn` changed to `t` and `t` models removed (`seresnext26t_32x4d` only model w/ weights that was removed). +* Support model default_cfgs with separate train vs test resolution `test_input_size` and remove extra `_320` suffix ResNet model defs that were just for test. + +### Jan 30, 2021 +* Add initial "Normalization Free" NF-RegNet-B* and NF-ResNet model definitions based on [paper](https://arxiv.org/abs/2101.08692) + +### Jan 25, 2021 +* Add ResNetV2 Big Transfer (BiT) models w/ ImageNet-1k and 21k weights from https://github.com/google-research/big_transfer +* Add official R50+ViT-B/16 hybrid models + weights from https://github.com/google-research/vision_transformer +* ImageNet-21k ViT weights are added w/ model defs and representation layer (pre logits) support + * NOTE: ImageNet-21k classifier heads were zero'd in original weights, they are only useful for transfer learning +* Add model defs and weights for DeiT Vision Transformer models from https://github.com/facebookresearch/deit +* Refactor dataset classes into ImageDataset/IterableImageDataset + dataset specific parser classes +* Add Tensorflow-Datasets (TFDS) wrapper to allow use of TFDS image classification sets with train script + * Ex: `train.py /data/tfds --dataset tfds/oxford_iiit_pet --val-split test --model resnet50 -b 256 --amp --num-classes 37 --opt adamw --lr 3e-4 --weight-decay .001 --pretrained -j 2` +* Add improved .tar dataset parser that reads images from .tar, folder of .tar files, or .tar within .tar + * Run validation on full ImageNet-21k directly from tar w/ BiT model: `validate.py /data/fall11_whole.tar --model resnetv2_50x1_bitm_in21k --amp` +* Models in this update should be stable w/ possible exception of ViT/BiT, possibility of some regressions with train/val scripts and dataset handling + +### Jan 3, 2021 +* Add SE-ResNet-152D weights + * 256x256 val, 0.94 crop top-1 - 83.75 + * 320x320 val, 1.0 crop - 84.36 +* Update results files + ### Dec 18, 2020 * Add ResNet-101D, ResNet-152D, and ResNet-200D weights trained @ 256x256 * 256x256 val, 0.94 crop (top-1) - 101D (82.33), 152D (83.08), 200D (83.25) diff --git a/docs/changes.md b/docs/changes.md index 6ff50756..d2965e8f 100644 --- a/docs/changes.md +++ b/docs/changes.md @@ -1,130 +1,130 @@ # Recent Changes -### June 8, 2021 -* Add first ResMLP weights, trained in PyTorch XLA on TPU-VM w/ my XLA branch. 24 block variant, 79.2 top-1. -* Add ResNet51-Q model w/ pretrained weights at 82.36 top-1. - * NFNet inspired block layout with quad layer stem and no maxpool - * Same param count (35.7M) and throughput as ResNetRS-50 but +1.5 top-1 @ 224x224 and +2.5 top-1 at 288x288 -### May 25, 2021 -* Add LeViT, Visformer, Convit (PR by Aman Arora), Twins (PR by paper authors) transformer models -* Cleanup input_size/img_size override handling and testing for all vision transformer models -* Add `efficientnetv2_rw_m` model and weights (started training before official code). 84.8 top-1, 53M params. - -### May 14, 2021 -* Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl. - * 1k trained variants: `tf_efficientnetv2_s/m/l` - * 21k trained variants: `tf_efficientnetv2_s/m/l_in21k` - * 21k pretrained -> 1k fine-tuned: `tf_efficientnetv2_s/m/l_in21ft1k` - * v2 models w/ v1 scaling: `tf_efficientnetv2_b0` through `b3` - * Rename my prev V2 guess `efficientnet_v2s` -> `efficientnetv2_rw_s` - * Some blank `efficientnetv2_*` models in-place for future native PyTorch training - -### May 5, 2021 -* Add MLP-Mixer models and port pretrained weights from [Google JAX impl](https://github.com/google-research/vision_transformer/tree/linen) -* Add CaiT models and pretrained weights from [FB](https://github.com/facebookresearch/deit) -* Add ResNet-RS models and weights from [TF](https://github.com/tensorflow/tpu/tree/master/models/official/resnet/resnet_rs). Thanks [Aman Arora](https://github.com/amaarora) -* Add CoaT models and weights. Thanks [Mohammed Rizin](https://github.com/morizin) -* Add new ImageNet-21k weights & finetuned weights for TResNet, MobileNet-V3, ViT models. Thanks [mrT](https://github.com/mrT23) -* Add GhostNet models and weights. Thanks [Kai Han](https://github.com/iamhankai) -* Update ByoaNet attention modles - * Improve SA module inits - * Hack together experimental stand-alone Swin based attn module and `swinnet` - * Consistent '26t' model defs for experiments. -* Add improved Efficientnet-V2S (prelim model def) weights. 83.8 top-1. -* WandB logging support - -### April 13, 2021 -* Add Swin Transformer models and weights from https://github.com/microsoft/Swin-Transformer - -### April 12, 2021 -* Add ECA-NFNet-L1 (slimmed down F1 w/ SiLU, 41M params) trained with this code. 84% top-1 @ 320x320. Trained at 256x256. -* Add EfficientNet-V2S model (unverified model definition) weights. 83.3 top-1 @ 288x288. Only trained single res 224. Working on progressive training. -* Add ByoaNet model definition (Bring-your-own-attention) w/ SelfAttention block and corresponding SA/SA-like modules and model defs - * Lambda Networks - https://arxiv.org/abs/2102.08602 - * Bottleneck Transformers - https://arxiv.org/abs/2101.11605 - * Halo Nets - https://arxiv.org/abs/2103.12731 -* Adabelief optimizer contributed by Juntang Zhuang - -### April 1, 2021 -* Add snazzy `benchmark.py` script for bulk `timm` model benchmarking of train and/or inference -* Add Pooling-based Vision Transformer (PiT) models (from https://github.com/naver-ai/pit) - * Merged distilled variant into main for torchscript compatibility - * Some `timm` cleanup/style tweaks and weights have hub download support -* Cleanup Vision Transformer (ViT) models - * Merge distilled (DeiT) model into main so that torchscript can work - * Support updated weight init (defaults to old still) that closer matches original JAX impl (possibly better training from scratch) - * Separate hybrid model defs into different file and add several new model defs to fiddle with, support patch_size != 1 for hybrids - * Fix fine-tuning num_class changes (PiT and ViT) and pos_embed resizing (Vit) with distilled variants - * nn.Sequential for block stack (does not break downstream compat) -* TnT (Transformer-in-Transformer) models contributed by author (from https://gitee.com/mindspore/mindspore/tree/master/model_zoo/research/cv/TNT) -* Add RegNetY-160 weights from DeiT teacher model -* Add new NFNet-L0 w/ SE attn (rename `nfnet_l0b`->`nfnet_l0`) weights 82.75 top-1 @ 288x288 -* Some fixes/improvements for TFDS dataset wrapper - -### March 7, 2021 -* First 0.4.x PyPi release w/ NFNets (& related), ByoB (GPU-Efficient, RepVGG, etc). -* Change feature extraction for pre-activation nets (NFNets, ResNetV2) to return features before activation. - -### Feb 18, 2021 -* Add pretrained weights and model variants for NFNet-F* models from [DeepMind Haiku impl](https://github.com/deepmind/deepmind-research/tree/master/nfnets). - * Models are prefixed with `dm_`. They require SAME padding conv, skipinit enabled, and activation gains applied in act fn. - * These models are big, expect to run out of GPU memory. With the GELU activiation + other options, they are roughly 1/2 the inference speed of my SiLU PyTorch optimized `s` variants. - * Original model results are based on pre-processing that is not the same as all other models so you'll see different results in the results csv (once updated). - * Matching the original pre-processing as closely as possible I get these results: - * `dm_nfnet_f6` - 86.352 - * `dm_nfnet_f5` - 86.100 - * `dm_nfnet_f4` - 85.834 - * `dm_nfnet_f3` - 85.676 - * `dm_nfnet_f2` - 85.178 - * `dm_nfnet_f1` - 84.696 - * `dm_nfnet_f0` - 83.464 - -### Feb 16, 2021 -* Add Adaptive Gradient Clipping (AGC) as per https://arxiv.org/abs/2102.06171. Integrated w/ PyTorch gradient clipping via mode arg that defaults to prev 'norm' mode. For backward arg compat, clip-grad arg must be specified to enable when using train.py. - * AGC w/ default clipping factor `--clip-grad .01 --clip-mode agc` - * PyTorch global norm of 1.0 (old behaviour, always norm), `--clip-grad 1.0` - * PyTorch value clipping of 10, `--clip-grad 10. --clip-mode value` - * AGC performance is definitely sensitive to the clipping factor. More experimentation needed to determine good values for smaller batch sizes and optimizers besides those in paper. So far I've found .001-.005 is necessary for stable RMSProp training w/ NFNet/NF-ResNet. - -### Feb 12, 2021 -* Update Normalization-Free nets to include new NFNet-F (https://arxiv.org/abs/2102.06171) model defs - -### Feb 10, 2021 -* More model archs, incl a flexible ByobNet backbone ('Bring-your-own-blocks') - * GPU-Efficient-Networks (https://github.com/idstcv/GPU-Efficient-Networks), impl in `byobnet.py` - * RepVGG (https://github.com/DingXiaoH/RepVGG), impl in `byobnet.py` - * classic VGG (from torchvision, impl in `vgg`) -* Refinements to normalizer layer arg handling and normalizer+act layer handling in some models -* Default AMP mode changed to native PyTorch AMP instead of APEX. Issues not being fixed with APEX. Native works with `--channels-last` and `--torchscript` model training, APEX does not. -* Fix a few bugs introduced since last pypi release - -### Feb 8, 2021 -* Add several ResNet weights with ECA attention. 26t & 50t trained @ 256, test @ 320. 269d train @ 256, fine-tune @320, test @ 352. - * `ecaresnet26t` - 79.88 top-1 @ 320x320, 79.08 @ 256x256 - * `ecaresnet50t` - 82.35 top-1 @ 320x320, 81.52 @ 256x256 - * `ecaresnet269d` - 84.93 top-1 @ 352x352, 84.87 @ 320x320 -* Remove separate tiered (`t`) vs tiered_narrow (`tn`) ResNet model defs, all `tn` changed to `t` and `t` models removed (`seresnext26t_32x4d` only model w/ weights that was removed). -* Support model default_cfgs with separate train vs test resolution `test_input_size` and remove extra `_320` suffix ResNet model defs that were just for test. - -### Jan 30, 2021 -* Add initial "Normalization Free" NF-RegNet-B* and NF-ResNet model definitions based on [paper](https://arxiv.org/abs/2101.08692) - -### Jan 25, 2021 -* Add ResNetV2 Big Transfer (BiT) models w/ ImageNet-1k and 21k weights from https://github.com/google-research/big_transfer -* Add official R50+ViT-B/16 hybrid models + weights from https://github.com/google-research/vision_transformer -* ImageNet-21k ViT weights are added w/ model defs and representation layer (pre logits) support - * NOTE: ImageNet-21k classifier heads were zero'd in original weights, they are only useful for transfer learning -* Add model defs and weights for DeiT Vision Transformer models from https://github.com/facebookresearch/deit -* Refactor dataset classes into ImageDataset/IterableImageDataset + dataset specific parser classes -* Add Tensorflow-Datasets (TFDS) wrapper to allow use of TFDS image classification sets with train script - * Ex: `train.py /data/tfds --dataset tfds/oxford_iiit_pet --val-split test --model resnet50 -b 256 --amp --num-classes 37 --opt adamw --lr 3e-4 --weight-decay .001 --pretrained -j 2` -* Add improved .tar dataset parser that reads images from .tar, folder of .tar files, or .tar within .tar - * Run validation on full ImageNet-21k directly from tar w/ BiT model: `validate.py /data/fall11_whole.tar --model resnetv2_50x1_bitm_in21k --amp` -* Models in this update should be stable w/ possible exception of ViT/BiT, possibility of some regressions with train/val scripts and dataset handling - -### Jan 3, 2021 -* Add SE-ResNet-152D weights - * 256x256 val, 0.94 crop top-1 - 83.75 - * 320x320 val, 1.0 crop - 84.36 -* Update results files +### March 23, 2022 +* Add `ParallelBlock` and `LayerScale` option to base vit models to support model configs in [Three things everyone should know about ViT](https://arxiv.org/abs/2203.09795) +* `convnext_tiny_hnf` (head norm first) weights trained with (close to) A2 recipe, 82.2% top-1, could do better with more epochs. + +### March 21, 2022 +* Merge `norm_norm_norm`. **IMPORTANT** this update for a coming 0.6.x release will likely de-stabilize the master branch for a while. Branch [`0.5.x`](https://github.com/rwightman/pytorch-image-models/tree/0.5.x) or a previous 0.5.x release can be used if stability is required. +* Significant weights update (all TPU trained) as described in this [release](https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights) + * `regnety_040` - 82.3 @ 224, 82.96 @ 288 + * `regnety_064` - 83.0 @ 224, 83.65 @ 288 + * `regnety_080` - 83.17 @ 224, 83.86 @ 288 + * `regnetv_040` - 82.44 @ 224, 83.18 @ 288 (timm pre-act) + * `regnetv_064` - 83.1 @ 224, 83.71 @ 288 (timm pre-act) + * `regnetz_040` - 83.67 @ 256, 84.25 @ 320 + * `regnetz_040h` - 83.77 @ 256, 84.5 @ 320 (w/ extra fc in head) + * `resnetv2_50d_gn` - 80.8 @ 224, 81.96 @ 288 (pre-act GroupNorm) + * `resnetv2_50d_evos` 80.77 @ 224, 82.04 @ 288 (pre-act EvoNormS) + * `regnetz_c16_evos` - 81.9 @ 256, 82.64 @ 320 (EvoNormS) + * `regnetz_d8_evos` - 83.42 @ 256, 84.04 @ 320 (EvoNormS) + * `xception41p` - 82 @ 299 (timm pre-act) + * `xception65` - 83.17 @ 299 + * `xception65p` - 83.14 @ 299 (timm pre-act) + * `resnext101_64x4d` - 82.46 @ 224, 83.16 @ 288 + * `seresnext101_32x8d` - 83.57 @ 224, 84.270 @ 288 + * `resnetrs200` - 83.85 @ 256, 84.44 @ 320 +* HuggingFace hub support fixed w/ initial groundwork for allowing alternative 'config sources' for pretrained model definitions and weights (generic local file / remote url support soon) +* SwinTransformer-V2 implementation added. Submitted by [Christoph Reich](https://github.com/ChristophReich1996). Training experiments and model changes by myself are ongoing so expect compat breaks. +* Swin-S3 (AutoFormerV2) models / weights added from https://github.com/microsoft/Cream/tree/main/AutoFormerV2 +* MobileViT models w/ weights adapted from https://github.com/apple/ml-cvnets +* PoolFormer models w/ weights adapted from https://github.com/sail-sg/poolformer +* VOLO models w/ weights adapted from https://github.com/sail-sg/volo +* Significant work experimenting with non-BatchNorm norm layers such as EvoNorm, FilterResponseNorm, GroupNorm, etc +* Enhance support for alternate norm + act ('NormAct') layers added to a number of models, esp EfficientNet/MobileNetV3, RegNet, and aligned Xception +* Grouped conv support added to EfficientNet family +* Add 'group matching' API to all models to allow grouping model parameters for application of 'layer-wise' LR decay, lr scale added to LR scheduler +* Gradient checkpointing support added to many models +* `forward_head(x, pre_logits=False)` fn added to all models to allow separate calls of `forward_features` + `forward_head` +* All vision transformer and vision MLP models update to return non-pooled / non-token selected features from `foward_features`, for consistency with CNN models, token selection or pooling now applied in `forward_head` + +### Feb 2, 2022 +* [Chris Hughes](https://github.com/Chris-hughes10) posted an exhaustive run through of `timm` on his blog yesterday. Well worth a read. [Getting Started with PyTorch Image Models (timm): A Practitioner’s Guide](https://towardsdatascience.com/getting-started-with-pytorch-image-models-timm-a-practitioners-guide-4e77b4bf9055) +* I'm currently prepping to merge the `norm_norm_norm` branch back to master (ver 0.6.x) in next week or so. + * The changes are more extensive than usual and may destabilize and break some model API use (aiming for full backwards compat). So, beware `pip install git+https://github.com/rwightman/pytorch-image-models` installs! + * `0.5.x` releases and a `0.5.x` branch will remain stable with a cherry pick or two until dust clears. Recommend sticking to pypi install for a bit if you want stable. + +### Jan 14, 2022 +* Version 0.5.4 w/ release to be pushed to pypi. It's been a while since last pypi update and riskier changes will be merged to main branch soon.... +* Add ConvNeXT models /w weights from official impl (https://github.com/facebookresearch/ConvNeXt), a few perf tweaks, compatible with timm features +* Tried training a few small (~1.8-3M param) / mobile optimized models, a few are good so far, more on the way... + * `mnasnet_small` - 65.6 top-1 + * `mobilenetv2_050` - 65.9 + * `lcnet_100/075/050` - 72.1 / 68.8 / 63.1 + * `semnasnet_075` - 73 + * `fbnetv3_b/d/g` - 79.1 / 79.7 / 82.0 +* TinyNet models added by [rsomani95](https://github.com/rsomani95) +* LCNet added via MobileNetV3 architecture + +### Nov 22, 2021 +* A number of updated weights anew new model defs + * `eca_halonext26ts` - 79.5 @ 256 + * `resnet50_gn` (new) - 80.1 @ 224, 81.3 @ 288 + * `resnet50` - 80.7 @ 224, 80.9 @ 288 (trained at 176, not replacing current a1 weights as default since these don't scale as well to higher res, [weights](https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1h2_176-001a1197.pth)) + * `resnext50_32x4d` - 81.1 @ 224, 82.0 @ 288 + * `sebotnet33ts_256` (new) - 81.2 @ 224 + * `lamhalobotnet50ts_256` - 81.5 @ 256 + * `halonet50ts` - 81.7 @ 256 + * `halo2botnet50ts_256` - 82.0 @ 256 + * `resnet101` - 82.0 @ 224, 82.8 @ 288 + * `resnetv2_101` (new) - 82.1 @ 224, 83.0 @ 288 + * `resnet152` - 82.8 @ 224, 83.5 @ 288 + * `regnetz_d8` (new) - 83.5 @ 256, 84.0 @ 320 + * `regnetz_e8` (new) - 84.5 @ 256, 85.0 @ 320 +* `vit_base_patch8_224` (85.8 top-1) & `in21k` variant weights added thanks [Martins Bruveris](https://github.com/martinsbruveris) +* Groundwork in for FX feature extraction thanks to [Alexander Soare](https://github.com/alexander-soare) + * models updated for tracing compatibility (almost full support with some distlled transformer exceptions) + +### Oct 19, 2021 +* ResNet strikes back (https://arxiv.org/abs/2110.00476) weights added, plus any extra training components used. Model weights and some more details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-rsb-weights) +* BCE loss and Repeated Augmentation support for RSB paper +* 4 series of ResNet based attention model experiments being added (implemented across byobnet.py/byoanet.py). These include all sorts of attention, from channel attn like SE, ECA to 2D QKV self-attention layers such as Halo, Bottlneck, Lambda. Details here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights) +* Working implementations of the following 2D self-attention modules (likely to be differences from paper or eventual official impl): + * Halo (https://arxiv.org/abs/2103.12731) + * Bottleneck Transformer (https://arxiv.org/abs/2101.11605) + * LambdaNetworks (https://arxiv.org/abs/2102.08602) +* A RegNetZ series of models with some attention experiments (being added to). These do not follow the paper (https://arxiv.org/abs/2103.06877) in any way other than block architecture, details of official models are not available. See more here (https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-attn-weights) +* ConvMixer (https://openreview.net/forum?id=TVHS5Y4dNvM), CrossVit (https://arxiv.org/abs/2103.14899), and BeiT (https://arxiv.org/abs/2106.08254) architectures + weights added +* freeze/unfreeze helpers by [Alexander Soare](https://github.com/alexander-soare) + +### Aug 18, 2021 +* Optimizer bonanza! + * Add LAMB and LARS optimizers, incl trust ratio clipping options. Tweaked to work properly in PyTorch XLA (tested on TPUs w/ `timm bits` [branch](https://github.com/rwightman/pytorch-image-models/tree/bits_and_tpu/timm/bits)) + * Add MADGRAD from FB research w/ a few tweaks (decoupled decay option, step handling that works with PyTorch XLA) + * Some cleanup on all optimizers and factory. No more `.data`, a bit more consistency, unit tests for all! + * SGDP and AdamP still won't work with PyTorch XLA but others should (have yet to test Adabelief, Adafactor, Adahessian myself). +* EfficientNet-V2 XL TF ported weights added, but they don't validate well in PyTorch (L is better). The pre-processing for the V2 TF training is a bit diff and the fine-tuned 21k -> 1k weights are very sensitive and less robust than the 1k weights. +* Added PyTorch trained EfficientNet-V2 'Tiny' w/ GlobalContext attn weights. Only .1-.2 top-1 better than the SE so more of a curiosity for those interested. + +### July 12, 2021 +* Add XCiT models from [official facebook impl](https://github.com/facebookresearch/xcit). Contributed by [Alexander Soare](https://github.com/alexander-soare) + +### July 5-9, 2021 +* Add `efficientnetv2_rw_t` weights, a custom 'tiny' 13.6M param variant that is a bit better than (non NoisyStudent) B3 models. Both faster and better accuracy (at same or lower res) + * top-1 82.34 @ 288x288 and 82.54 @ 320x320 +* Add [SAM pretrained](https://arxiv.org/abs/2106.01548) in1k weight for ViT B/16 (`vit_base_patch16_sam_224`) and B/32 (`vit_base_patch32_sam_224`) models. +* Add 'Aggregating Nested Transformer' (NesT) w/ weights converted from official [Flax impl](https://github.com/google-research/nested-transformer). Contributed by [Alexander Soare](https://github.com/alexander-soare). + * `jx_nest_base` - 83.534, `jx_nest_small` - 83.120, `jx_nest_tiny` - 81.426 + +### June 23, 2021 +* Reproduce gMLP model training, `gmlp_s16_224` trained to 79.6 top-1, matching [paper](https://arxiv.org/abs/2105.08050). Hparams for this and other recent MLP training [here](https://gist.github.com/rwightman/d6c264a9001f9167e06c209f630b2cc6) + +### June 20, 2021 +* Release Vision Transformer 'AugReg' weights from [How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers](https://arxiv.org/abs/2106.10270) + * .npz weight loading support added, can load any of the 50K+ weights from the [AugReg series](https://console.cloud.google.com/storage/browser/vit_models/augreg) + * See [example notebook](https://colab.research.google.com/github/google-research/vision_transformer/blob/master/vit_jax_augreg.ipynb) from [official impl](https://github.com/google-research/vision_transformer/) for navigating the augreg weights + * Replaced all default weights w/ best AugReg variant (if possible). All AugReg 21k classifiers work. + * Highlights: `vit_large_patch16_384` (87.1 top-1), `vit_large_r50_s32_384` (86.2 top-1), `vit_base_patch16_384` (86.0 top-1) + * `vit_deit_*` renamed to just `deit_*` + * Remove my old small model, replace with DeiT compatible small w/ AugReg weights +* Add 1st training of my `gmixer_24_224` MLP /w GLU, 78.1 top-1 w/ 25M params. +* Add weights from official ResMLP release (https://github.com/facebookresearch/deit) +* Add `eca_nfnet_l2` weights from my 'lightweight' series. 84.7 top-1 at 384x384. +* Add distilled BiT 50x1 student and 152x2 Teacher weights from [Knowledge distillation: A good teacher is patient and consistent](https://arxiv.org/abs/2106.05237) +* NFNets and ResNetV2-BiT models work w/ Pytorch XLA now + * weight standardization uses F.batch_norm instead of std_mean (std_mean wasn't lowered) + * eps values adjusted, will be slight differences but should be quite close +* Improve test coverage and classifier interface of non-conv (vision transformer and mlp) models +* Cleanup a few classifier / flatten details for models w/ conv classifiers or early global pool +* Please report any regressions, this PR touched quite a few models. diff --git a/docs/index.md b/docs/index.md index 95f7df64..e022f891 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,7 +4,7 @@ Welcome to the `timm` documentation, a lean set of docs that covers the basics of `timm`. -For a more comprehensive set of docs (currently under development), please visit [timmdocs](https://fastai.github.io/timmdocs/) by [Aman Arora](https://github.com/amaarora). +For a more comprehensive set of docs (currently under development), please visit [timmdocs](http://timm.fast.ai) by [Aman Arora](https://github.com/amaarora). ## Install @@ -20,17 +20,17 @@ pip install git+https://github.com/rwightman/pytorch-image-models.git ``` !!! info "Conda Environment" - All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically Python 3.6.x, 3.7.x., 3.8.x., 3.9 + All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically 3.7, 3.8, 3.9, 3.10 Little to no care has been taken to be Python 2.x friendly and will not support it. If you run into any challenges running on Windows, or other OS, I'm definitely open to looking into those issues so long as it's in a reproducible (read Conda) environment. - PyTorch versions 1.4, 1.5.x, 1.6, 1.7.x, and 1.8 have been tested with this code. + PyTorch versions 1.9, 1.10, 1.11 have been tested with the latest versions of this code. I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda: ``` conda create -n torch-env conda activate torch-env - conda install pytorch torchvision cudatoolkit=11.1 -c pytorch -c conda-forge + conda install pytorch torchvision cudatoolkit=11.3 -c pytorch conda install pyyaml ``` diff --git a/docs/scripts.md b/docs/scripts.md index f48eec0d..0fbf18f1 100644 --- a/docs/scripts.md +++ b/docs/scripts.md @@ -12,7 +12,7 @@ To train an SE-ResNet34 on ImageNet, locally distributed, 4 GPUs, one process pe `./distributed_train.sh 4 /data/imagenet --model seresnet34 --sched cosine --epochs 150 --warmup-epochs 5 --lr 0.4 --reprob 0.5 --remode pixel --batch-size 256 --amp -j 4` -NOTE: It is recommended to use PyTorch 1.7+ w/ PyTorch native AMP and DDP instead of APEX AMP. `--amp` defaults to native AMP as of timm ver 0.4.3. `--apex-amp` will force use of APEX components if they are installed. +NOTE: It is recommended to use PyTorch 1.9+ w/ PyTorch native AMP and DDP instead of APEX AMP. `--amp` defaults to native AMP as of timm ver 0.4.3. `--apex-amp` will force use of APEX components if they are installed. ## Validation / Inference Scripts From 7629d8264d559a5bc8da751dd02f2174ff4a07bc Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 22 Apr 2022 16:54:53 -0700 Subject: [PATCH 03/32] Add two new SE-ResNeXt101-D 32x8d weights, one anti-aliased and one not. Reshuffle default_cfgs vs model entrypoints for resnet.py so they are better aligned. --- timm/models/resnet.py | 432 ++++++++++++++++++++++-------------------- 1 file changed, 229 insertions(+), 203 deletions(-) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 7a5afb3b..d304cf42 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -148,6 +148,49 @@ default_cfgs = { 'swsl_resnext101_32x16d': _cfg( url='https://dl.fbaipublicfiles.com/semiweaksupervision/model_files/semi_weakly_supervised_resnext101_32x16-f3559a9c.pth'), + # Efficient Channel Attention ResNets + 'ecaresnet26t': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet26t_ra2-46609757.pth', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), + crop_pct=0.95, test_input_size=(3, 320, 320)), + 'ecaresnetlight': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNetLight_4f34b35b.pth', + interpolation='bicubic'), + 'ecaresnet50d': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet50D_833caf58.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'ecaresnet50d_pruned': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45899/outputs/ECAResNet50D_P_9c67f710.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'ecaresnet50t': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet50t_ra2-f7ac63c4.pth', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), + crop_pct=0.95, test_input_size=(3, 320, 320)), + 'ecaresnet101d': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet101D_281c5844.pth', + interpolation='bicubic', first_conv='conv1.0'), + 'ecaresnet101d_pruned': _cfg( + url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth', + interpolation='bicubic', + first_conv='conv1.0'), + 'ecaresnet200d': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)), + 'ecaresnet269d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet269d_320_ra2-7baa55cb.pth', + interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), pool_size=(10, 10), + crop_pct=1.0, test_input_size=(3, 352, 352)), + + # Efficient Channel Attention ResNeXts + 'ecaresnext26t_32x4d': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0'), + 'ecaresnext50t_32x4d': _cfg( + url='', + interpolation='bicubic', first_conv='conv1.0'), + # Squeeze-Excitation ResNets, to eventually replace the models in senet.py 'seresnet18': _cfg( url='', @@ -180,7 +223,6 @@ default_cfgs = { url='', interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)), - # Squeeze-Excitation ResNeXts, to eventually replace the models in senet.py 'seresnext26d_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26d_32x4d-80fa48a3.pth', @@ -199,55 +241,16 @@ default_cfgs = { 'seresnext101_32x8d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnext101_32x8d_ah-e6bc4c0a.pth', interpolation='bicubic', test_input_size=(3, 288, 288), crop_pct=1.0), + 'seresnext101d_32x8d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnext101d_32x8d_ah-191d7b94.pth', + interpolation='bicubic', test_input_size=(3, 288, 288), crop_pct=1.0), + 'senet154': _cfg( url='', interpolation='bicubic', first_conv='conv1.0'), - # Efficient Channel Attention ResNets - 'ecaresnet26t': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet26t_ra2-46609757.pth', - interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), - crop_pct=0.95, test_input_size=(3, 320, 320)), - 'ecaresnetlight': _cfg( - url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNetLight_4f34b35b.pth', - interpolation='bicubic'), - 'ecaresnet50d': _cfg( - url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet50D_833caf58.pth', - interpolation='bicubic', - first_conv='conv1.0'), - 'ecaresnet50d_pruned': _cfg( - url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45899/outputs/ECAResNet50D_P_9c67f710.pth', - interpolation='bicubic', - first_conv='conv1.0'), - 'ecaresnet50t': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet50t_ra2-f7ac63c4.pth', - interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), pool_size=(8, 8), - crop_pct=0.95, test_input_size=(3, 320, 320)), - 'ecaresnet101d': _cfg( - url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45402/outputs/ECAResNet101D_281c5844.pth', - interpolation='bicubic', first_conv='conv1.0'), - 'ecaresnet101d_pruned': _cfg( - url='https://imvl-automl-sh.oss-cn-shanghai.aliyuncs.com/darts/hyperml/hyperml/job_45610/outputs/ECAResNet101D_P_75a3370e.pth', - interpolation='bicubic', - first_conv='conv1.0'), - 'ecaresnet200d': _cfg( - url='', - interpolation='bicubic', first_conv='conv1.0', input_size=(3, 256, 256), crop_pct=0.94, pool_size=(8, 8)), - 'ecaresnet269d': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecaresnet269d_320_ra2-7baa55cb.pth', - interpolation='bicubic', first_conv='conv1.0', input_size=(3, 320, 320), pool_size=(10, 10), - crop_pct=1.0, test_input_size=(3, 352, 352)), - - # Efficient Channel Attention ResNeXts - 'ecaresnext26t_32x4d': _cfg( - url='', - interpolation='bicubic', first_conv='conv1.0'), - 'ecaresnext50t_32x4d': _cfg( - url='', - interpolation='bicubic', first_conv='conv1.0'), - - # ResNets with anti-aliasing blur pool + # ResNets with anti-aliasing / blur pool 'resnetblur18': _cfg( interpolation='bicubic'), 'resnetblur50': _cfg( @@ -268,6 +271,9 @@ default_cfgs = { 'seresnetaa50d': _cfg( url='', interpolation='bicubic', first_conv='conv1.0'), + 'seresnextaa101d_32x8d': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnextaa101d_32x8d_ah-83c8ae12.pth', + interpolation='bicubic', test_input_size=(3, 288, 288), crop_pct=1.0), # ResNet-RS models 'resnetrs50': _cfg( @@ -1157,98 +1163,6 @@ def ecaresnet50d(pretrained=False, **kwargs): return _create_resnet('ecaresnet50d', pretrained, **model_args) -@register_model -def resnetrs50(pretrained=False, **kwargs): - """Constructs a ResNet-RS-50 model. - Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 - Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs - """ - attn_layer = partial(get_attn('se'), rd_ratio=0.25) - model_args = dict( - block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, - avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) - return _create_resnet('resnetrs50', pretrained, **model_args) - - -@register_model -def resnetrs101(pretrained=False, **kwargs): - """Constructs a ResNet-RS-101 model. - Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 - Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs - """ - attn_layer = partial(get_attn('se'), rd_ratio=0.25) - model_args = dict( - block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, - avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) - return _create_resnet('resnetrs101', pretrained, **model_args) - - -@register_model -def resnetrs152(pretrained=False, **kwargs): - """Constructs a ResNet-RS-152 model. - Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 - Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs - """ - attn_layer = partial(get_attn('se'), rd_ratio=0.25) - model_args = dict( - block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, - avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) - return _create_resnet('resnetrs152', pretrained, **model_args) - - -@register_model -def resnetrs200(pretrained=False, **kwargs): - """Constructs a ResNet-RS-200 model. - Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 - Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs - """ - attn_layer = partial(get_attn('se'), rd_ratio=0.25) - model_args = dict( - block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, - avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) - return _create_resnet('resnetrs200', pretrained, **model_args) - - -@register_model -def resnetrs270(pretrained=False, **kwargs): - """Constructs a ResNet-RS-270 model. - Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 - Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs - """ - attn_layer = partial(get_attn('se'), rd_ratio=0.25) - model_args = dict( - block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, - avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) - return _create_resnet('resnetrs270', pretrained, **model_args) - - - -@register_model -def resnetrs350(pretrained=False, **kwargs): - """Constructs a ResNet-RS-350 model. - Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 - Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs - """ - attn_layer = partial(get_attn('se'), rd_ratio=0.25) - model_args = dict( - block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, - avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) - return _create_resnet('resnetrs350', pretrained, **model_args) - - -@register_model -def resnetrs420(pretrained=False, **kwargs): - """Constructs a ResNet-RS-420 model - Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 - Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs - """ - attn_layer = partial(get_attn('se'), rd_ratio=0.25) - model_args = dict( - block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, - avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) - return _create_resnet('resnetrs420', pretrained, **model_args) - - @register_model def ecaresnet50d_pruned(pretrained=False, **kwargs): """Constructs a ResNet-50-D model pruned with eca. @@ -1346,72 +1260,6 @@ def ecaresnext50t_32x4d(pretrained=False, **kwargs): return _create_resnet('ecaresnext50t_32x4d', pretrained, **model_args) -@register_model -def resnetblur18(pretrained=False, **kwargs): - """Constructs a ResNet-18 model with blur anti-aliasing - """ - model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d, **kwargs) - return _create_resnet('resnetblur18', pretrained, **model_args) - - -@register_model -def resnetblur50(pretrained=False, **kwargs): - """Constructs a ResNet-50 model with blur anti-aliasing - """ - model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, **kwargs) - return _create_resnet('resnetblur50', pretrained, **model_args) - - -@register_model -def resnetblur50d(pretrained=False, **kwargs): - """Constructs a ResNet-50-D model with blur anti-aliasing - """ - model_args = dict( - block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, - stem_width=32, stem_type='deep', avg_down=True, **kwargs) - return _create_resnet('resnetblur50d', pretrained, **model_args) - - -@register_model -def resnetblur101d(pretrained=False, **kwargs): - """Constructs a ResNet-101-D model with blur anti-aliasing - """ - model_args = dict( - block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=BlurPool2d, - stem_width=32, stem_type='deep', avg_down=True, **kwargs) - return _create_resnet('resnetblur101d', pretrained, **model_args) - - -@register_model -def resnetaa50d(pretrained=False, **kwargs): - """Constructs a ResNet-50-D model with avgpool anti-aliasing - """ - model_args = dict( - block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, - stem_width=32, stem_type='deep', avg_down=True, **kwargs) - return _create_resnet('resnetaa50d', pretrained, **model_args) - - -@register_model -def resnetaa101d(pretrained=False, **kwargs): - """Constructs a ResNet-101-D model with avgpool anti-aliasing - """ - model_args = dict( - block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=nn.AvgPool2d, - stem_width=32, stem_type='deep', avg_down=True, **kwargs) - return _create_resnet('resnetaa101d', pretrained, **model_args) - - -@register_model -def seresnetaa50d(pretrained=False, **kwargs): - """Constructs a SE=ResNet-50-D model with avgpool anti-aliasing - """ - model_args = dict( - block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, - stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) - return _create_resnet('seresnetaa50d', pretrained, **model_args) - - @register_model def seresnet18(pretrained=False, **kwargs): model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], block_args=dict(attn_layer='se'), **kwargs) @@ -1535,9 +1383,187 @@ def seresnext101_32x8d(pretrained=False, **kwargs): return _create_resnet('seresnext101_32x8d', pretrained, **model_args) +@register_model +def seresnext101d_32x8d(pretrained=False, **kwargs): + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, + stem_width=32, stem_type='deep', avg_down=True, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnext101d_32x8d', pretrained, **model_args) + + @register_model def senet154(pretrained=False, **kwargs): model_args = dict( block=Bottleneck, layers=[3, 8, 36, 3], cardinality=64, base_width=4, stem_type='deep', down_kernel_size=3, block_reduce_first=2, block_args=dict(attn_layer='se'), **kwargs) return _create_resnet('senet154', pretrained, **model_args) + + +@register_model +def resnetblur18(pretrained=False, **kwargs): + """Constructs a ResNet-18 model with blur anti-aliasing + """ + model_args = dict(block=BasicBlock, layers=[2, 2, 2, 2], aa_layer=BlurPool2d, **kwargs) + return _create_resnet('resnetblur18', pretrained, **model_args) + + +@register_model +def resnetblur50(pretrained=False, **kwargs): + """Constructs a ResNet-50 model with blur anti-aliasing + """ + model_args = dict(block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, **kwargs) + return _create_resnet('resnetblur50', pretrained, **model_args) + + +@register_model +def resnetblur50d(pretrained=False, **kwargs): + """Constructs a ResNet-50-D model with blur anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=BlurPool2d, + stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnetblur50d', pretrained, **model_args) + + +@register_model +def resnetblur101d(pretrained=False, **kwargs): + """Constructs a ResNet-101-D model with blur anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=BlurPool2d, + stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnetblur101d', pretrained, **model_args) + + +@register_model +def resnetaa50d(pretrained=False, **kwargs): + """Constructs a ResNet-50-D model with avgpool anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, + stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnetaa50d', pretrained, **model_args) + + +@register_model +def resnetaa101d(pretrained=False, **kwargs): + """Constructs a ResNet-101-D model with avgpool anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], aa_layer=nn.AvgPool2d, + stem_width=32, stem_type='deep', avg_down=True, **kwargs) + return _create_resnet('resnetaa101d', pretrained, **model_args) + + +@register_model +def seresnetaa50d(pretrained=False, **kwargs): + """Constructs a SE=ResNet-50-D model with avgpool anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], aa_layer=nn.AvgPool2d, + stem_width=32, stem_type='deep', avg_down=True, block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnetaa50d', pretrained, **model_args) + + +@register_model +def seresnextaa101d_32x8d(pretrained=False, **kwargs): + """Constructs a SE=ResNeXt-101-D 32x8d model with avgpool anti-aliasing + """ + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], cardinality=32, base_width=8, + stem_width=32, stem_type='deep', avg_down=True, aa_layer=nn.AvgPool2d, + block_args=dict(attn_layer='se'), **kwargs) + return _create_resnet('seresnextaa101d_32x8d', pretrained, **model_args) + + +@register_model +def resnetrs50(pretrained=False, **kwargs): + """Constructs a ResNet-RS-50 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs50', pretrained, **model_args) + + +@register_model +def resnetrs101(pretrained=False, **kwargs): + """Constructs a ResNet-RS-101 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs101', pretrained, **model_args) + + +@register_model +def resnetrs152(pretrained=False, **kwargs): + """Constructs a ResNet-RS-152 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs152', pretrained, **model_args) + + +@register_model +def resnetrs200(pretrained=False, **kwargs): + """Constructs a ResNet-RS-200 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs200', pretrained, **model_args) + + +@register_model +def resnetrs270(pretrained=False, **kwargs): + """Constructs a ResNet-RS-270 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs270', pretrained, **model_args) + + + +@register_model +def resnetrs350(pretrained=False, **kwargs): + """Constructs a ResNet-RS-350 model. + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs350', pretrained, **model_args) + + +@register_model +def resnetrs420(pretrained=False, **kwargs): + """Constructs a ResNet-RS-420 model + Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 + Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs + """ + attn_layer = partial(get_attn('se'), rd_ratio=0.25) + model_args = dict( + block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, + avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) + return _create_resnet('resnetrs420', pretrained, **model_args) From 52ac8814024e373590e13c192dae16846cdf97ae Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 22 Apr 2022 20:55:52 -0700 Subject: [PATCH 04/32] Missed first_conv in latest seresnext 'D' default_cfgs --- timm/models/resnet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index d304cf42..a7f0c0f6 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -243,7 +243,7 @@ default_cfgs = { interpolation='bicubic', test_input_size=(3, 288, 288), crop_pct=1.0), 'seresnext101d_32x8d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnext101d_32x8d_ah-191d7b94.pth', - interpolation='bicubic', test_input_size=(3, 288, 288), crop_pct=1.0), + interpolation='bicubic', first_conv='conv1.0', test_input_size=(3, 288, 288), crop_pct=1.0), 'senet154': _cfg( url='', @@ -273,7 +273,7 @@ default_cfgs = { interpolation='bicubic', first_conv='conv1.0'), 'seresnextaa101d_32x8d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/seresnextaa101d_32x8d_ah-83c8ae12.pth', - interpolation='bicubic', test_input_size=(3, 288, 288), crop_pct=1.0), + interpolation='bicubic', first_conv='conv1.0', test_input_size=(3, 288, 288), crop_pct=1.0), # ResNet-RS models 'resnetrs50': _cfg( From 09e9f3defb0d4ee27cc77be4743ea5c06199255a Mon Sep 17 00:00:00 2001 From: Li Dong Date: Sat, 23 Apr 2022 13:02:29 +0800 Subject: [PATCH 05/32] migrate azure blob for beit checkpoints ## Motivation We are going to use a new blob account to store the checkpoints. ## Modification Modify the azure blob storage URLs for BEiT checkpoints. --- timm/models/beit.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/timm/models/beit.py b/timm/models/beit.py index 557ac9b0..a56653dd 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -46,27 +46,27 @@ def _cfg(url='', **kwargs): default_cfgs = { 'beit_base_patch16_224': _cfg( - url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'), + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22kto1k.pth'), 'beit_base_patch16_384': _cfg( - url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth', + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_384_pt22k_ft22kto1k.pth', input_size=(3, 384, 384), crop_pct=1.0, ), 'beit_base_patch16_224_in22k': _cfg( - url='https://unilm.blob.core.windows.net/beit/beit_base_patch16_224_pt22k_ft22k.pth', + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_base_patch16_224_pt22k_ft22k.pth', num_classes=21841, ), 'beit_large_patch16_224': _cfg( - url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'), + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22kto1k.pth'), 'beit_large_patch16_384': _cfg( - url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth', + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_384_pt22k_ft22kto1k.pth', input_size=(3, 384, 384), crop_pct=1.0, ), 'beit_large_patch16_512': _cfg( - url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth', + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_512_pt22k_ft22kto1k.pth', input_size=(3, 512, 512), crop_pct=1.0, ), 'beit_large_patch16_224_in22k': _cfg( - url='https://unilm.blob.core.windows.net/beit/beit_large_patch16_224_pt22k_ft22k.pth', + url='https://conversationhub.blob.core.windows.net/beit-share-public/beit/beit_large_patch16_224_pt22k_ft22k.pth', num_classes=21841, ), } From f88c606fcf47b6534241950e5ab9a1ce5cd7a5c5 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Mon, 25 Apr 2022 12:41:46 -0700 Subject: [PATCH 06/32] fixing channels_last on cond_conv2d; update nvfuser debug env variable --- timm/models/layers/cond_conv2d.py | 3 ++- timm/utils/jit.py | 6 +++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/timm/models/layers/cond_conv2d.py b/timm/models/layers/cond_conv2d.py index 8b4bbca8..43654c59 100644 --- a/timm/models/layers/cond_conv2d.py +++ b/timm/models/layers/cond_conv2d.py @@ -91,7 +91,8 @@ class CondConv2d(nn.Module): bias = torch.matmul(routing_weights, self.bias) bias = bias.view(B * self.out_channels) # move batch elements with channels so each batch element can be efficiently convolved with separate kernel - x = x.view(1, B * C, H, W) + # reshape instead of view to work with channels_last input + x = x.reshape(1, B * C, H, W) if self.dynamic_padding: out = conv2d_same( x, weight, bias, stride=self.stride, padding=self.padding, diff --git a/timm/utils/jit.py b/timm/utils/jit.py index 6039823f..8ebfdbff 100644 --- a/timm/utils/jit.py +++ b/timm/utils/jit.py @@ -34,9 +34,9 @@ def set_jit_fuser(fuser): torch._C._jit_override_can_fuse_on_gpu(True) torch._C._jit_set_texpr_fuser_enabled(False) elif fuser == "nvfuser" or fuser == "nvf": - os.environ['PYTORCH_CUDA_FUSER_DISABLE_FALLBACK'] = '1' - os.environ['PYTORCH_CUDA_FUSER_DISABLE_FMA'] = '1' - os.environ['PYTORCH_CUDA_FUSER_JIT_OPT_LEVEL'] = '0' + os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK'] = '1' + os.environ['PYTORCH_NVFUSER_DISABLE_FMA'] = '1' + os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0' torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_mode(True) From b7cb8d0337b3e7b50516849805ddb9be5fc11644 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 26 Apr 2022 17:32:49 -0700 Subject: [PATCH 07/32] Add Swin-V2 Small-NS weights (83.5 @ 224). Add layer scale like 'init_values' via post-norm LN weight scaling --- timm/models/swin_transformer_v2_cr.py | 52 +++++++++++++++++++++++++-- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index bb8fe3cc..472ae205 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -76,10 +76,15 @@ default_cfgs = { 'swin_v2_cr_small_224': _cfg( url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth", input_size=(3, 224, 224), crop_pct=0.9), + 'swin_v2_cr_small_ns_224': _cfg( + url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth", + input_size=(3, 224, 224), crop_pct=0.9), 'swin_v2_cr_base_384': _cfg( url="", input_size=(3, 384, 384), crop_pct=1.0), 'swin_v2_cr_base_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), + 'swin_v2_cr_base_ns_224': _cfg( + url="", input_size=(3, 224, 224), crop_pct=0.9), 'swin_v2_cr_large_384': _cfg( url="", input_size=(3, 384, 384), crop_pct=1.0), 'swin_v2_cr_large_224': _cfg( @@ -179,7 +184,7 @@ class WindowMultiHeadAttention(nn.Module): hidden_features=meta_hidden_dim, out_features=num_heads, act_layer=nn.ReLU, - drop=0.1 # FIXME should there be stochasticity, appears to 'overfit' without? + drop=(0.125, 0.) # FIXME should there be stochasticity, appears to 'overfit' without? ) self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads))) self._make_pair_wise_relative_positions() @@ -304,6 +309,7 @@ class SwinTransformerBlock(nn.Module): window_size: Tuple[int, int], shift_size: Tuple[int, int] = (0, 0), mlp_ratio: float = 4.0, + init_values: float = 0, drop: float = 0.0, drop_attn: float = 0.0, drop_path: float = 0.0, @@ -317,6 +323,7 @@ class SwinTransformerBlock(nn.Module): self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size) self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size)) self.window_area = self.window_size[0] * self.window_size[1] + self.init_values: float = init_values # attn branch self.attn = WindowMultiHeadAttention( @@ -345,6 +352,7 @@ class SwinTransformerBlock(nn.Module): self.norm3 = norm_layer(dim) if extra_norm else nn.Identity() self._make_attention_mask() + self.init_weights() def _calc_window_shift(self, target_window_size): window_size = [f if f <= w else w for f, w in zip(self.feat_size, target_window_size)] @@ -377,6 +385,12 @@ class SwinTransformerBlock(nn.Module): attn_mask = None self.register_buffer("attn_mask", attn_mask, persistent=False) + def init_weights(self): + # extra, module specific weight init + if self.init_values: + nn.init.constant_(self.norm1.weight, self.init_values) + nn.init.constant_(self.norm2.weight, self.init_values) + def update_input_size(self, new_window_size: Tuple[int, int], new_feat_size: Tuple[int, int]) -> None: """Method updates the image resolution to be processed and window size and so the pair-wise relative positions. @@ -435,7 +449,7 @@ class SwinTransformerBlock(nn.Module): Returns: output (torch.Tensor): Output tensor of the shape [B, C, H, W] """ - # NOTE post-norm branches (op -> norm -> drop) + # post-norm branches (op -> norm -> drop) x = x + self.drop_path1(self.norm1(self._shifted_window_attn(x))) x = x + self.drop_path2(self.norm2(self.mlp(x))) x = self.norm3(x) # main-branch norm enabled for some blocks / stages (every 6 for Huge/Giant) @@ -522,6 +536,7 @@ class SwinTransformerStage(nn.Module): feat_size: Tuple[int, int], window_size: Tuple[int, int], mlp_ratio: float = 4.0, + init_values: float = 0.0, drop: float = 0.0, drop_attn: float = 0.0, drop_path: Union[List[float], float] = 0.0, @@ -552,6 +567,7 @@ class SwinTransformerStage(nn.Module): window_size=window_size, shift_size=tuple([0 if ((index % 2) == 0) else w // 2 for w in window_size]), mlp_ratio=mlp_ratio, + init_values=init_values, drop=drop, drop_attn=drop_attn, drop_path=drop_path[index] if isinstance(drop_path, list) else drop_path, @@ -634,6 +650,7 @@ class SwinTransformerV2Cr(nn.Module): depths: Tuple[int, ...] = (2, 2, 6, 2), num_heads: Tuple[int, ...] = (3, 6, 12, 24), mlp_ratio: float = 4.0, + init_values: float = 0.0, drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, @@ -674,6 +691,7 @@ class SwinTransformerV2Cr(nn.Module): num_heads=num_heads, window_size=window_size, mlp_ratio=mlp_ratio, + init_values=init_values, drop=drop_rate, drop_attn=attn_drop_rate, drop_path=drop_path_rate[sum(depths[:index]):sum(depths[:index + 1])], @@ -786,6 +804,8 @@ def init_weights(module: nn.Module, name: str = ''): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs): @@ -863,6 +883,20 @@ def swin_v2_cr_small_224(pretrained=False, **kwargs): return _create_swin_transformer_v2_cr('swin_v2_cr_small_224', pretrained=pretrained, **model_kwargs) +@register_model +def swin_v2_cr_small_ns_224(pretrained=False, **kwargs): + """Swin-S V2 CR @ 224x224, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=96, + depths=(2, 2, 18, 2), + num_heads=(3, 6, 12, 24), + init_values=1e-5, + extra_norm_stage=True, + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_small_ns_224', pretrained=pretrained, **model_kwargs) + + @register_model def swin_v2_cr_base_384(pretrained=False, **kwargs): """Swin-B V2 CR @ 384x384, trained ImageNet-1k""" @@ -887,6 +921,20 @@ def swin_v2_cr_base_224(pretrained=False, **kwargs): return _create_swin_transformer_v2_cr('swin_v2_cr_base_224', pretrained=pretrained, **model_kwargs) +@register_model +def swin_v2_cr_base_ns_224(pretrained=False, **kwargs): + """Swin-B V2 CR @ 224x224, trained ImageNet-1k""" + model_kwargs = dict( + embed_dim=128, + depths=(2, 2, 18, 2), + num_heads=(4, 8, 16, 32), + init_values=1e-6, + extra_norm_stage=True, + **kwargs + ) + return _create_swin_transformer_v2_cr('swin_v2_cr_base_ns_224', pretrained=pretrained, **model_kwargs) + + @register_model def swin_v2_cr_large_384(pretrained=False, **kwargs): """Swin-L V2 CR @ 384x384, trained ImageNet-1k""" From 41dc49a33752b72dbb3cff5cb181b9953e07971f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 2 May 2022 15:37:39 -0700 Subject: [PATCH 08/32] Vision Transformer refactoring and Rel Pos impl --- README.md | 9 + timm/models/__init__.py | 1 + timm/models/vision_transformer.py | 190 +++++----- timm/models/vision_transformer_hybrid.py | 2 +- timm/models/vision_transformer_relpos.py | 425 +++++++++++++++++++++++ 5 files changed, 544 insertions(+), 83 deletions(-) create mode 100644 timm/models/vision_transformer_relpos.py diff --git a/README.md b/README.md index 355cedaf..df5fb968 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,15 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor ## What's New + +### May 2, 2022 +* Vision Transformer experiments adding Relative Position (Swin-V2 log-coord) (`vision_transformer_relpos.py`) and Residual Post-Norm branches (from Swin-V2) (`vision_transformer*.py`) + * `vit_relpos_base_patch32_plus_rpn_256` - 79.5 @ 256, 80.6 @ 320 -- rel pos + extended width + res-post-norm, no class token, avg pool + * `vit_relpos_base_patch16_224` - 82.5 @ 224, 83.6 @ 320 -- rel pos, layer scale, no class token, avg pool + * `vit_base_patch16_rpn_224` - 82.3 @ 224 -- rel pos + res-post-norm, no class token, avg pool +* Vision Transformer refactor to remove representation layer that was only used in initial vit and rarely used since with newer pretrain (ie `How to Train Your ViT`) +* `vit_*` models support removal of class token, use of global average pool, use of fc_norm (ala beit, mae). + ### April 22, 2022 * `timm` models are now officially supported in [fast.ai](https://www.fast.ai/)! Just in time for the new Practical Deep Learning course. `timmdocs` documentation link updated to [timm.fast.ai](http://timm.fast.ai/). * Two more model weights added in the TPU trained [series](https://github.com/rwightman/pytorch-image-models/releases/tag/v0.1-tpu-weights). Some In22k pretrain still in progress. diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 45ead5dc..c1d63dcc 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -49,6 +49,7 @@ from .vgg import * from .visformer import * from .vision_transformer import * from .vision_transformer_hybrid import * +from .vision_transformer_relpos import * from .volo import * from .vovnet import * from .xception import * diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 17faba53..33cc5db2 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -23,6 +23,7 @@ import math import logging from functools import partial from collections import OrderedDict +from typing import Optional import torch import torch.nn as nn @@ -107,7 +108,6 @@ default_cfgs = { 'vit_giant_patch14_224': _cfg(url=''), 'vit_gigantic_patch14_224': _cfg(url=''), - 'vit_base2_patch32_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), # patch models, imagenet21k (weights from official Google JAX impl) 'vit_tiny_patch16_224_in21k': _cfg( @@ -171,7 +171,12 @@ default_cfgs = { mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', ), - # experimental + 'vit_base_patch16_rpn_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_base_patch16_rpn_224-sw-3b07e89d.pth'), + + # experimental (may be removed) + 'vit_base_patch32_plus_256': _cfg(url='', input_size=(3, 256, 256), crop_pct=0.95), + 'vit_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240), crop_pct=0.95), 'vit_small_patch16_36x1_224': _cfg(url=''), 'vit_small_patch16_18x2_224': _cfg(url=''), 'vit_base_patch16_18x2_224': _cfg(url=''), @@ -229,8 +234,7 @@ class Block(nn.Module): self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) - mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -240,6 +244,36 @@ class Block(nn.Module): return x +class ResPostBlock(nn.Module): + + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., init_values=None, + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.init_values = init_values + + self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + self.norm1 = norm_layer(dim) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.norm2 = norm_layer(dim) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.init_weights() + + def init_weights(self): + # NOTE this init overrides that base model init with specific changes for the block type + if self.init_values is not None: + nn.init.constant_(self.norm1.weight, self.init_values) + nn.init.constant_(self.norm2.weight, self.init_values) + + def forward(self, x): + x = x + self.drop_path1(self.norm1(self.attn(x))) + x = x + self.drop_path2(self.norm2(self.mlp(x))) + return x + + class ParallelBlock(nn.Module): def __init__( @@ -290,9 +324,9 @@ class VisionTransformer(nn.Module): def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', - embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, representation_size=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', init_values=None, - embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', class_token=True, + fc_norm=None, embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): """ Args: img_size (int, tuple): input image size @@ -305,33 +339,36 @@ class VisionTransformer(nn.Module): num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True - representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + init_values: (float): layer-scale init values drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate - weight_init: (str): weight init scheme - init_values: (float): layer-scale init values + weight_init (str): weight init scheme + class_token (bool): use class token + fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) embed_layer (nn.Module): patch embedding layer norm_layer: (nn.Module): normalization layer act_layer: (nn.Module): MLP activation layer """ super().__init__() assert global_pool in ('', 'avg', 'token') + assert class_token or global_pool != 'token' + use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) act_layer = act_layer or nn.GELU self.num_classes = num_classes self.global_pool = global_pool self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models - self.num_tokens = 1 + self.num_tokens = 1 if class_token else 0 self.grad_checkpointing = False self.patch_embed = embed_layer( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) num_patches = self.patch_embed.num_patches - self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) - self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if self.num_tokens > 0 else None + self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, embed_dim) * .02) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule @@ -340,38 +377,21 @@ class VisionTransformer(nn.Module): dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, init_values=init_values, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer) for i in range(depth)]) - use_fc_norm = self.global_pool == 'avg' self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity() - # Representation layer. Used for original ViT models w/ in21k pretraining. - self.representation_size = representation_size - self.pre_logits = nn.Identity() - if representation_size: - self._reset_representation(representation_size) - # Classifier Head self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity() - final_chs = self.representation_size if self.representation_size else self.embed_dim - self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() if weight_init != 'skip': self.init_weights(weight_init) - def _reset_representation(self, representation_size): - self.representation_size = representation_size - if self.representation_size: - self.pre_logits = nn.Sequential(OrderedDict([ - ('fc', nn.Linear(self.embed_dim, self.representation_size)), - ('act', nn.Tanh()) - ])) - else: - self.pre_logits = nn.Identity() - def init_weights(self, mode=''): assert mode in ('jax', 'jax_nlhb', 'moco', '') head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0. trunc_normal_(self.pos_embed, std=.02) - nn.init.normal_(self.cls_token, std=1e-6) + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) named_apply(get_init_weights_vit(mode, head_bias), self) def _init_weights(self, m): @@ -401,19 +421,17 @@ class VisionTransformer(nn.Module): def get_classifier(self): return self.head - def reset_classifier(self, num_classes: int, global_pool=None, representation_size=None): + def reset_classifier(self, num_classes: int, global_pool=None): self.num_classes = num_classes if global_pool is not None: assert global_pool in ('', 'avg', 'token') self.global_pool = global_pool - if representation_size is not None: - self._reset_representation(representation_size) - final_chs = self.representation_size if self.representation_size else self.embed_dim - self.head = nn.Linear(final_chs, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.patch_embed(x) - x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) x = self.pos_drop(x + self.pos_embed) if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint_seq(self.blocks, x) @@ -424,9 +442,8 @@ class VisionTransformer(nn.Module): def forward_head(self, x, pre_logits: bool = False): if self.global_pool: - x = x[:, 1:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] x = self.fc_norm(x) - x = self.pre_logits(x) return x if pre_logits else self.head(x) def forward(self, x): @@ -441,6 +458,8 @@ def init_weights_vit_timm(module: nn.Module, name: str = ''): trunc_normal_(module.weight, std=.02) if module.bias is not None: nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0.): @@ -449,9 +468,6 @@ def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0 if name.startswith('head'): nn.init.zeros_(module.weight) nn.init.constant_(module.bias, head_bias) - elif name.startswith('pre_logits'): - lecun_normal_(module.weight) - nn.init.zeros_(module.bias) else: nn.init.xavier_uniform_(module.weight) if module.bias is not None: @@ -460,6 +476,8 @@ def init_weights_vit_jax(module: nn.Module, name: str = '', head_bias: float = 0 lecun_normal_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() def init_weights_vit_moco(module: nn.Module, name: str = ''): @@ -473,6 +491,8 @@ def init_weights_vit_moco(module: nn.Module, name: str = ''): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) + elif hasattr(module, 'init_weights'): + module.init_weights() def get_init_weights_vit(mode='jax', head_bias: float = 0.): @@ -543,9 +563,10 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]: model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel'])) model.head.bias.copy_(_n2p(w[f'{prefix}head/bias'])) - if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: - model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) - model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) + # NOTE representation layer has been removed, not used in latest 21k/1k pretrained weights + # if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w: + # model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel'])) + # model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias'])) for i, block in enumerate(model.blocks.children()): block_prefix = f'{prefix}Transformer/encoderblock_{i}/' mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/' @@ -601,6 +622,9 @@ def checkpoint_filter_fn(state_dict, model): # To resize pos embedding when using model at different size from pretrained weights v = resize_pos_embed( v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) + elif 'pre_logits' in k: + # NOTE representation layer removed as not used in latest 21k/1k pretrained weights + continue out_dict[k] = v return out_dict @@ -609,21 +633,10 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') - # NOTE this extra code to support handling of repr size for in21k pretrained models pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) - default_num_classes = pretrained_cfg['num_classes'] - num_classes = kwargs.get('num_classes', default_num_classes) - repr_size = kwargs.pop('representation_size', None) - if repr_size is not None and num_classes != default_num_classes: - # Remove representation layer if fine-tuning. This may not always be the desired action, - # but I feel better than doing nothing by default for fine-tuning. Perhaps a better interface? - _logger.warning("Removing representation layer for fine-tuning.") - repr_size = None - model = build_model_with_cfg( VisionTransformer, variant, pretrained, pretrained_cfg=pretrained_cfg, - representation_size=repr_size, pretrained_filter_fn=checkpoint_filter_fn, pretrained_custom_load='npz' in pretrained_cfg['url'], **kwargs) @@ -696,16 +709,6 @@ def vit_base_patch32_224(pretrained=False, **kwargs): return model -@register_model -def vit_base2_patch32_256(pretrained=False, **kwargs): - """ ViT-Base (ViT-B/32) - # FIXME experiment - """ - model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, **kwargs) - model = _create_vision_transformer('vit_base2_patch32_256', pretrained=pretrained, **model_kwargs) - return model - - @register_model def vit_base_patch32_384(pretrained=False, **kwargs): """ ViT-Base model (ViT-B/32) from original paper (https://arxiv.org/abs/2010.11929). @@ -860,8 +863,7 @@ def vit_base_patch32_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ - model_kwargs = dict( - patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) + model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch32_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -872,8 +874,7 @@ def vit_base_patch16_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ - model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) + model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -884,8 +885,7 @@ def vit_base_patch8_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ - model_kwargs = dict( - patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) + model_kwargs = dict(patch_size=8, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch8_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -896,8 +896,7 @@ def vit_large_patch32_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights """ - model_kwargs = dict( - patch_size=32, embed_dim=1024, depth=24, num_heads=16, representation_size=1024, **kwargs) + model_kwargs = dict(patch_size=32, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer('vit_large_patch32_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -908,8 +907,7 @@ def vit_large_patch16_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has valid 21k classifier head and no representation (pre-logits) layer """ - model_kwargs = dict( - patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) + model_kwargs = dict(patch_size=16, embed_dim=1024, depth=24, num_heads=16, **kwargs) model = _create_vision_transformer('vit_large_patch16_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -920,8 +918,7 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. NOTE: this model has a representation layer but the 21k classifier head is zero'd out in original weights """ - model_kwargs = dict( - patch_size=14, embed_dim=1280, depth=32, num_heads=16, representation_size=1280, **kwargs) + model_kwargs = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, **kwargs) model = _create_vision_transformer('vit_huge_patch14_224_in21k', pretrained=pretrained, **model_kwargs) return model @@ -930,7 +927,6 @@ def vit_huge_patch14_224_in21k(pretrained=False, **kwargs): def vit_base_patch16_224_sam(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 """ - # NOTE original SAM weights release worked with representation_size=768 model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch16_224_sam', pretrained=pretrained, **model_kwargs) return model @@ -940,7 +936,6 @@ def vit_base_patch16_224_sam(pretrained=False, **kwargs): def vit_base_patch32_224_sam(pretrained=False, **kwargs): """ ViT-Base (ViT-B/32) w/ SAM pretrained weights. Paper: https://arxiv.org/abs/2106.01548 """ - # NOTE original SAM weights release worked with representation_size=768 model_kwargs = dict(patch_size=32, embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer('vit_base_patch32_224_sam', pretrained=pretrained, **model_kwargs) return model @@ -1002,6 +997,37 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs): return model +# Experimental models below + +@register_model +def vit_base_patch32_plus_256(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/32+) + """ + model_kwargs = dict(patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs) + model = _create_vision_transformer('vit_base_patch32_plus_256', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_plus_240(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16+) + """ + model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs) + model = _create_vision_transformer('vit_base_patch16_plus_240', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_base_patch16_rpn_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ residual post-norm + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, class_token=False, + block_fn=ResPostBlock, global_pool=kwargs.pop('global_pool', 'avg'), **kwargs) + model = _create_vision_transformer('vit_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_small_patch16_36x1_224(pretrained=False, **kwargs): """ ViT-Base w/ LayerScale + 36 x 1 (36 block serial) config. Experimental, may remove. diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 0eee2044..24ff2096 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -295,7 +295,7 @@ def vit_base_r50_s16_224_in21k(pretrained=False, **kwargs): ImageNet-21k weights @ 224x224, source https://github.com/google-research/vision_transformer. """ backbone = _resnetv2(layers=(3, 4, 9), **kwargs) - model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, representation_size=768, **kwargs) + model_kwargs = dict(embed_dim=768, depth=12, num_heads=12, **kwargs) model = _create_vision_transformer_hybrid( 'vit_base_r50_s16_224_in21k', backbone=backbone, pretrained=pretrained, **model_kwargs) return model diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py new file mode 100644 index 00000000..056dba97 --- /dev/null +++ b/timm/models/vision_transformer_relpos.py @@ -0,0 +1,425 @@ +""" Relative Position Vision Transformer (ViT) in PyTorch + +Hacked together by / Copyright 2022, Ross Wightman +""" +import math +import logging +from functools import partial +from collections import OrderedDict +from typing import Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.checkpoint import checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD +from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply +from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple +from .registry import register_model + +_logger = logging.getLogger(__name__) + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'vit_relpos_base_patch32_plus_rpn_256': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_replos_base_patch32_plus_rpn_256-sw-dd486f51.pth', + input_size=(3, 256, 256)), + 'vit_relpos_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240)), + 'vit_relpos_base_patch16_rpn_224': _cfg(url=''), + 'vit_relpos_base_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth'), +} + + +def gen_relative_position_index(win_size: Tuple[int, int], class_token: int = 0) -> torch.Tensor: + # cut and paste w/ modifications from swin / beit codebase + # cls to token & token 2 cls & cls to cls + # get pair-wise relative position index for each token inside the window + window_area = win_size[0] * win_size[1] + coords = torch.stack(torch.meshgrid([torch.arange(win_size[0]), torch.arange(win_size[1])])).flatten(1) # 2, Wh, Ww + relative_coords = coords[:, :, None] - coords[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += win_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += win_size[1] - 1 + relative_coords[:, :, 0] *= 2 * win_size[1] - 1 + if class_token: + num_relative_distance = (2 * win_size[0] - 1) * (2 * win_size[1] - 1) + 3 + relative_position_index = torch.zeros(size=(window_area + 1,) * 2, dtype=relative_coords.dtype) + relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + relative_position_index[0, 0:] = num_relative_distance - 3 + relative_position_index[0:, 0] = num_relative_distance - 2 + relative_position_index[0, 0] = num_relative_distance - 1 + else: + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + return relative_position_index + + +def gen_relative_position_log(win_size: Tuple[int, int]) -> torch.Tensor: + """Method initializes the pair-wise relative positions to compute the positional biases.""" + coordinates = torch.stack(torch.meshgrid([torch.arange(win_size[0]), torch.arange(win_size[1])])).flatten(1) + relative_coords = coordinates[:, :, None] - coordinates[:, None, :] + relative_coords = relative_coords.permute(1, 2, 0).float() + relative_coordinates_log = torch.sign(relative_coords) * torch.log(1.0 + relative_coords.abs()) + return relative_coordinates_log + + +class RelPosMlp(nn.Module): + # based on timm swin-v2 impl + def __init__(self, window_size, num_heads=8, hidden_dim=32, class_token=False): + super().__init__() + self.window_size = window_size + self.window_area = self.window_size[0] * self.window_size[1] + self.class_token = 1 if class_token else 0 + self.num_heads = num_heads + + self.mlp = Mlp( + 2, # x, y + hidden_features=min(128, hidden_dim * num_heads), + out_features=num_heads, + act_layer=nn.ReLU, + drop=(0.125, 0.) + ) + + self.register_buffer( + 'rel_coords_log', + gen_relative_position_log(window_size), + persistent=False + ) + + def get_bias(self) -> torch.Tensor: + relative_position_bias = self.mlp(self.rel_coords_log).permute(2, 0, 1).unsqueeze(0) + if self.class_token: + relative_position_bias = F.pad(relative_position_bias, [self.class_token, 0, self.class_token, 0]) + return relative_position_bias + + def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): + return attn + self.get_bias() + + +class RelPosBias(nn.Module): + + def __init__(self, window_size, num_heads, class_token=False): + super().__init__() + self.window_size = window_size + self.window_area = window_size[0] * window_size[1] + self.class_token = 1 if class_token else 0 + self.bias_shape = (self.window_area + self.class_token,) * 2 + (num_heads,) + + num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 * self.class_token + self.relative_position_bias_table = nn.Parameter(torch.zeros(num_relative_distance, num_heads)) + self.register_buffer( + "relative_position_index", + gen_relative_position_index(self.window_size, class_token=self.class_token), + persistent=False, + ) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.relative_position_bias_table, std=.02) + + def get_bias(self) -> torch.Tensor: + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.bias_shape) # win_h * win_w, win_h * win_w, num_heads + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() + return relative_position_bias + + def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): + return attn + self.get_bias() + + +class RelPosAttention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, rel_pos_cls=None, attn_drop=0., proj_drop=0.): + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.rel_pos = rel_pos_cls(num_heads=num_heads) if rel_pos_cls else None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + if self.rel_pos is not None: + attn = self.rel_pos(attn, shared_rel_pos=shared_rel_pos) + elif shared_rel_pos is not None: + attn = attn + shared_rel_pos + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class LayerScale(nn.Module): + def __init__(self, dim, init_values=1e-5, inplace=False): + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x): + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class RelPosBlock(nn.Module): + + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, rel_pos_cls=None, init_values=None, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = RelPosAttention( + dim, num_heads, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=drop) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + x = x + self.drop_path1(self.ls1(self.attn(self.norm1(x), shared_rel_pos=shared_rel_pos))) + x = x + self.drop_path2(self.ls2(self.mlp(self.norm2(x)))) + return x + + +class ResPostRelPosBlock(nn.Module): + + def __init__( + self, dim, num_heads, mlp_ratio=4., qkv_bias=False, rel_pos_cls=None, init_values=None, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.init_values = init_values + + self.attn = RelPosAttention( + dim, num_heads, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, attn_drop=attn_drop, proj_drop=drop) + self.norm1 = norm_layer(dim) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.norm2 = norm_layer(dim) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.init_weights() + + def init_weights(self): + # NOTE this init overrides that base model init with specific changes for the block type + if self.init_values is not None: + nn.init.constant_(self.norm1.weight, self.init_values) + nn.init.constant_(self.norm2.weight, self.init_values) + + def forward(self, x, shared_rel_pos: Optional[torch.Tensor] = None): + x = x + self.drop_path1(self.norm1(self.attn(x, shared_rel_pos=shared_rel_pos))) + x = x + self.drop_path2(self.norm2(self.mlp(x))) + return x + + +class VisionTransformerRelPos(nn.Module): + """ Vision Transformer w/ Relative Position Bias + """ + + def __init__( + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg', + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip', class_token=False, + rel_pos_type='mlp', shared_rel_pos=False, fc_norm=False, + embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=RelPosBlock): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + global_pool (str): type of global pooling for final sequence (default: 'token') + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + init_values: (float): layer-scale init values + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + weight_init (str): weight init scheme + class_token (bool): use class token (default: False) + rel_pos_ty pe (str): type of relative position + shared_rel_pos (bool): share relative pos across all blocks + fc_norm (bool): use pre classifier norm + embed_layer (nn.Module): patch embedding layer + norm_layer: (nn.Module): normalization layer + act_layer: (nn.Module): MLP activation layer + """ + super().__init__() + assert global_pool in ('', 'avg', 'token') + assert class_token or global_pool != 'token' + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + act_layer = act_layer or nn.GELU + + self.num_classes = num_classes + self.global_pool = global_pool + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 if class_token else 0 + self.grad_checkpointing = False + + self.patch_embed = embed_layer( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + feat_size = self.patch_embed.grid_size + + rel_pos_cls = RelPosMlp if rel_pos_type == 'mlp' else RelPosBias + rel_pos_cls = partial(rel_pos_cls, window_size=feat_size, class_token=class_token) + self.shared_rel_pos = None + if shared_rel_pos: + self.shared_rel_pos = rel_pos_cls(num_heads=num_heads) + # NOTE shared rel pos currently mutually exclusive w/ per-block, but could support both... + rel_pos_cls = None + + self.cls_token = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim)) if self.num_tokens else None + + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + self.blocks = nn.ModuleList([ + block_fn( + dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, rel_pos_cls=rel_pos_cls, + init_values=init_values, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], + norm_layer=norm_layer, act_layer=act_layer) + for i in range(depth)]) + self.norm = norm_layer(embed_dim) if not fc_norm else nn.Identity() + + # Classifier Head + self.fc_norm = norm_layer(embed_dim) if fc_norm else nn.Identity() + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + if weight_init != 'skip': + self.init_weights(weight_init) + + def init_weights(self, mode=''): + assert mode in ('jax', 'moco', '') + if self.cls_token is not None: + nn.init.normal_(self.cls_token, std=1e-6) + # FIXME weight init scheme using PyTorch defaults curently + #named_apply(get_init_weights_vit(mode, head_bias), self) + + @torch.jit.ignore + def no_weight_decay(self): + return {'cls_token'} + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^cls_token|patch_embed', # stem and embed + blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + self.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes: int, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'avg', 'token') + self.global_pool = global_pool + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + if self.cls_token is not None: + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + + shared_rel_pos = self.shared_rel_pos.get_bias() if self.shared_rel_pos is not None else None + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(blk, x, shared_rel_pos=shared_rel_pos) + else: + x = blk(x, shared_rel_pos=shared_rel_pos) + x = self.norm(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool: + x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0] + x = self.fc_norm(x) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_vision_transformer_relpos(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg(VisionTransformerRelPos, variant, pretrained, **kwargs) + return model + + +@register_model +def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token + """ + model_kwargs = dict( + patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, + block_fn=ResPostRelPosBlock, **kwargs) + model = _create_vision_transformer_relpos( + 'vit_relpos_base_patch32_plus_rpn_256', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_base_patch16_plus_240(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16+) w/ relative log-coord position, no class token + """ + model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs) + model = _create_vision_transformer_relpos('vit_relpos_base_patch16_plus_240', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_base_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ relative log-coord position, no class token + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, + fc_norm=True, **kwargs) + model = _create_vision_transformer_relpos('vit_relpos_base_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_base_patch16_rpn_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, + block_fn=ResPostRelPosBlock, **kwargs) + model = _create_vision_transformer_relpos('vit_relpos_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs) + return model From f5ca4141f710d8b0b363f849abbf0182aebc5021 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 2 May 2022 22:41:38 -0700 Subject: [PATCH 09/32] Adjust arg order for recent vit model args, add a few comments --- timm/models/vision_transformer.py | 8 +++--- timm/models/vision_transformer_relpos.py | 35 +++++++++++++----------- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 33cc5db2..59fd7849 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -325,8 +325,8 @@ class VisionTransformer(nn.Module): def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token', embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', class_token=True, - fc_norm=None, embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): + class_token=True, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='', + embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block): """ Args: img_size (int, tuple): input image size @@ -340,12 +340,12 @@ class VisionTransformer(nn.Module): mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True init_values: (float): layer-scale init values + class_token (bool): use class token + fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate weight_init (str): weight init scheme - class_token (bool): use class token - fc_norm (Optional[bool]): pre-fc norm after pool, set if global_pool == 'avg' if None (default: None) embed_layer (nn.Module): patch embedding layer norm_layer: (nn.Module): normalization layer act_layer: (nn.Module): MLP activation layer diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 056dba97..9ecfd473 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -240,13 +240,19 @@ class ResPostRelPosBlock(nn.Module): class VisionTransformerRelPos(nn.Module): """ Vision Transformer w/ Relative Position Bias + + Differing from classic vit, this impl + * uses relative position index (swin v1 / beit) or relative log coord + mlp (swin v2) pos embed + * defaults to no class token (can be enabled) + * defaults to global avg pool for head (can be changed) + * layer-scale (residual branch gain) enabled """ def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg', - embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None, - drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip', class_token=False, - rel_pos_type='mlp', shared_rel_pos=False, fc_norm=False, + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=1e-5, + class_token=False, rel_pos_type='mlp', shared_rel_pos=False, fc_norm=False, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=RelPosBlock): """ Args: @@ -254,21 +260,21 @@ class VisionTransformerRelPos(nn.Module): patch_size (int, tuple): patch size in_chans (int): number of input channels num_classes (int): number of classes for classification head - global_pool (str): type of global pooling for final sequence (default: 'token') + global_pool (str): type of global pooling for final sequence (default: 'avg') embed_dim (int): embedding dimension depth (int): depth of transformer num_heads (int): number of attention heads mlp_ratio (int): ratio of mlp hidden dim to embedding dim qkv_bias (bool): enable bias for qkv if True init_values: (float): layer-scale init values + class_token (bool): use class token (default: False) + rel_pos_ty pe (str): type of relative position + shared_rel_pos (bool): share relative pos across all blocks + fc_norm (bool): use pre classifier norm instead of pre-pool drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate weight_init (str): weight init scheme - class_token (bool): use class token (default: False) - rel_pos_ty pe (str): type of relative position - shared_rel_pos (bool): share relative pos across all blocks - fc_norm (bool): use pre classifier norm embed_layer (nn.Module): patch embedding layer norm_layer: (nn.Module): normalization layer act_layer: (nn.Module): MLP activation layer @@ -384,11 +390,10 @@ def _create_vision_transformer_relpos(variant, pretrained=False, **kwargs): @register_model def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs): - """ ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token + """ ViT-Base (ViT-B/32+) w/ relative log-coord position and residual post-norm, no class token """ model_kwargs = dict( - patch_size=32, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, - block_fn=ResPostRelPosBlock, **kwargs) + patch_size=32, embed_dim=896, depth=12, num_heads=14, block_fn=ResPostRelPosBlock, **kwargs) model = _create_vision_transformer_relpos( 'vit_relpos_base_patch32_plus_rpn_256', pretrained=pretrained, **model_kwargs) return model @@ -398,7 +403,7 @@ def vit_relpos_base_patch32_plus_rpn_256(pretrained=False, **kwargs): def vit_relpos_base_patch16_plus_240(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16+) w/ relative log-coord position, no class token """ - model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, init_values=1e-5, **kwargs) + model_kwargs = dict(patch_size=16, embed_dim=896, depth=12, num_heads=14, **kwargs) model = _create_vision_transformer_relpos('vit_relpos_base_patch16_plus_240', pretrained=pretrained, **model_kwargs) return model @@ -408,8 +413,7 @@ def vit_relpos_base_patch16_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/ relative log-coord position, no class token """ model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, - fc_norm=True, **kwargs) + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, **kwargs) model = _create_vision_transformer_relpos('vit_relpos_base_patch16_224', pretrained=pretrained, **model_kwargs) return model @@ -419,7 +423,6 @@ def vit_relpos_base_patch16_rpn_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token """ model_kwargs = dict( - patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, init_values=1e-5, - block_fn=ResPostRelPosBlock, **kwargs) + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, block_fn=ResPostRelPosBlock, **kwargs) model = _create_vision_transformer_relpos('vit_relpos_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs) return model From 578d52e7522bb20eba36fa1ab341a37eb088dc67 Mon Sep 17 00:00:00 2001 From: okojoalg Date: Thu, 5 May 2022 23:22:40 +0900 Subject: [PATCH 10/32] Add Sequencer --- timm/models/__init__.py | 1 + timm/models/sequencer.py | 389 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 390 insertions(+) create mode 100644 timm/models/sequencer.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 45ead5dc..2b5d6031 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -39,6 +39,7 @@ from .resnetv2 import * from .rexnet import * from .selecsls import * from .senet import * +from .sequencer import * from .sknet import * from .swin_transformer import * from .swin_transformer_v2_cr import * diff --git a/timm/models/sequencer.py b/timm/models/sequencer.py new file mode 100644 index 00000000..88003540 --- /dev/null +++ b/timm/models/sequencer.py @@ -0,0 +1,389 @@ +""" Sequencer + +Paper: `Sequencer: Deep LSTM for Image Classification` - https://arxiv.org/pdf/2205.01972.pdf + +""" +# Copyright (c) 2022. Yuki Tatsunami +# Licensed under the Apache License, Version 2.0 (the "License"); + + +import math +from functools import partial +from typing import Tuple + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, DEFAULT_CROP_PCT +from .helpers import build_model_with_cfg, named_apply +from .layers import lecun_normal_, DropPath, Mlp, PatchEmbed as TimmPatchEmbed +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': DEFAULT_CROP_PCT, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = dict( + sequencer2d_s=_cfg(url="https://github.com/okojoalg/sequencer/releases/download/weights/sequencer2d_s.pth"), + sequencer2d_m=_cfg(url="https://github.com/okojoalg/sequencer/releases/download/weights/sequencer2d_m.pth"), + sequencer2d_l=_cfg(url="https://github.com/okojoalg/sequencer/releases/download/weights/sequencer2d_l.pth"), +) + + +def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=False): + if isinstance(module, nn.Linear): + if name.startswith('head'): + nn.init.zeros_(module.weight) + nn.init.constant_(module.bias, head_bias) + else: + if flax: + # Flax defaults + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + else: + nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + if 'mlp' in name: + nn.init.normal_(module.bias, std=1e-6) + else: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Conv2d): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.RNN, nn.GRU, nn.LSTM)): + stdv = 1.0 / math.sqrt(module.hidden_size) + for weight in module.parameters(): + nn.init.uniform_(weight, -stdv, stdv) + elif hasattr(module, 'init_weights'): + module.init_weights() + + +def get_stage(index, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer, rnn_layer, mlp_layer, + norm_layer, act_layer, num_layers, bidirectional, union, + with_fc, drop=0., drop_path_rate=0., **kwargs): + assert len(layers) == len(patch_sizes) == len(embed_dims) == len(hidden_sizes) == len(mlp_ratios) + blocks = [] + for block_idx in range(layers[index]): + drop_path = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) + blocks.append(block_layer(embed_dims[index], hidden_sizes[index], mlp_ratio=mlp_ratios[index], + rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, + act_layer=act_layer, num_layers=num_layers, + bidirectional=bidirectional, union=union, with_fc=with_fc, + drop=drop, drop_path=drop_path)) + + if index < len(embed_dims) - 1: + blocks.append(Downsample2D(embed_dims[index], embed_dims[index + 1], patch_sizes[index + 1])) + + blocks = nn.Sequential(*blocks) + return blocks + + +class RNNIdentity(nn.Module): + def __init__(self, *args, **kwargs): + super(RNNIdentity, self).__init__() + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, None]: + return x, None + + +class RNN2DBase(nn.Module): + + def __init__(self, input_size: int, hidden_size: int, + num_layers: int = 1, bias: bool = True, bidirectional: bool = True, + union="cat", with_fc=True): + super().__init__() + + self.input_size = input_size + self.hidden_size = hidden_size + self.output_size = 2 * hidden_size if bidirectional else hidden_size + self.union = union + + self.with_vertical = True + self.with_horizontal = True + self.with_fc = with_fc + + if with_fc: + if union == "cat": + self.fc = nn.Linear(2 * self.output_size, input_size) + elif union == "add": + self.fc = nn.Linear(self.output_size, input_size) + elif union == "vertical": + self.fc = nn.Linear(self.output_size, input_size) + self.with_horizontal = False + elif union == "horizontal": + self.fc = nn.Linear(self.output_size, input_size) + self.with_vertical = False + else: + raise ValueError("Unrecognized union: " + union) + elif union == "cat": + pass + if 2 * self.output_size != input_size: + raise ValueError(f"The output channel {2 * self.output_size} is different from the input channel {input_size}.") + elif union == "add": + pass + if self.output_size != input_size: + raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.") + elif union == "vertical": + if self.output_size != input_size: + raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.") + self.with_horizontal = False + elif union == "horizontal": + if self.output_size != input_size: + raise ValueError(f"The output channel {self.output_size} is different from the input channel {input_size}.") + self.with_vertical = False + else: + raise ValueError("Unrecognized union: " + union) + + self.rnn_v = RNNIdentity() + self.rnn_h = RNNIdentity() + + def forward(self, x): + B, H, W, C = x.shape + + if self.with_vertical: + v = x.permute(0, 2, 1, 3) + v = v.reshape(-1, H, C) + v, _ = self.rnn_v(v) + v = v.reshape(B, W, H, -1) + v = v.permute(0, 2, 1, 3) + + if self.with_horizontal: + h = x.reshape(-1, W, C) + h, _ = self.rnn_h(h) + h = h.reshape(B, H, W, -1) + + if self.with_vertical and self.with_horizontal: + if self.union == "cat": + x = torch.cat([v, h], dim=-1) + else: + x = v + h + elif self.with_vertical: + x = v + elif self.with_horizontal: + x = h + + if self.with_fc: + x = self.fc(x) + + return x + + +class LSTM2D(RNN2DBase): + + def __init__(self, input_size: int, hidden_size: int, + num_layers: int = 1, bias: bool = True, bidirectional: bool = True, + union="cat", with_fc=True): + super().__init__(input_size, hidden_size, num_layers, bias, bidirectional, union, with_fc) + if self.with_vertical: + self.rnn_v = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bias=bias, bidirectional=bidirectional) + if self.with_horizontal: + self.rnn_h = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bias=bias, bidirectional=bidirectional) + + +class Sequencer2DBlock(nn.Module): + def __init__(self, dim, hidden_size, mlp_ratio=3.0, rnn_layer=LSTM2D, mlp_layer=Mlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, + num_layers=1, bidirectional=True, union="cat", with_fc=True, + drop=0., drop_path=0.): + super().__init__() + channels_dim = int(mlp_ratio * dim) + self.norm1 = norm_layer(dim) + self.rnn_tokens = rnn_layer(dim, hidden_size, num_layers=num_layers, bidirectional=bidirectional, + union=union, with_fc=with_fc) + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp_channels = mlp_layer(dim, channels_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.drop_path(self.rnn_tokens(self.norm1(x))) + x = x + self.drop_path(self.mlp_channels(self.norm2(x))) + return x + + +class PatchEmbed(TimmPatchEmbed): + def forward(self, x): + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + else: + x = x.permute(0, 2, 3, 1) # BCHW -> BHWC + x = self.norm(x) + return x + + +class Shuffle(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + if self.training: + B, H, W, C = x.shape + r = torch.randperm(H * W) + x = x.reshape(B, -1, C) + x = x[:, r, :].reshape(B, H, W, -1) + return x + + +class Downsample2D(nn.Module): + def __init__(self, input_dim, output_dim, patch_size): + super().__init__() + self.down = nn.Conv2d(input_dim, output_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + x = x.permute(0, 3, 1, 2) + x = self.down(x) + x = x.permute(0, 2, 3, 1) + return x + + +class Sequencer2D(nn.Module): + def __init__( + self, + num_classes=1000, + img_size=224, + in_chans=3, + layers=[4, 3, 8, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + block_layer=Sequencer2DBlock, + rnn_layer=LSTM2D, + mlp_layer=Mlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + num_rnn_layers=1, + bidirectional=True, + union="cat", + with_fc=True, + drop_rate=0., + drop_path_rate=0., + nlhb=False, + stem_norm=False, + ): + super().__init__() + self.num_classes = num_classes + self.num_features = embed_dims[0] # num_features for consistency with other models + self.embed_dims = embed_dims + self.stem = PatchEmbed( + img_size=img_size, patch_size=patch_sizes[0], in_chans=in_chans, + embed_dim=embed_dims[0], norm_layer=norm_layer if stem_norm else None, + flatten=False) + + self.blocks = nn.Sequential(*[ + get_stage( + i, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer=block_layer, + rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, + num_layers=num_rnn_layers, bidirectional=bidirectional, + union=union, with_fc=with_fc, drop=drop_rate, drop_path_rate=drop_path_rate, + ) + for i, _ in enumerate(embed_dims)]) + + self.norm = norm_layer(embed_dims[-1]) + self.head = nn.Linear(embed_dims[-1], self.num_classes) if num_classes > 0 else nn.Identity() + + self.init_weights(nlhb=nlhb) + + def init_weights(self, nlhb=False): + head_bias = -math.log(self.num_classes) if nlhb else 0. + named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.stem(x) + x = self.blocks(x) + x = self.norm(x) + x = x.mean(dim=(1, 2)) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + return state_dict + + +def _create_sequencer2d(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Sequencer2D models.') + + model = build_model_with_cfg( + Sequencer2D, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + +# main + +@register_model +def sequencer2d_s(pretrained=False, **kwargs): + model_args = dict( + layers=[4, 3, 8, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM2D, + bidirectional=True, + union="cat", + with_fc=True, + **kwargs) + model = _create_sequencer2d('sequencer2d_s', pretrained=pretrained, **model_args) + return model + + +@register_model +def sequencer2d_m(pretrained=False, **kwargs): + model_args = dict( + layers=[4, 3, 14, 3], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM2D, + bidirectional=True, + union="cat", + with_fc=True, + **kwargs) + model = _create_sequencer2d('sequencer2d_m', pretrained=pretrained, **model_args) + return model + + +@register_model +def sequencer2d_l(pretrained=False, **kwargs): + model_args = dict( + layers=[8, 8, 16, 4], + patch_sizes=[7, 2, 1, 1], + embed_dims=[192, 384, 384, 384], + hidden_sizes=[48, 96, 96, 96], + mlp_ratios=[3.0, 3.0, 3.0, 3.0], + rnn_layer=LSTM2D, + bidirectional=True, + union="cat", + with_fc=True, + **kwargs) + model = _create_sequencer2d('sequencer2d_l', pretrained=pretrained, **model_args) + return model From 57a988df3090eabfbde45a4d31996a229576bb05 Mon Sep 17 00:00:00 2001 From: han Date: Fri, 6 May 2022 13:14:43 +0900 Subject: [PATCH 11/32] fix: multistep lr decay epoch bugs - add milestones arguments - change decay_epochs to milestones variable --- timm/scheduler/scheduler_factory.py | 2 +- train.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/timm/scheduler/scheduler_factory.py b/timm/scheduler/scheduler_factory.py index 72a979c2..2f5a49fa 100644 --- a/timm/scheduler/scheduler_factory.py +++ b/timm/scheduler/scheduler_factory.py @@ -71,7 +71,7 @@ def create_scheduler(args, optimizer): elif args.sched == 'multistep': lr_scheduler = MultiStepLRScheduler( optimizer, - decay_t=args.decay_epochs, + decay_t=args.milestones, decay_rate=args.decay_rate, warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, diff --git a/train.py b/train.py index cc2b2b4d..ec8778d9 100755 --- a/train.py +++ b/train.py @@ -171,6 +171,8 @@ parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N', help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') parser.add_argument('--start-epoch', default=None, type=int, metavar='N', help='manual epoch number (useful on restarts)') +parser.add_argument('--milestones', default=[30, 60], type=int, nargs='+', metavar="MILESTONES", + help='list of epoch indices for multistep lr. must be increasing') parser.add_argument('--decay-epochs', type=float, default=100, metavar='N', help='epoch interval to decay LR') parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', From 2fec08e9232f15cee50868123067c3c9d9014014 Mon Sep 17 00:00:00 2001 From: okojoalg Date: Fri, 6 May 2022 23:08:10 +0900 Subject: [PATCH 12/32] Add Sequencer to non std filters --- tests/test_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 4e50de6e..f06ddd95 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -25,7 +25,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*', - 'poolformer_*', 'volo_*'] + 'poolformer_*', 'volo_*', 'sequencer2d_*'] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures From 93a79a3dd99f53e759525a5be945e8ba93678009 Mon Sep 17 00:00:00 2001 From: okojoalg Date: Fri, 6 May 2022 23:16:32 +0900 Subject: [PATCH 13/32] Fix num_features in Sequencer --- timm/models/sequencer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/sequencer.py b/timm/models/sequencer.py index 88003540..3ffaf02b 100644 --- a/timm/models/sequencer.py +++ b/timm/models/sequencer.py @@ -276,7 +276,7 @@ class Sequencer2D(nn.Module): ): super().__init__() self.num_classes = num_classes - self.num_features = embed_dims[0] # num_features for consistency with other models + self.num_features = embed_dims[-1] # num_features for consistency with other models self.embed_dims = embed_dims self.stem = PatchEmbed( img_size=img_size, patch_size=patch_sizes[0], in_chans=in_chans, From 37b6920df3f8a271e8286ec34c9cb277c553d16c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 9 May 2022 10:40:40 -0700 Subject: [PATCH 14/32] Fix group_matcher regex for regnet.py --- timm/models/regnet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 3e22bf56..9d1f1f64 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -458,7 +458,7 @@ class RegNet(nn.Module): def group_matcher(self, coarse=False): return dict( stem=r'^stem', - blocks=r'^stages\.(\d+)' if coarse else r'^stages\.(\d+)\.blocks\.(\d+)', + blocks=r'^s(\d+)' if coarse else r'^s(\d+)\.b(\d+)', ) @torch.jit.ignore From d79f3d9d1ed1916549240a765f8cc6f958426878 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 9 May 2022 12:09:39 -0700 Subject: [PATCH 15/32] Fix torchscript use for sequencer, add group_matcher, forward_head support, minor formatting --- timm/models/sequencer.py | 93 ++++++++++++++++++++++++++-------------- 1 file changed, 60 insertions(+), 33 deletions(-) diff --git a/timm/models/sequencer.py b/timm/models/sequencer.py index 3ffaf02b..5fff04d1 100644 --- a/timm/models/sequencer.py +++ b/timm/models/sequencer.py @@ -71,18 +71,19 @@ def _init_weights(module: nn.Module, name: str, head_bias: float = 0., flax=Fals module.init_weights() -def get_stage(index, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer, rnn_layer, mlp_layer, - norm_layer, act_layer, num_layers, bidirectional, union, - with_fc, drop=0., drop_path_rate=0., **kwargs): +def get_stage( + index, layers, patch_sizes, embed_dims, hidden_sizes, mlp_ratios, block_layer, rnn_layer, mlp_layer, + norm_layer, act_layer, num_layers, bidirectional, union, + with_fc, drop=0., drop_path_rate=0., **kwargs): assert len(layers) == len(patch_sizes) == len(embed_dims) == len(hidden_sizes) == len(mlp_ratios) blocks = [] for block_idx in range(layers[index]): drop_path = drop_path_rate * (block_idx + sum(layers[:index])) / (sum(layers) - 1) - blocks.append(block_layer(embed_dims[index], hidden_sizes[index], mlp_ratio=mlp_ratios[index], - rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, - act_layer=act_layer, num_layers=num_layers, - bidirectional=bidirectional, union=union, with_fc=with_fc, - drop=drop, drop_path=drop_path)) + blocks.append(block_layer( + embed_dims[index], hidden_sizes[index], mlp_ratio=mlp_ratios[index], + rnn_layer=rnn_layer, mlp_layer=mlp_layer, norm_layer=norm_layer, act_layer=act_layer, + num_layers=num_layers, bidirectional=bidirectional, union=union, with_fc=with_fc, + drop=drop, drop_path=drop_path)) if index < len(embed_dims) - 1: blocks.append(Downsample2D(embed_dims[index], embed_dims[index + 1], patch_sizes[index + 1])) @@ -101,9 +102,10 @@ class RNNIdentity(nn.Module): class RNN2DBase(nn.Module): - def __init__(self, input_size: int, hidden_size: int, - num_layers: int = 1, bias: bool = True, bidirectional: bool = True, - union="cat", with_fc=True): + def __init__( + self, input_size: int, hidden_size: int, + num_layers: int = 1, bias: bool = True, bidirectional: bool = True, + union="cat", with_fc=True): super().__init__() self.input_size = input_size @@ -115,6 +117,7 @@ class RNN2DBase(nn.Module): self.with_horizontal = True self.with_fc = with_fc + self.fc = None if with_fc: if union == "cat": self.fc = nn.Linear(2 * self.output_size, input_size) @@ -159,23 +162,27 @@ class RNN2DBase(nn.Module): v, _ = self.rnn_v(v) v = v.reshape(B, W, H, -1) v = v.permute(0, 2, 1, 3) + else: + v = None if self.with_horizontal: h = x.reshape(-1, W, C) h, _ = self.rnn_h(h) h = h.reshape(B, H, W, -1) + else: + h = None - if self.with_vertical and self.with_horizontal: + if v is not None and h is not None: if self.union == "cat": x = torch.cat([v, h], dim=-1) else: x = v + h - elif self.with_vertical: + elif v is not None: x = v - elif self.with_horizontal: + elif h is not None: x = h - if self.with_fc: + if self.fc is not None: x = self.fc(x) return x @@ -183,9 +190,10 @@ class RNN2DBase(nn.Module): class LSTM2D(RNN2DBase): - def __init__(self, input_size: int, hidden_size: int, - num_layers: int = 1, bias: bool = True, bidirectional: bool = True, - union="cat", with_fc=True): + def __init__( + self, input_size: int, hidden_size: int, + num_layers: int = 1, bias: bool = True, bidirectional: bool = True, + union="cat", with_fc=True): super().__init__(input_size, hidden_size, num_layers, bias, bidirectional, union, with_fc) if self.with_vertical: self.rnn_v = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bias=bias, bidirectional=bidirectional) @@ -194,10 +202,10 @@ class LSTM2D(RNN2DBase): class Sequencer2DBlock(nn.Module): - def __init__(self, dim, hidden_size, mlp_ratio=3.0, rnn_layer=LSTM2D, mlp_layer=Mlp, - norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, - num_layers=1, bidirectional=True, union="cat", with_fc=True, - drop=0., drop_path=0.): + def __init__( + self, dim, hidden_size, mlp_ratio=3.0, rnn_layer=LSTM2D, mlp_layer=Mlp, + norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=nn.GELU, + num_layers=1, bidirectional=True, union="cat", with_fc=True, drop=0., drop_path=0.): super().__init__() channels_dim = int(mlp_ratio * dim) self.norm1 = norm_layer(dim) @@ -255,6 +263,7 @@ class Sequencer2D(nn.Module): num_classes=1000, img_size=224, in_chans=3, + global_pool='avg', layers=[4, 3, 8, 3], patch_sizes=[7, 2, 1, 1], embed_dims=[192, 384, 384, 384], @@ -275,7 +284,9 @@ class Sequencer2D(nn.Module): stem_norm=False, ): super().__init__() + assert global_pool in ('', 'avg') self.num_classes = num_classes + self.global_pool = global_pool self.num_features = embed_dims[-1] # num_features for consistency with other models self.embed_dims = embed_dims self.stem = PatchEmbed( @@ -301,38 +312,54 @@ class Sequencer2D(nn.Module): head_bias = -math.log(self.num_classes) if nlhb else 0. named_apply(partial(_init_weights, head_bias=head_bias), module=self) # depth-first + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=[ + (r'^blocks\.(\d+)\..*\.down', (99999,)), + (r'^blocks\.(\d+)', None) if coarse else (r'^blocks\.(\d+)\.(\d+)', None), + (r'^norm', (99999,)) + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + assert not enable, 'gradient checkpointing not supported' + + @torch.jit.ignore def get_classifier(self): return self.head - def reset_classifier(self, num_classes, global_pool=''): + def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes + if self.global_pool is not None: + assert global_pool in ('', 'avg') + self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.stem(x) x = self.blocks(x) x = self.norm(x) - x = x.mean(dim=(1, 2)) return x + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool == 'avg': + x = x.mean(dim=(1, 2)) + return x if pre_logits else self.head(x) + def forward(self, x): x = self.forward_features(x) - x = self.head(x) + x = self.forward_head(x) return x -def checkpoint_filter_fn(state_dict, model): - return state_dict - - def _create_sequencer2d(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Sequencer2D models.') - model = build_model_with_cfg( - Sequencer2D, variant, pretrained, - pretrained_filter_fn=checkpoint_filter_fn, - **kwargs) + model = build_model_with_cfg(Sequencer2D, variant, pretrained, **kwargs) return model From 78a32655fab61614b4399f08e34396539fc9026e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 9 May 2022 12:20:04 -0700 Subject: [PATCH 16/32] Fix poolformer group_matcher to merge proj downsample with previous block, support coarse --- timm/models/poolformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/poolformer.py b/timm/models/poolformer.py index 3cf6b1a3..17d657b0 100644 --- a/timm/models/poolformer.py +++ b/timm/models/poolformer.py @@ -234,8 +234,8 @@ class PoolFormer(nn.Module): return dict( stem=r'^patch_embed', # stem and embed blocks=[ - (r'^network\.(\d+)\.(\d+)', None), - (r'^network\.(\d+)', (0,)), + (r'^network\.(\d+).*\.proj', (99999,)), + (r'^network\.(\d+)', None) if coarse else (r'^network\.(\d+)\.(\d+)', None), (r'^norm', (99999,)) ], ) From 39b725e1c90dd5902f48b2eef2a800a3b221ca47 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 9 May 2022 15:20:24 -0700 Subject: [PATCH 17/32] Fix tests for rank-4 output where feature channels dim is -1 (3) and not 1 --- tests/test_models.py | 10 +++++++--- timm/models/sequencer.py | 3 ++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index f06ddd95..6489892c 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -202,13 +202,15 @@ def test_model_default_cfgs_non_std(model_name, batch_size): pytest.skip("Fixed input size model > limit.") input_tensor = torch.randn((batch_size, *input_size)) + feat_dim = getattr(model, 'feature_dim', None) outputs = model.forward_features(input_tensor) if isinstance(outputs, (tuple, list)): # cannot currently verify multi-tensor output. pass else: - feat_dim = -1 if outputs.ndim == 3 else 1 + if feat_dim is None: + feat_dim = -1 if outputs.ndim == 3 else 1 assert outputs.shape[feat_dim] == model.num_features # test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features @@ -216,14 +218,16 @@ def test_model_default_cfgs_non_std(model_name, batch_size): outputs = model.forward(input_tensor) if isinstance(outputs, (tuple, list)): outputs = outputs[0] - feat_dim = -1 if outputs.ndim == 3 else 1 + if feat_dim is None: + feat_dim = -1 if outputs.ndim == 3 else 1 assert outputs.shape[feat_dim] == model.num_features, 'pooled num_features != config' model = create_model(model_name, pretrained=False, num_classes=0).eval() outputs = model.forward(input_tensor) if isinstance(outputs, (tuple, list)): outputs = outputs[0] - feat_dim = -1 if outputs.ndim == 3 else 1 + if feat_dim is None: + feat_dim = -1 if outputs.ndim == 3 else 1 assert outputs.shape[feat_dim] == model.num_features # check classifier name matches default_cfg diff --git a/timm/models/sequencer.py b/timm/models/sequencer.py index 5fff04d1..b1ae92a4 100644 --- a/timm/models/sequencer.py +++ b/timm/models/sequencer.py @@ -288,6 +288,7 @@ class Sequencer2D(nn.Module): self.num_classes = num_classes self.global_pool = global_pool self.num_features = embed_dims[-1] # num_features for consistency with other models + self.feature_dim = -1 # channel dim index for feature outputs (rank 4, NHWC) self.embed_dims = embed_dims self.stem = PatchEmbed( img_size=img_size, patch_size=patch_sizes[0], in_chans=in_chans, @@ -333,7 +334,7 @@ class Sequencer2D(nn.Module): def reset_classifier(self, num_classes, global_pool=None): self.num_classes = num_classes - if self.global_pool is not None: + if global_pool is not None: assert global_pool in ('', 'avg') self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() From a16171335b6d68d8578cae39ce54ab2e7d81b21d Mon Sep 17 00:00:00 2001 From: han Date: Tue, 10 May 2022 07:57:19 +0900 Subject: [PATCH 18/32] fix: change milestones to decay-milestones - change argparser option `milestone` to `decay-milestone` --- timm/scheduler/scheduler_factory.py | 2 +- train.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/timm/scheduler/scheduler_factory.py b/timm/scheduler/scheduler_factory.py index 2f5a49fa..3e100fe0 100644 --- a/timm/scheduler/scheduler_factory.py +++ b/timm/scheduler/scheduler_factory.py @@ -71,7 +71,7 @@ def create_scheduler(args, optimizer): elif args.sched == 'multistep': lr_scheduler = MultiStepLRScheduler( optimizer, - decay_t=args.milestones, + decay_t=args.decay_milestones, decay_rate=args.decay_rate, warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, diff --git a/train.py b/train.py index ec8778d9..6f31e295 100755 --- a/train.py +++ b/train.py @@ -171,8 +171,8 @@ parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N', help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') parser.add_argument('--start-epoch', default=None, type=int, metavar='N', help='manual epoch number (useful on restarts)') -parser.add_argument('--milestones', default=[30, 60], type=int, nargs='+', metavar="MILESTONES", - help='list of epoch indices for multistep lr. must be increasing') +parser.add_argument('--decay-milestones', default=[30, 60], type=int, nargs='+', metavar="MILESTONES", + help='list of decay epoch indices for multistep lr. must be increasing') parser.add_argument('--decay-epochs', type=float, default=100, metavar='N', help='epoch interval to decay LR') parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', From 9a86b900fa6ea62d39ecd7c7a95e9523fe4cf264 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 12 May 2022 15:01:23 -0700 Subject: [PATCH 19/32] Official SwinV2 models --- timm/models/__init__.py | 1 + timm/models/helpers.py | 2 +- timm/models/swin_transformer_v2.py | 736 ++++++++++++++++++++++++++ timm/models/swin_transformer_v2_cr.py | 90 ++-- 4 files changed, 783 insertions(+), 46 deletions(-) create mode 100644 timm/models/swin_transformer_v2.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 32212fca..8cb6c70a 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -42,6 +42,7 @@ from .senet import * from .sequencer import * from .sknet import * from .swin_transformer import * +from .swin_transformer_v2 import * from .swin_transformer_v2_cr import * from .tnt import * from .tresnet import * diff --git a/timm/models/helpers.py b/timm/models/helpers.py index c4f48d6a..1276b68e 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -477,7 +477,7 @@ def build_model_with_cfg( pretrained_cfg: Optional[Dict] = None, model_cfg: Optional[Any] = None, feature_cfg: Optional[Dict] = None, - pretrained_strict: bool = False, + pretrained_strict: bool = True, pretrained_filter_fn: Optional[Callable] = None, pretrained_custom_load: bool = False, kwargs_filter: Optional[Tuple[str]] = None, diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py new file mode 100644 index 00000000..29c0be9e --- /dev/null +++ b/timm/models/swin_transformer_v2.py @@ -0,0 +1,736 @@ +""" Swin Transformer V2 +A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` + - https://arxiv.org/abs/2111.09883 + +Code/weights from https://github.com/microsoft/Swin-Transformer, original copyright/license info below + +Modifications and additions for timm hacked together by / Copyright 2022, Ross Wightman +""" +# -------------------------------------------------------- +# Swin Transformer V2 +# Copyright (c) 2022 Microsoft +# Licensed under The MIT License [see LICENSE for details] +# Written by Ze Liu +# -------------------------------------------------------- +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .fx_features import register_notrace_function +from .helpers import build_model_with_cfg, named_apply +from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert +from .registry import register_model +from .vision_transformer import checkpoint_filter_fn, get_init_weights_vit + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.proj', 'classifier': 'head', + **kwargs + } + + +default_cfgs = { + 'swinv2_tiny_window8_256.': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth', + input_size=(3, 256, 256) + ), + 'swinv2_tiny_window16_256': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window16_256.pth', + input_size=(3, 256, 256) + ), + 'swinv2_small_window8_256': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window8_256.pth', + input_size=(3, 256, 256) + ), + 'swinv2_small_window16_256': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_small_patch4_window16_256.pth', + input_size=(3, 256, 256) + ), + 'swinv2_base_window8_256': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window8_256.pth', + input_size=(3, 256, 256) + ), + 'swinv2_base_window16_256': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window16_256.pth', + input_size=(3, 256, 256) + ), + + 'swinv2_base_window12_192_22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12_192_22k.pth', + num_classes=21841, input_size=(3, 192, 192) + ), + 'swinv2_base_window12to16_192to256_22kft1k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to16_192to256_22kto1k_ft.pth', + input_size=(3, 256, 256) + ), + 'swinv2_base_window12to24_192to384_22kft1k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth', + input_size=(3, 384, 384) + ), + 'swinv2_large_window12_192_22k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth', + num_classes=21841, input_size=(3, 192, 192) + ), + 'swinv2_large_window12to16_192to256_22kft1k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to16_192to256_22kto1k_ft.pth', + input_size=(3, 256, 256) + ), + 'swinv2_large_window12to24_192to384_22kft1k': _cfg( + url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth', + input_size=(3, 384, 384) + ), +} + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + pretrained_window_size (tuple[int]): The height and width of the window in pre-training. + """ + + def __init__( + self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0., + pretrained_window_size=[0, 0]): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.pretrained_window_size = pretrained_window_size + self.num_heads = num_heads + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) + + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential( + nn.Linear(2, 512, bias=True), + nn.ReLU(inplace=True), + nn.Linear(512, num_heads, bias=False) + ) + + # get relative_coords_table + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) + relative_coords_table = torch.stack(torch.meshgrid([ + relative_coords_h, + relative_coords_w])).permute(1, 2, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1, 2 + if pretrained_window_size[0] > 0: + relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) + else: + relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / math.log2(8) + + self.register_buffer("relative_coords_table", relative_coords_table) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(dim)) + self.v_bias = nn.Parameter(torch.zeros(dim)) + else: + self.q_bias = None + self.v_bias = None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + # cosine attention + attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) + logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp() + attn = attn * logit_scale + + relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) + relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + pretrained_window_size (int): Window size in pretraining. + """ + + def __init__( + self, dim, input_resolution, num_heads, window_size=7, shift_size=0, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + _assert(0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size") + + self.attn = WindowAttention( + dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, + qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, + pretrained_window_size=to_2tuple(pretrained_window_size)) + self.norm1 = norm_layer(dim) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop) + self.norm2 = norm_layer(dim) + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + cnt = 0 + for h in ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)): + for w in ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None)): + img_mask[:, h, w, :] = cnt + cnt += 1 + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def _attn(self, x): + H, W = self.input_resolution + B, L, C = x.shape + _assert(L == H * W, "input feature has wrong size") + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + x = x.view(B, H * W, C) + return x + + def forward(self, x): + x = x + self.drop_path1(self.norm1(self._attn(x))) + x = x + self.drop_path2(self.norm2(self.mlp(x))) + return x + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(2 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + _assert(L == H * W, "input feature has wrong size") + _assert(H % 2 == 0, f"x size ({H}*{W}) are not even.") + _assert(W % 2 == 0, f"x size ({H}*{W}) are not even.") + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.reduction(x) + x = self.norm(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + pretrained_window_size (int): Local window size in pre-training. + """ + + def __init__( + self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0., + norm_layer=nn.LayerNorm, downsample=None, pretrained_window_size=0): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.grad_checkpointing = False + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + pretrained_window_size=pretrained_window_size) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = nn.Identity() + + def forward(self, x): + for blk in self.blocks: + if self.grad_checkpointing: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + x = self.downsample(x) + return x + + def _init_respostnorm(self): + for blk in self.blocks: + nn.init.constant_(blk.norm1.bias, 0) + nn.init.constant_(blk.norm1.weight, 0) + nn.init.constant_(blk.norm2.bias, 0) + nn.init.constant_(blk.norm2.weight, 0) + + +class SwinTransformerV2(nn.Module): + r""" Swin Transformer V2 + A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` + - https://arxiv.org/abs/2111.09883 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer. + """ + + def __init__( + self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, global_pool='avg', + embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), + window_size=7, mlp_ratio=4., qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, + norm_layer=nn.LayerNorm, ape=False, patch_norm=True, + pretrained_window_sizes=(0, 0, 0, 0), **kwargs): + super().__init__() + + self.num_classes = num_classes + assert global_pool in ('', 'avg') + self.global_pool = global_pool + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.patch_norm = patch_norm + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + + # absolute position embedding + if ape: + self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + else: + self.absolute_pos_embed = None + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2 ** i_layer), + input_resolution=( + self.patch_embed.grid_size[0] // (2 ** i_layer), + self.patch_embed.grid_size[1] // (2 ** i_layer)), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + pretrained_window_size=pretrained_window_sizes[i_layer] + ) + self.layers.append(layer) + + self.norm = norm_layer(self.num_features) + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + self.apply(self._init_weights) + for bly in self.layers: + bly._init_respostnorm() + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + nod = {'absolute_pos_embed'} + for n, m in self.named_modules(): + if any([kw in n for kw in ("cpb_mlp", "logit_scale", 'relative_position_bias_table')]): + nod.add(n) + return nod + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^absolute_pos_embed|patch_embed', # stem and embed + blocks=r'^layers\.(\d+)' if coarse else [ + (r'^layers\.(\d+).downsample', (0,)), + (r'^layers\.(\d+)\.\w+\.(\d+)', None), + (r'^norm', (99999,)), + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for l in self.layers: + l.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + if global_pool is not None: + assert global_pool in ('', 'avg') + self.global_pool = global_pool + self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.patch_embed(x) + if self.absolute_pos_embed is not None: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + for layer in self.layers: + x = layer(x) + + x = self.norm(x) # B L C + return x + + def forward_head(self, x, pre_logits: bool = False): + if self.global_pool == 'avg': + x = x.mean(dim=1) + return x if pre_logits else self.head(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _create_swin_transformer_v2(variant, pretrained=False, **kwargs): + model = build_model_with_cfg( + SwinTransformerV2, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + +@register_model +def swinv2_tiny_window16_256(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=16, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer_v2('swinv2_tiny_window16_256', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_tiny_window8_256(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=8, embed_dim=96, depths=(2, 2, 6, 2), num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer_v2('swinv2_tiny_window8_256', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_small_window16_256(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=16, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer_v2('swinv2_small_window16_256', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_small_window8_256(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=8, embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), **kwargs) + return _create_swin_transformer_v2('swinv2_small_window8_256', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_base_window16_256(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=16, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer_v2('swinv2_base_window16_256', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_base_window8_256(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=8, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer_v2('swinv2_base_window8_256', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_base_window12_192_22k(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=12, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), **kwargs) + return _create_swin_transformer_v2('swinv2_base_window12_192_22k', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_base_window12to16_192to256_22kft1k(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=16, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), + pretrained_window_sizes=(12, 12, 12, 6), **kwargs) + return _create_swin_transformer_v2( + 'swinv2_base_window12to16_192to256_22kft1k', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_base_window12to24_192to384_22kft1k(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=24, embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), + pretrained_window_sizes=(12, 12, 12, 6), **kwargs) + return _create_swin_transformer_v2( + 'swinv2_base_window12to24_192to384_22kft1k', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_large_window12_192_22k(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=12, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs) + return _create_swin_transformer_v2('swinv2_large_window12_192_22k', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_large_window12to16_192to256_22kft1k(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=16, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), + pretrained_window_sizes=(12, 12, 12, 6), **kwargs) + return _create_swin_transformer_v2( + 'swinv2_large_window12to16_192to256_22kft1k', pretrained=pretrained, **model_kwargs) + + +@register_model +def swinv2_large_window12to24_192to384_22kft1k(pretrained=False, **kwargs): + """ + """ + model_kwargs = dict( + window_size=24, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), + pretrained_window_sizes=(12, 12, 12, 6), **kwargs) + return _create_swin_transformer_v2( + 'swinv2_large_window12to24_192to384_22kft1k', pretrained=pretrained, **model_kwargs) diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 472ae205..596ee204 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -64,38 +64,38 @@ def _cfg(url='', **kwargs): default_cfgs = { - 'swin_v2_cr_tiny_384': _cfg( + 'swinv2_cr_tiny_384': _cfg( url="", input_size=(3, 384, 384), crop_pct=1.0), - 'swin_v2_cr_tiny_224': _cfg( + 'swinv2_cr_tiny_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), - 'swin_v2_cr_tiny_ns_224': _cfg( + 'swinv2_cr_tiny_ns_224': _cfg( url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth", input_size=(3, 224, 224), crop_pct=0.9), - 'swin_v2_cr_small_384': _cfg( + 'swinv2_cr_small_384': _cfg( url="", input_size=(3, 384, 384), crop_pct=1.0), - 'swin_v2_cr_small_224': _cfg( + 'swinv2_cr_small_224': _cfg( url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth", input_size=(3, 224, 224), crop_pct=0.9), - 'swin_v2_cr_small_ns_224': _cfg( + 'swinv2_cr_small_ns_224': _cfg( url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth", input_size=(3, 224, 224), crop_pct=0.9), - 'swin_v2_cr_base_384': _cfg( + 'swinv2_cr_base_384': _cfg( url="", input_size=(3, 384, 384), crop_pct=1.0), - 'swin_v2_cr_base_224': _cfg( + 'swinv2_cr_base_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), - 'swin_v2_cr_base_ns_224': _cfg( + 'swinv2_cr_base_ns_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), - 'swin_v2_cr_large_384': _cfg( + 'swinv2_cr_large_384': _cfg( url="", input_size=(3, 384, 384), crop_pct=1.0), - 'swin_v2_cr_large_224': _cfg( + 'swinv2_cr_large_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), - 'swin_v2_cr_huge_384': _cfg( + 'swinv2_cr_huge_384': _cfg( url="", input_size=(3, 384, 384), crop_pct=1.0), - 'swin_v2_cr_huge_224': _cfg( + 'swinv2_cr_huge_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), - 'swin_v2_cr_giant_384': _cfg( + 'swinv2_cr_giant_384': _cfg( url="", input_size=(3, 384, 384), crop_pct=1.0), - 'swin_v2_cr_giant_224': _cfg( + 'swinv2_cr_giant_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), } @@ -820,7 +820,7 @@ def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs): @register_model -def swin_v2_cr_tiny_384(pretrained=False, **kwargs): +def swinv2_cr_tiny_384(pretrained=False, **kwargs): """Swin-T V2 CR @ 384x384, trained ImageNet-1k""" model_kwargs = dict( embed_dim=96, @@ -828,11 +828,11 @@ def swin_v2_cr_tiny_384(pretrained=False, **kwargs): num_heads=(3, 6, 12, 24), **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_384', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swinv2_cr_tiny_384', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_tiny_224(pretrained=False, **kwargs): +def swinv2_cr_tiny_224(pretrained=False, **kwargs): """Swin-T V2 CR @ 224x224, trained ImageNet-1k""" model_kwargs = dict( embed_dim=96, @@ -840,11 +840,11 @@ def swin_v2_cr_tiny_224(pretrained=False, **kwargs): num_heads=(3, 6, 12, 24), **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_224', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swinv2_cr_tiny_224', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_tiny_ns_224(pretrained=False, **kwargs): +def swinv2_cr_tiny_ns_224(pretrained=False, **kwargs): """Swin-T V2 CR @ 224x224, trained ImageNet-1k w/ extra stage norms. ** Experimental, may make default if results are improved. ** """ @@ -855,11 +855,11 @@ def swin_v2_cr_tiny_ns_224(pretrained=False, **kwargs): extra_norm_stage=True, **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_tiny_ns_224', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swinv2_cr_tiny_ns_224', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_small_384(pretrained=False, **kwargs): +def swinv2_cr_small_384(pretrained=False, **kwargs): """Swin-S V2 CR @ 384x384, trained ImageNet-1k""" model_kwargs = dict( embed_dim=96, @@ -867,12 +867,12 @@ def swin_v2_cr_small_384(pretrained=False, **kwargs): num_heads=(3, 6, 12, 24), **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_small_384', pretrained=pretrained, **model_kwargs + return _create_swin_transformer_v2_cr('swinv2_cr_small_384', pretrained=pretrained, **model_kwargs ) @register_model -def swin_v2_cr_small_224(pretrained=False, **kwargs): +def swinv2_cr_small_224(pretrained=False, **kwargs): """Swin-S V2 CR @ 224x224, trained ImageNet-1k""" model_kwargs = dict( embed_dim=96, @@ -880,11 +880,11 @@ def swin_v2_cr_small_224(pretrained=False, **kwargs): num_heads=(3, 6, 12, 24), **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_small_224', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swinv2_cr_small_224', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_small_ns_224(pretrained=False, **kwargs): +def swinv2_cr_small_ns_224(pretrained=False, **kwargs): """Swin-S V2 CR @ 224x224, trained ImageNet-1k""" model_kwargs = dict( embed_dim=96, @@ -894,11 +894,11 @@ def swin_v2_cr_small_ns_224(pretrained=False, **kwargs): extra_norm_stage=True, **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_small_ns_224', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swinv2_cr_small_ns_224', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_base_384(pretrained=False, **kwargs): +def swinv2_cr_base_384(pretrained=False, **kwargs): """Swin-B V2 CR @ 384x384, trained ImageNet-1k""" model_kwargs = dict( embed_dim=128, @@ -906,11 +906,11 @@ def swin_v2_cr_base_384(pretrained=False, **kwargs): num_heads=(4, 8, 16, 32), **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_base_384', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swinv2_cr_base_384', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_base_224(pretrained=False, **kwargs): +def swinv2_cr_base_224(pretrained=False, **kwargs): """Swin-B V2 CR @ 224x224, trained ImageNet-1k""" model_kwargs = dict( embed_dim=128, @@ -918,11 +918,11 @@ def swin_v2_cr_base_224(pretrained=False, **kwargs): num_heads=(4, 8, 16, 32), **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_base_224', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swinv2_cr_base_224', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_base_ns_224(pretrained=False, **kwargs): +def swinv2_cr_base_ns_224(pretrained=False, **kwargs): """Swin-B V2 CR @ 224x224, trained ImageNet-1k""" model_kwargs = dict( embed_dim=128, @@ -932,11 +932,11 @@ def swin_v2_cr_base_ns_224(pretrained=False, **kwargs): extra_norm_stage=True, **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_base_ns_224', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swinv2_cr_base_ns_224', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_large_384(pretrained=False, **kwargs): +def swinv2_cr_large_384(pretrained=False, **kwargs): """Swin-L V2 CR @ 384x384, trained ImageNet-1k""" model_kwargs = dict( embed_dim=192, @@ -944,12 +944,12 @@ def swin_v2_cr_large_384(pretrained=False, **kwargs): num_heads=(6, 12, 24, 48), **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_large_384', pretrained=pretrained, **model_kwargs + return _create_swin_transformer_v2_cr('swinv2_cr_large_384', pretrained=pretrained, **model_kwargs ) @register_model -def swin_v2_cr_large_224(pretrained=False, **kwargs): +def swinv2_cr_large_224(pretrained=False, **kwargs): """Swin-L V2 CR @ 224x224, trained ImageNet-1k""" model_kwargs = dict( embed_dim=192, @@ -957,11 +957,11 @@ def swin_v2_cr_large_224(pretrained=False, **kwargs): num_heads=(6, 12, 24, 48), **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_large_224', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swinv2_cr_large_224', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_huge_384(pretrained=False, **kwargs): +def swinv2_cr_huge_384(pretrained=False, **kwargs): """Swin-H V2 CR @ 384x384, trained ImageNet-1k""" model_kwargs = dict( embed_dim=352, @@ -970,11 +970,11 @@ def swin_v2_cr_huge_384(pretrained=False, **kwargs): extra_norm_period=6, **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_huge_384', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swinv2_cr_huge_384', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_huge_224(pretrained=False, **kwargs): +def swinv2_cr_huge_224(pretrained=False, **kwargs): """Swin-H V2 CR @ 224x224, trained ImageNet-1k""" model_kwargs = dict( embed_dim=352, @@ -983,11 +983,11 @@ def swin_v2_cr_huge_224(pretrained=False, **kwargs): extra_norm_period=6, **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_huge_224', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swinv2_cr_huge_224', pretrained=pretrained, **model_kwargs) @register_model -def swin_v2_cr_giant_384(pretrained=False, **kwargs): +def swinv2_cr_giant_384(pretrained=False, **kwargs): """Swin-G V2 CR @ 384x384, trained ImageNet-1k""" model_kwargs = dict( embed_dim=512, @@ -996,12 +996,12 @@ def swin_v2_cr_giant_384(pretrained=False, **kwargs): extra_norm_period=6, **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_giant_384', pretrained=pretrained, **model_kwargs + return _create_swin_transformer_v2_cr('swinv2_cr_giant_384', pretrained=pretrained, **model_kwargs ) @register_model -def swin_v2_cr_giant_224(pretrained=False, **kwargs): +def swinv2_cr_giant_224(pretrained=False, **kwargs): """Swin-G V2 CR @ 224x224, trained ImageNet-1k""" model_kwargs = dict( embed_dim=512, @@ -1010,4 +1010,4 @@ def swin_v2_cr_giant_224(pretrained=False, **kwargs): extra_norm_period=6, **kwargs ) - return _create_swin_transformer_v2_cr('swin_v2_cr_giant_224', pretrained=pretrained, **model_kwargs) + return _create_swin_transformer_v2_cr('swinv2_cr_giant_224', pretrained=pretrained, **model_kwargs) From c0211b0bf79ee7e1009d04f11d27a061caa670b6 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 12 May 2022 22:31:55 -0700 Subject: [PATCH 20/32] Swin-V2 test fixes, typo --- tests/test_models.py | 2 +- timm/models/swin_transformer_v2.py | 7 +++++-- timm/models/swin_transformer_v2_cr.py | 14 +++++++------- 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 6489892c..7ea9af6e 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -25,7 +25,7 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): NON_STD_FILTERS = [ 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*', 'levit*', 'visformer*', 'deit*', 'jx_nest_*', 'nest_*', 'xcit_*', 'crossvit_*', 'beit_*', - 'poolformer_*', 'volo_*', 'sequencer2d_*'] + 'poolformer_*', 'volo_*', 'sequencer2d_*', 'swinv2_*'] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 29c0be9e..700012fe 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -39,7 +39,7 @@ def _cfg(url='', **kwargs): default_cfgs = { - 'swinv2_tiny_window8_256.': _cfg( + 'swinv2_tiny_window8_256': _cfg( url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_tiny_patch4_window8_256.pth', input_size=(3, 256, 256) ), @@ -106,6 +106,7 @@ def window_partition(x, window_size): return windows +@register_notrace_function # reason: int argument is a Proxy def window_reverse(windows, window_size, H, W): """ Args: @@ -190,9 +191,11 @@ class WindowAttention(nn.Module): self.qkv = nn.Linear(dim, dim * 3, bias=False) if qkv_bias: self.q_bias = nn.Parameter(torch.zeros(dim)) + self.register_buffer('k_bias', torch.zeros(dim), persistent=False) self.v_bias = nn.Parameter(torch.zeros(dim)) else: self.q_bias = None + self.k_bias = None self.v_bias = None self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) @@ -208,7 +211,7 @@ class WindowAttention(nn.Module): B_, N, C = x.shape qkv_bias = None if self.q_bias is not None: - qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index 596ee204..fcfa217e 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -51,7 +51,7 @@ def _cfg(url='', **kwargs): 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), - 'pool_size': None, + 'pool_size': (7, 7), 'crop_pct': 0.9, 'interpolation': 'bicubic', 'fixed_input_size': True, @@ -65,14 +65,14 @@ def _cfg(url='', **kwargs): default_cfgs = { 'swinv2_cr_tiny_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0), + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), 'swinv2_cr_tiny_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), 'swinv2_cr_tiny_ns_224': _cfg( url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_tiny_ns_224-ba8166c6.pth", input_size=(3, 224, 224), crop_pct=0.9), 'swinv2_cr_small_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0), + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), 'swinv2_cr_small_224': _cfg( url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_224-0813c165.pth", input_size=(3, 224, 224), crop_pct=0.9), @@ -80,21 +80,21 @@ default_cfgs = { url="https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights-swinv2/swin_v2_cr_small_ns_224_iv-2ce90f8e.pth", input_size=(3, 224, 224), crop_pct=0.9), 'swinv2_cr_base_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0), + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), 'swinv2_cr_base_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), 'swinv2_cr_base_ns_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), 'swinv2_cr_large_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0), + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), 'swinv2_cr_large_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), 'swinv2_cr_huge_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0), + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), 'swinv2_cr_huge_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), 'swinv2_cr_giant_384': _cfg( - url="", input_size=(3, 384, 384), crop_pct=1.0), + url="", input_size=(3, 384, 384), crop_pct=1.0, pool_size=(12, 12)), 'swinv2_cr_giant_224': _cfg( url="", input_size=(3, 224, 224), crop_pct=0.9), } From 2f2b22d8c7889174dbf11b92c2d72d8587f9164b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 13 May 2022 09:26:16 -0700 Subject: [PATCH 21/32] Disable nvfuser fma / opt level overrides per #1244 --- timm/utils/jit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/utils/jit.py b/timm/utils/jit.py index 8ebfdbff..a32cbd40 100644 --- a/timm/utils/jit.py +++ b/timm/utils/jit.py @@ -35,8 +35,8 @@ def set_jit_fuser(fuser): torch._C._jit_set_texpr_fuser_enabled(False) elif fuser == "nvfuser" or fuser == "nvf": os.environ['PYTORCH_NVFUSER_DISABLE_FALLBACK'] = '1' - os.environ['PYTORCH_NVFUSER_DISABLE_FMA'] = '1' - os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0' + #os.environ['PYTORCH_NVFUSER_DISABLE_FMA'] = '1' + #os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0' torch._C._jit_set_texpr_fuser_enabled(False) torch._C._jit_set_profiling_executor(True) torch._C._jit_set_profiling_mode(True) From 27c42f0830afab4b2ff40b948cf612328ed26680 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 13 May 2022 09:29:33 -0700 Subject: [PATCH 22/32] Fix torchscript use for offician Swin-V2, add support for non-square window/shift to WindowAttn/Block --- timm/models/swin_transformer_v2.py | 80 ++++++++++++++++-------------- 1 file changed, 43 insertions(+), 37 deletions(-) diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 700012fe..8b4eff64 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -13,6 +13,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W # Written by Ze Liu # -------------------------------------------------------- import math +from typing import Tuple, Optional import torch import torch.nn as nn @@ -91,7 +92,7 @@ default_cfgs = { } -def window_partition(x, window_size): +def window_partition(x, window_size: Tuple[int, int]): """ Args: x: (B, H, W, C) @@ -101,25 +102,25 @@ def window_partition(x, window_size): windows: (num_windows*B, window_size, window_size, C) """ B, H, W, C = x.shape - x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) - windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size[0], window_size[1], C) return windows @register_notrace_function # reason: int argument is a Proxy -def window_reverse(windows, window_size, H, W): +def window_reverse(windows, window_size: Tuple[int, int], img_size: Tuple[int, int]): """ Args: - windows: (num_windows*B, window_size, window_size, C) - window_size (int): Window size - H (int): Height of image - W (int): Width of image + windows: (num_windows * B, window_size[0], window_size[1], C) + window_size (Tuple[int, int]): Window size + img_size (Tuple[int, int]): Image size Returns: x: (B, H, W, C) """ - B = int(windows.shape[0] / (H * W / window_size / window_size)) - x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + H, W = img_size + B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) + x = windows.view(B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1) x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x @@ -148,7 +149,7 @@ class WindowAttention(nn.Module): self.pretrained_window_size = pretrained_window_size self.num_heads = num_heads - self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) # mlp to generate continuous relative position bias self.cpb_mlp = nn.Sequential( @@ -202,7 +203,7 @@ class WindowAttention(nn.Module): self.proj_drop = nn.Dropout(proj_drop) self.softmax = nn.Softmax(dim=-1) - def forward(self, x, mask=None): + def forward(self, x, mask: Optional[torch.Tensor] = None): """ Args: x: input features with shape of (num_windows*B, N, C) @@ -218,7 +219,7 @@ class WindowAttention(nn.Module): # cosine attention attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) - logit_scale = torch.clamp(self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp() + logit_scale = torch.clamp(self.logit_scale, max=math.log(1. / 0.01)).exp() attn = attn * logit_scale relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) @@ -269,16 +270,13 @@ class SwinTransformerBlock(nn.Module): act_layer=nn.GELU, norm_layer=nn.LayerNorm, pretrained_window_size=0): super().__init__() self.dim = dim - self.input_resolution = input_resolution + self.input_resolution = to_2tuple(input_resolution) self.num_heads = num_heads - self.window_size = window_size - self.shift_size = shift_size + ws, ss = self._calc_window_shift(window_size, shift_size) + self.window_size: Tuple[int, int] = ws + self.shift_size: Tuple[int, int] = ss + self.window_area = self.window_size[0] * self.window_size[1] self.mlp_ratio = mlp_ratio - if min(self.input_resolution) <= self.window_size: - # if window size is larger than input resolution, we don't partition windows - self.shift_size = 0 - self.window_size = min(self.input_resolution) - _assert(0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size") self.attn = WindowAttention( dim, window_size=to_2tuple(self.window_size), num_heads=num_heads, @@ -291,23 +289,23 @@ class SwinTransformerBlock(nn.Module): self.norm2 = norm_layer(dim) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() - if self.shift_size > 0: + if any(self.shift_size): # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 cnt = 0 for h in ( - slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)): + slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)): for w in ( - slice(0, -self.window_size), - slice(-self.window_size, -self.shift_size), - slice(-self.shift_size, None)): + slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)): img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 - mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + mask_windows = mask_windows.view(-1, self.window_area) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) else: @@ -315,6 +313,13 @@ class SwinTransformerBlock(nn.Module): self.register_buffer("attn_mask", attn_mask) + def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]: + target_window_size = to_2tuple(target_window_size) + target_shift_size = to_2tuple(target_shift_size) + window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)] + shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)] + return tuple(window_size), tuple(shift_size) + def _attn(self, x): H, W = self.input_resolution B, L, C = x.shape @@ -322,25 +327,26 @@ class SwinTransformerBlock(nn.Module): x = x.view(B, H, W, C) # cyclic shift - if self.shift_size > 0: - shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + has_shift = any(self.shift_size) + if has_shift: + shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)) else: shifted_x = x # partition windows x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C - x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + x_windows = x_windows.view(-1, self.window_area, C) # nW*B, window_size*window_size, C # W-MSA/SW-MSA attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size, C # merge windows - attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) - shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], C) + shifted_x = window_reverse(attn_windows, self.window_size, self.input_resolution) # B H' W' C # reverse cyclic shift - if self.shift_size > 0: - x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + if has_shift: + x = torch.roll(shifted_x, shifts=self.shift_size, dims=(1, 2)) else: x = shifted_x x = x.view(B, H * W, C) @@ -445,7 +451,7 @@ class BasicLayer(nn.Module): def forward(self, x): for blk in self.blocks: - if self.grad_checkpointing: + if not torch.jit.is_scripting() and self.grad_checkpointing: x = checkpoint.checkpoint(blk, x) else: x = blk(x) From d4c0588012a9b5d9fddd13035a9682acd9db0ad7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 13 May 2022 10:49:11 -0700 Subject: [PATCH 23/32] Remove persistent buffers from Swin-V2. Change SwinV2Cr cos attn + tau/logit_scale to match official, add ckpt convert, init_value zeros resid LN weight by default --- timm/models/swin_transformer_v2.py | 26 +++++++++++------ timm/models/swin_transformer_v2_cr.py | 42 ++++++++++++++++++--------- 2 files changed, 46 insertions(+), 22 deletions(-) diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 8b4eff64..fe90144c 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -25,7 +25,6 @@ from .fx_features import register_notrace_function from .helpers import build_model_with_cfg, named_apply from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, to_ntuple, trunc_normal_, _assert from .registry import register_model -from .vision_transformer import checkpoint_filter_fn, get_init_weights_vit def _cfg(url='', **kwargs): @@ -75,7 +74,7 @@ default_cfgs = { ), 'swinv2_base_window12to24_192to384_22kft1k': _cfg( url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_base_patch4_window12to24_192to384_22kto1k_ft.pth', - input_size=(3, 384, 384) + input_size=(3, 384, 384), crop_pct=1.0, ), 'swinv2_large_window12_192_22k': _cfg( url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12_192_22k.pth', @@ -87,7 +86,7 @@ default_cfgs = { ), 'swinv2_large_window12to24_192to384_22kft1k': _cfg( url='https://github.com/SwinTransformer/storage/releases/download/v2.0.0/swinv2_large_patch4_window12to24_192to384_22kto1k_ft.pth', - input_size=(3, 384, 384) + input_size=(3, 384, 384), crop_pct=1.0, ), } @@ -174,7 +173,7 @@ class WindowAttention(nn.Module): relative_coords_table = torch.sign(relative_coords_table) * torch.log2( torch.abs(relative_coords_table) + 1.0) / math.log2(8) - self.register_buffer("relative_coords_table", relative_coords_table) + self.register_buffer("relative_coords_table", relative_coords_table, persistent=False) # get pair-wise relative position index for each token inside the window coords_h = torch.arange(self.window_size[0]) @@ -187,7 +186,7 @@ class WindowAttention(nn.Module): relative_coords[:, :, 1] += self.window_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww - self.register_buffer("relative_position_index", relative_position_index) + self.register_buffer("relative_position_index", relative_position_index, persistent=False) self.qkv = nn.Linear(dim, dim * 3, bias=False) if qkv_bias: @@ -215,7 +214,7 @@ class WindowAttention(nn.Module): qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) - q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + q, k, v = qkv.unbind(0) # cosine attention attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) @@ -559,9 +558,6 @@ class SwinTransformerV2(nn.Module): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) - elif isinstance(m, nn.LayerNorm): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): @@ -621,6 +617,18 @@ class SwinTransformerV2(nn.Module): return x +def checkpoint_filter_fn(state_dict, model): + out_dict = {} + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + for k, v in state_dict.items(): + if any([n in k for n in ('relative_position_index', 'relative_coords_table')]): + continue # skip buffers that should not be persistent + out_dict[k] = v + return out_dict + + def _create_swin_transformer_v2(variant, pretrained=False, **kwargs): model = build_model_with_cfg( SwinTransformerV2, variant, pretrained, diff --git a/timm/models/swin_transformer_v2_cr.py b/timm/models/swin_transformer_v2_cr.py index fcfa217e..d143c14c 100644 --- a/timm/models/swin_transformer_v2_cr.py +++ b/timm/models/swin_transformer_v2_cr.py @@ -34,6 +34,7 @@ from typing import Tuple, Optional, List, Union, Any, Type import torch import torch.nn as nn +import torch.nn.functional as F import torch.utils.checkpoint as checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -41,7 +42,7 @@ from .fx_features import register_notrace_function from .helpers import build_model_with_cfg, named_apply from .layers import DropPath, Mlp, to_2tuple, _assert from .registry import register_model -from .vision_transformer import checkpoint_filter_fn + _logger = logging.getLogger(__name__) @@ -186,12 +187,13 @@ class WindowMultiHeadAttention(nn.Module): act_layer=nn.ReLU, drop=(0.125, 0.) # FIXME should there be stochasticity, appears to 'overfit' without? ) - self.register_parameter("tau", torch.nn.Parameter(torch.ones(num_heads))) + # NOTE old checkpoints used inverse of logit_scale ('tau') following the paper, see conversion fn + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones(num_heads))) self._make_pair_wise_relative_positions() def _make_pair_wise_relative_positions(self) -> None: """Method initializes the pair-wise relative positions to compute the positional biases.""" - device = self.tau.device + device = self.logit_scale.device coordinates = torch.stack(torch.meshgrid([ torch.arange(self.window_size[0], device=device), torch.arange(self.window_size[1], device=device)]), dim=0).flatten(1) @@ -250,10 +252,11 @@ class WindowMultiHeadAttention(nn.Module): query, key, value = qkv.unbind(0) # compute attention map with scaled cosine attention - denom = torch.norm(query, dim=-1, keepdim=True) @ torch.norm(key, dim=-1, keepdim=True).transpose(-2, -1) - attn = query @ key.transpose(-2, -1) / denom.clamp(min=1e-6) - attn = attn / self.tau.clamp(min=0.01).reshape(1, self.num_heads, 1, 1) + attn = (F.normalize(query, dim=-1) @ F.normalize(key, dim=-1).transpose(-2, -1)) + logit_scale = torch.clamp(self.logit_scale.reshape(1, self.num_heads, 1, 1), max=math.log(1. / 0.01)).exp() + attn = attn * logit_scale attn = attn + self._relative_positional_encodings() + if mask is not None: # Apply mask if utilized num_win: int = mask.shape[0] @@ -309,7 +312,7 @@ class SwinTransformerBlock(nn.Module): window_size: Tuple[int, int], shift_size: Tuple[int, int] = (0, 0), mlp_ratio: float = 4.0, - init_values: float = 0, + init_values: Optional[float] = 0, drop: float = 0.0, drop_attn: float = 0.0, drop_path: float = 0.0, @@ -323,7 +326,7 @@ class SwinTransformerBlock(nn.Module): self.target_shift_size: Tuple[int, int] = to_2tuple(shift_size) self.window_size, self.shift_size = self._calc_window_shift(to_2tuple(window_size)) self.window_area = self.window_size[0] * self.window_size[1] - self.init_values: float = init_values + self.init_values: Optional[float] = init_values # attn branch self.attn = WindowMultiHeadAttention( @@ -387,7 +390,7 @@ class SwinTransformerBlock(nn.Module): def init_weights(self): # extra, module specific weight init - if self.init_values: + if self.init_values is not None: nn.init.constant_(self.norm1.weight, self.init_values) nn.init.constant_(self.norm2.weight, self.init_values) @@ -536,7 +539,7 @@ class SwinTransformerStage(nn.Module): feat_size: Tuple[int, int], window_size: Tuple[int, int], mlp_ratio: float = 4.0, - init_values: float = 0.0, + init_values: Optional[float] = 0.0, drop: float = 0.0, drop_attn: float = 0.0, drop_path: Union[List[float], float] = 0.0, @@ -650,7 +653,7 @@ class SwinTransformerV2Cr(nn.Module): depths: Tuple[int, ...] = (2, 2, 6, 2), num_heads: Tuple[int, ...] = (3, 6, 12, 24), mlp_ratio: float = 4.0, - init_values: float = 0.0, + init_values: Optional[float] = 0., drop_rate: float = 0.0, attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, @@ -808,6 +811,21 @@ def init_weights(module: nn.Module, name: str = ''): module.init_weights() +def checkpoint_filter_fn(state_dict, model): + """ convert patch embedding weight from manual patchify + linear proj to conv""" + out_dict = {} + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + for k, v in state_dict.items(): + if 'tau' in k: + # convert old tau based checkpoints -> logit_scale (inverse) + v = torch.log(1 / v) + k = k.replace('tau', 'logit_scale') + out_dict[k] = v + return out_dict + + def _create_swin_transformer_v2_cr(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') @@ -890,7 +908,6 @@ def swinv2_cr_small_ns_224(pretrained=False, **kwargs): embed_dim=96, depths=(2, 2, 18, 2), num_heads=(3, 6, 12, 24), - init_values=1e-5, extra_norm_stage=True, **kwargs ) @@ -928,7 +945,6 @@ def swinv2_cr_base_ns_224(pretrained=False, **kwargs): embed_dim=128, depths=(2, 2, 18, 2), num_heads=(4, 8, 16, 32), - init_values=1e-6, extra_norm_stage=True, **kwargs ) From 4b30bae67b48a8e0e4b727c952bee33a4b52aa3d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 13 May 2022 13:53:57 -0700 Subject: [PATCH 24/32] Add updated vit_relpos weights, and impl w/ support for official swin-v2 differences for relpos. Add bias control support for MLP layers --- timm/models/layers/mlp.py | 33 ++-- timm/models/swin_transformer_v2.py | 2 +- timm/models/vision_transformer_relpos.py | 184 +++++++++++++++++++---- 3 files changed, 178 insertions(+), 41 deletions(-) diff --git a/timm/models/layers/mlp.py b/timm/models/layers/mlp.py index a85e28d0..91e80a84 100644 --- a/timm/models/layers/mlp.py +++ b/timm/models/layers/mlp.py @@ -10,16 +10,17 @@ from .helpers import to_2tuple class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features + bias = to_2tuple(bias) drop_probs = to_2tuple(drop) - self.fc1 = nn.Linear(in_features, hidden_features) + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) - self.fc2 = nn.Linear(hidden_features, out_features) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): @@ -35,17 +36,18 @@ class GluMlp(nn.Module): """ MLP w/ GLU style gating See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 """ - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, bias=True, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features assert hidden_features % 2 == 0 + bias = to_2tuple(bias) drop_probs = to_2tuple(drop) - self.fc1 = nn.Linear(in_features, hidden_features) + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) - self.fc2 = nn.Linear(hidden_features // 2, out_features) + self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def init_weights(self): @@ -67,14 +69,16 @@ class GluMlp(nn.Module): class GatedMlp(nn.Module): """ MLP as used in gMLP """ - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, - gate_layer=None, drop=0.): + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, + gate_layer=None, bias=True, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features + bias = to_2tuple(bias) drop_probs = to_2tuple(drop) - self.fc1 = nn.Linear(in_features, hidden_features) + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) if gate_layer is not None: @@ -83,7 +87,7 @@ class GatedMlp(nn.Module): hidden_features = hidden_features // 2 # FIXME base reduction on gate property? else: self.gate = nn.Identity() - self.fc2 = nn.Linear(hidden_features, out_features) + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): @@ -100,15 +104,18 @@ class ConvMlp(nn.Module): """ MLP using 1x1 convs that keeps spatial dims """ def __init__( - self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.): + self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, + norm_layer=None, bias=True, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features - self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True) + bias = to_2tuple(bias) + + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0]) self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() self.act = act_layer() - self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True) self.drop = nn.Dropout(drop) + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1]) def forward(self, x): x = self.fc1(x) diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index fe90144c..0c9db3dd 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -450,7 +450,7 @@ class BasicLayer(nn.Module): def forward(self, x): for blk in self.blocks: - if not torch.jit.is_scripting() and self.grad_checkpointing: + if self.grad_checkpointing and not torch.jit.is_scripting(): x = checkpoint.checkpoint(blk, x) else: x = blk(x) diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 9ecfd473..0c2ac376 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -1,5 +1,7 @@ """ Relative Position Vision Transformer (ViT) in PyTorch +NOTE: these models are experimental / WIP, expect changes + Hacked together by / Copyright 2022, Ross Wightman """ import math @@ -37,9 +39,23 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_replos_base_patch32_plus_rpn_256-sw-dd486f51.pth', input_size=(3, 256, 256)), 'vit_relpos_base_patch16_plus_240': _cfg(url='', input_size=(3, 240, 240)), - 'vit_relpos_base_patch16_rpn_224': _cfg(url=''), + + 'vit_relpos_small_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_small_patch16_224-sw-ec2778b4.pth'), + 'vit_relpos_medium_patch16_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_224-sw-11c174af.pth'), 'vit_relpos_base_patch16_224': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_224-sw-49049aed.pth'), + + 'vit_relpos_base_patch16_cls_224': _cfg( + url=''), + 'vit_relpos_base_patch16_gapcls_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_base_patch16_gapcls_224-sw-1a341d6c.pth'), + + 'vit_relpos_small_patch16_rpn_224': _cfg(url=''), + 'vit_relpos_medium_patch16_rpn_224': _cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tpu-weights/vit_relpos_medium_patch16_rpn_224-sw-5d2befd8.pth'), + 'vit_relpos_base_patch16_rpn_224': _cfg(url=''), } @@ -66,43 +82,84 @@ def gen_relative_position_index(win_size: Tuple[int, int], class_token: int = 0) return relative_position_index -def gen_relative_position_log(win_size: Tuple[int, int]) -> torch.Tensor: - """Method initializes the pair-wise relative positions to compute the positional biases.""" - coordinates = torch.stack(torch.meshgrid([torch.arange(win_size[0]), torch.arange(win_size[1])])).flatten(1) - relative_coords = coordinates[:, :, None] - coordinates[:, None, :] - relative_coords = relative_coords.permute(1, 2, 0).float() - relative_coordinates_log = torch.sign(relative_coords) * torch.log(1.0 + relative_coords.abs()) - return relative_coordinates_log +def gen_relative_log_coords( + win_size: Tuple[int, int], + pretrained_win_size: Tuple[int, int] = (0, 0), + mode='swin' +): + # as per official swin-v2 impl, supporting timm swin-v2-cr coords as well + relative_coords_h = torch.arange(-(win_size[0] - 1), win_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(win_size[1] - 1), win_size[1], dtype=torch.float32) + relative_coords_table = torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) + relative_coords_table = relative_coords_table.permute(1, 2, 0).contiguous() # 2*Wh-1, 2*Ww-1, 2 + if mode == 'swin': + if pretrained_win_size[0] > 0: + relative_coords_table[:, :, 0] /= (pretrained_win_size[0] - 1) + relative_coords_table[:, :, 1] /= (pretrained_win_size[1] - 1) + else: + relative_coords_table[:, :, 0] /= (win_size[0] - 1) + relative_coords_table[:, :, 1] /= (win_size[1] - 1) + relative_coords_table *= 8 # normalize to -8, 8 + scale = math.log2(8) + else: + # FIXME we should support a form of normalization (to -1/1) for this mode? + scale = math.log2(math.e) + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + 1.0 + relative_coords_table.abs()) / scale + return relative_coords_table class RelPosMlp(nn.Module): - # based on timm swin-v2 impl - def __init__(self, window_size, num_heads=8, hidden_dim=32, class_token=False): + def __init__( + self, + window_size, + num_heads=8, + hidden_dim=128, + class_token=False, + mode='cr', + pretrained_window_size=(0, 0) + ): super().__init__() self.window_size = window_size self.window_area = self.window_size[0] * self.window_size[1] self.class_token = 1 if class_token else 0 self.num_heads = num_heads + self.bias_shape = (self.window_area,) * 2 + (num_heads,) + self.apply_sigmoid = mode == 'swin' + mlp_bias = (True, False) if mode == 'swin' else True self.mlp = Mlp( 2, # x, y - hidden_features=min(128, hidden_dim * num_heads), + hidden_features=hidden_dim, out_features=num_heads, act_layer=nn.ReLU, + bias=mlp_bias, drop=(0.125, 0.) ) self.register_buffer( - 'rel_coords_log', - gen_relative_position_log(window_size), - persistent=False - ) + "relative_position_index", + gen_relative_position_index(window_size), + persistent=False) + + # get relative_coords_table + self.register_buffer( + "rel_coords_log", + gen_relative_log_coords(window_size, pretrained_window_size, mode=mode), + persistent=False) def get_bias(self) -> torch.Tensor: - relative_position_bias = self.mlp(self.rel_coords_log).permute(2, 0, 1).unsqueeze(0) + relative_position_bias = self.mlp(self.rel_coords_log) + if self.relative_position_index is not None: + relative_position_bias = relative_position_bias.view(-1, self.num_heads)[ + self.relative_position_index.view(-1)] # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.view(self.bias_shape) + relative_position_bias = relative_position_bias.permute(2, 0, 1) + if self.apply_sigmoid: + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) if self.class_token: relative_position_bias = F.pad(relative_position_bias, [self.class_token, 0, self.class_token, 0]) - return relative_position_bias + return relative_position_bias.unsqueeze(0).contiguous() def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): return attn + self.get_bias() @@ -131,10 +188,10 @@ class RelPosBias(nn.Module): trunc_normal_(self.relative_position_bias_table, std=.02) def get_bias(self) -> torch.Tensor: - relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( - self.bias_shape) # win_h * win_w, win_h * win_w, num_heads - relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() - return relative_position_bias + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)] + # win_h * win_w, win_h * win_w, num_heads + relative_position_bias = relative_position_bias.view(self.bias_shape).permute(2, 0, 1) + return relative_position_bias.unsqueeze(0).contiguous() def forward(self, attn, shared_rel_pos: Optional[torch.Tensor] = None): return attn + self.get_bias() @@ -250,8 +307,8 @@ class VisionTransformerRelPos(nn.Module): def __init__( self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='avg', - embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=1e-5, - class_token=False, rel_pos_type='mlp', shared_rel_pos=False, fc_norm=False, + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=1e-6, + class_token=False, fc_norm=False, rel_pos_type='mlp', shared_rel_pos=False, rel_pos_dim=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='skip', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=RelPosBlock): """ @@ -268,9 +325,9 @@ class VisionTransformerRelPos(nn.Module): qkv_bias (bool): enable bias for qkv if True init_values: (float): layer-scale init values class_token (bool): use class token (default: False) + fc_norm (bool): use pre classifier norm instead of pre-pool rel_pos_ty pe (str): type of relative position shared_rel_pos (bool): share relative pos across all blocks - fc_norm (bool): use pre classifier norm instead of pre-pool drop_rate (float): dropout rate attn_drop_rate (float): attention dropout rate drop_path_rate (float): stochastic depth rate @@ -295,8 +352,15 @@ class VisionTransformerRelPos(nn.Module): img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) feat_size = self.patch_embed.grid_size - rel_pos_cls = RelPosMlp if rel_pos_type == 'mlp' else RelPosBias - rel_pos_cls = partial(rel_pos_cls, window_size=feat_size, class_token=class_token) + rel_pos_args = dict(window_size=feat_size, class_token=class_token) + if rel_pos_type.startswith('mlp'): + if rel_pos_dim: + rel_pos_args['hidden_dim'] = rel_pos_dim + if 'swin' in rel_pos_type: + rel_pos_args['mode'] = 'swin' + rel_pos_cls = partial(RelPosMlp, **rel_pos_args) + else: + rel_pos_cls = partial(RelPosBias, **rel_pos_args) self.shared_rel_pos = None if shared_rel_pos: self.shared_rel_pos = rel_pos_cls(num_heads=num_heads) @@ -408,6 +472,26 @@ def vit_relpos_base_patch16_plus_240(pretrained=False, **kwargs): return model +@register_model +def vit_relpos_small_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ relative log-coord position, no class token + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, qkv_bias=False, fc_norm=True, **kwargs) + model = _create_vision_transformer_relpos('vit_relpos_small_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_medium_patch16_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ relative log-coord position, no class token + """ + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, fc_norm=True, **kwargs) + model = _create_vision_transformer_relpos('vit_relpos_medium_patch16_224', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_relpos_base_patch16_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/ relative log-coord position, no class token @@ -418,11 +502,57 @@ def vit_relpos_base_patch16_224(pretrained=False, **kwargs): return model +@register_model +def vit_relpos_base_patch16_cls_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, + class_token=True, global_pool='token', **kwargs) + model = _create_vision_transformer_relpos('vit_relpos_base_patch16_cls_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_base_patch16_gapcls_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ relative log-coord position, class token present + NOTE this config is a bit of a mistake, class token was enabled but global avg-pool w/ fc-norm was not disabled + Leaving here for comparisons w/ a future re-train as it performs quite well. + """ + model_kwargs = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, fc_norm=True, class_token=True, **kwargs) + model = _create_vision_transformer_relpos('vit_relpos_base_patch16_gapcls_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_small_patch16_rpn_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token + """ + model_kwargs = dict( + patch_size=16, embed_dim=384, depth=12, num_heads=6, qkv_bias=False, block_fn=ResPostRelPosBlock, **kwargs) + model = _create_vision_transformer_relpos( + 'vit_relpos_small_patch16_rpn_224', pretrained=pretrained, **model_kwargs) + return model + + +@register_model +def vit_relpos_medium_patch16_rpn_224(pretrained=False, **kwargs): + """ ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token + """ + model_kwargs = dict( + patch_size=16, embed_dim=512, depth=12, num_heads=8, qkv_bias=False, block_fn=ResPostRelPosBlock, **kwargs) + model = _create_vision_transformer_relpos( + 'vit_relpos_medium_patch16_rpn_224', pretrained=pretrained, **model_kwargs) + return model + + @register_model def vit_relpos_base_patch16_rpn_224(pretrained=False, **kwargs): """ ViT-Base (ViT-B/16) w/ relative log-coord position and residual post-norm, no class token """ model_kwargs = dict( patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, block_fn=ResPostRelPosBlock, **kwargs) - model = _create_vision_transformer_relpos('vit_relpos_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs) + model = _create_vision_transformer_relpos( + 'vit_relpos_base_patch16_rpn_224', pretrained=pretrained, **model_kwargs) return model From 347308faadbe98a7e736371fffa013c28543530c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 13 May 2022 13:54:41 -0700 Subject: [PATCH 25/32] Update README.md, version to 0.6.2 --- README.md | 12 ++++++++++++ timm/version.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index df5fb968..f90c6abd 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,17 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor ## What's New +### May 13, 2022 +* Official Swin-V2 models and weights added from (https://github.com/microsoft/Swin-Transformer). Cleaned up to support torchscript. +* Some refactoring for existing `timm` Swin-V2-CR impl, will likely do a bit more to bring parts closer to official and decide whether to merge some aspects. +* More Vision Transformer relative position / residual post-norm experiments w/ 512 dim + * `vit_relpos_small_patch16_224` - 81.5 @ 224, 82.5 @ 320 -- rel pos, layer scale, no class token, avg pool + * `vit_relpos_medium_patch16_rpn_224` - 82.3 @ 224, 83.1 @ 320 -- rel pos + res-post-norm, no class token, avg pool + * `vit_relpos_medium_patch16_224` - 82.5 @ 224, 83.3 @ 320 -- rel pos, layer scale, no class token, avg pool + * `vit_relpos_base_patch16_gapcls_224` - 82.8 @ 224, 83.9 @ 320 -- rel pos, layer scale, class token, avg pool (by mistake) +* Bring 512 dim, 8-head 'medium' ViT model variant back to life (after using in a pre DeiT 'small' model for first ViT impl back in 2020) +* Add ViT relative position support for switching btw existing impl and some additions in official Swin-V2 impl for future trials +* Sequencer2D impl (https://arxiv.org/abs/2205.01972), added via PR from author (https://github.com/okojoalg) ### May 2, 2022 * Vision Transformer experiments adding Relative Position (Swin-V2 log-coord) (`vision_transformer_relpos.py`) and Residual Post-Norm branches (from Swin-V2) (`vision_transformer*.py`) @@ -390,6 +401,7 @@ A full version of the list below with source links can be found in the [document * ReXNet - https://arxiv.org/abs/2007.00992 * SelecSLS - https://arxiv.org/abs/1907.00837 * Selective Kernel Networks - https://arxiv.org/abs/1903.06586 +* Sequencer2D - https://arxiv.org/abs/2205.01972 * Swin S3 (AutoFormerV2) - https://arxiv.org/abs/2111.14725 * Swin Transformer - https://arxiv.org/abs/2103.14030 * Swin Transformer V2 - https://arxiv.org/abs/2111.09883 diff --git a/timm/version.py b/timm/version.py index 8411e551..aece342d 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.6.1' +__version__ = '0.6.2' From 137292644b8a959add48d4d63663d2f9fefbab21 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 13 May 2022 15:28:04 -0700 Subject: [PATCH 26/32] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index f90c6abd..cd9b6071 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor ### May 13, 2022 * Official Swin-V2 models and weights added from (https://github.com/microsoft/Swin-Transformer). Cleaned up to support torchscript. * Some refactoring for existing `timm` Swin-V2-CR impl, will likely do a bit more to bring parts closer to official and decide whether to merge some aspects. -* More Vision Transformer relative position / residual post-norm experiments w/ 512 dim +* More Vision Transformer relative position / residual post-norm experiments w/ 512 dim (all trained on TPU thanks to TRC program) * `vit_relpos_small_patch16_224` - 81.5 @ 224, 82.5 @ 320 -- rel pos, layer scale, no class token, avg pool * `vit_relpos_medium_patch16_rpn_224` - 82.3 @ 224, 83.1 @ 320 -- rel pos + res-post-norm, no class token, avg pool * `vit_relpos_medium_patch16_224` - 82.5 @ 224, 83.3 @ 320 -- rel pos, layer scale, no class token, avg pool From 3409bb91568b107a44d82c3552b02ac7edefa032 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 13 May 2022 15:28:46 -0700 Subject: [PATCH 27/32] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index cd9b6071..4c39c692 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor ### May 13, 2022 * Official Swin-V2 models and weights added from (https://github.com/microsoft/Swin-Transformer). Cleaned up to support torchscript. * Some refactoring for existing `timm` Swin-V2-CR impl, will likely do a bit more to bring parts closer to official and decide whether to merge some aspects. -* More Vision Transformer relative position / residual post-norm experiments w/ 512 dim (all trained on TPU thanks to TRC program) +* More Vision Transformer relative position / residual post-norm experiments (all trained on TPU thanks to TRC program) * `vit_relpos_small_patch16_224` - 81.5 @ 224, 82.5 @ 320 -- rel pos, layer scale, no class token, avg pool * `vit_relpos_medium_patch16_rpn_224` - 82.3 @ 224, 83.1 @ 320 -- rel pos + res-post-norm, no class token, avg pool * `vit_relpos_medium_patch16_224` - 82.5 @ 224, 83.3 @ 320 -- rel pos, layer scale, no class token, avg pool From 20a1fa63f8ea999dab29d927d5e1866ed3b67348 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 15 May 2022 14:29:57 -0700 Subject: [PATCH 28/32] Make dev version 0.6.2.dev0 for pypi pre --- timm/version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timm/version.py b/timm/version.py index aece342d..3e8e43bd 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.6.2' +__version__ = '0.6.2.dev0' From e1e4c9bbae292ee983f0e606283f66ee0598b1d4 Mon Sep 17 00:00:00 2001 From: Jakub Kaczmarzyk Date: Wed, 18 May 2022 10:17:02 -0400 Subject: [PATCH 29/32] rm whitespace --- train.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/train.py b/train.py index 6f31e295..b6aade12 100755 --- a/train.py +++ b/train.py @@ -55,7 +55,7 @@ except AttributeError: try: import wandb has_wandb = True -except ImportError: +except ImportError: has_wandb = False torch.backends.cudnn.benchmark = True @@ -326,14 +326,14 @@ def _parse_args(): def main(): setup_default_logging() args, args_text = _parse_args() - + if args.log_wandb: if has_wandb: wandb.init(project=args.experiment, config=args) - else: + else: _logger.warning("You've requested to log metrics to wandb but package not found. " "Metrics not being logged to wandb, try `pip install wandb`") - + args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: From dcad288fd6ff5109042c6fe61994db5cc5e55f3a Mon Sep 17 00:00:00 2001 From: Jakub Kaczmarzyk Date: Wed, 18 May 2022 10:27:33 -0400 Subject: [PATCH 30/32] use argparse groups to group arguments --- train.py | 227 +++++++++++++++++++++++++++++-------------------------- 1 file changed, 118 insertions(+), 109 deletions(-) diff --git a/train.py b/train.py index b6aade12..c953eb02 100755 --- a/train.py +++ b/train.py @@ -71,238 +71,247 @@ parser.add_argument('-c', '--config', default='', type=str, metavar='FILE', parser = argparse.ArgumentParser(description='PyTorch ImageNet Training') # Dataset parameters +group = parser.add_argument_group('Dataset parameters') +# Keep this argument outside of the dataset group because it is positional. parser.add_argument('data_dir', metavar='DIR', help='path to dataset') -parser.add_argument('--dataset', '-d', metavar='NAME', default='', +group.add_argument('--dataset', '-d', metavar='NAME', default='', help='dataset type (default: ImageFolder/ImageTar if empty)') -parser.add_argument('--train-split', metavar='NAME', default='train', +group.add_argument('--train-split', metavar='NAME', default='train', help='dataset train split (default: train)') -parser.add_argument('--val-split', metavar='NAME', default='validation', +group.add_argument('--val-split', metavar='NAME', default='validation', help='dataset validation split (default: validation)') -parser.add_argument('--dataset-download', action='store_true', default=False, +group.add_argument('--dataset-download', action='store_true', default=False, help='Allow download of dataset for torch/ and tfds/ datasets that support it.') -parser.add_argument('--class-map', default='', type=str, metavar='FILENAME', +group.add_argument('--class-map', default='', type=str, metavar='FILENAME', help='path to class to idx mapping file (default: "")') # Model parameters -parser.add_argument('--model', default='resnet50', type=str, metavar='MODEL', +group = parser.add_argument_group('Model parameters') +group.add_argument('--model', default='resnet50', type=str, metavar='MODEL', help='Name of model to train (default: "resnet50"') -parser.add_argument('--pretrained', action='store_true', default=False, +group.add_argument('--pretrained', action='store_true', default=False, help='Start with pretrained version of specified network (if avail)') -parser.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', +group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH', help='Initialize model from this checkpoint (default: none)') -parser.add_argument('--resume', default='', type=str, metavar='PATH', +group.add_argument('--resume', default='', type=str, metavar='PATH', help='Resume full model and optimizer state from checkpoint (default: none)') -parser.add_argument('--no-resume-opt', action='store_true', default=False, +group.add_argument('--no-resume-opt', action='store_true', default=False, help='prevent resume of optimizer state when resuming model') -parser.add_argument('--num-classes', type=int, default=None, metavar='N', +group.add_argument('--num-classes', type=int, default=None, metavar='N', help='number of label classes (Model default if None)') -parser.add_argument('--gp', default=None, type=str, metavar='POOL', +group.add_argument('--gp', default=None, type=str, metavar='POOL', help='Global pool type, one of (fast, avg, max, avgmax, avgmaxc). Model default if None.') -parser.add_argument('--img-size', type=int, default=None, metavar='N', +group.add_argument('--img-size', type=int, default=None, metavar='N', help='Image patch size (default: None => model default)') -parser.add_argument('--input-size', default=None, nargs=3, type=int, +group.add_argument('--input-size', default=None, nargs=3, type=int, metavar='N N N', help='Input all image dimensions (d h w, e.g. --input-size 3 224 224), uses model default if empty') -parser.add_argument('--crop-pct', default=None, type=float, +group.add_argument('--crop-pct', default=None, type=float, metavar='N', help='Input image center crop percent (for validation only)') -parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', +group.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override mean pixel value of dataset') -parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', +group.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', help='Override std deviation of dataset') -parser.add_argument('--interpolation', default='', type=str, metavar='NAME', +group.add_argument('--interpolation', default='', type=str, metavar='NAME', help='Image resize interpolation type (overrides model)') -parser.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', +group.add_argument('-b', '--batch-size', type=int, default=128, metavar='N', help='Input batch size for training (default: 128)') -parser.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N', +group.add_argument('-vb', '--validation-batch-size', type=int, default=None, metavar='N', help='Validation batch size override (default: None)') -parser.add_argument('--channels-last', action='store_true', default=False, +group.add_argument('--channels-last', action='store_true', default=False, help='Use channels_last memory layout') -parser.add_argument('--torchscript', dest='torchscript', action='store_true', +group.add_argument('--torchscript', dest='torchscript', action='store_true', help='torch.jit.script the full model') -parser.add_argument('--fuser', default='', type=str, +group.add_argument('--fuser', default='', type=str, help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')") -parser.add_argument('--grad-checkpointing', action='store_true', default=False, +group.add_argument('--grad-checkpointing', action='store_true', default=False, help='Enable gradient checkpointing through model blocks/stages') # Optimizer parameters -parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', +group = parser.add_argument_group('Optimizer parameters') +group.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER', help='Optimizer (default: "sgd"') -parser.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', +group.add_argument('--opt-eps', default=None, type=float, metavar='EPSILON', help='Optimizer Epsilon (default: None, use opt default)') -parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', +group.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA', help='Optimizer Betas (default: None, use opt default)') -parser.add_argument('--momentum', type=float, default=0.9, metavar='M', +group.add_argument('--momentum', type=float, default=0.9, metavar='M', help='Optimizer momentum (default: 0.9)') -parser.add_argument('--weight-decay', type=float, default=2e-5, +group.add_argument('--weight-decay', type=float, default=2e-5, help='weight decay (default: 2e-5)') -parser.add_argument('--clip-grad', type=float, default=None, metavar='NORM', +group.add_argument('--clip-grad', type=float, default=None, metavar='NORM', help='Clip gradient norm (default: None, no clipping)') -parser.add_argument('--clip-mode', type=str, default='norm', +group.add_argument('--clip-mode', type=str, default='norm', help='Gradient clipping mode. One of ("norm", "value", "agc")') -parser.add_argument('--layer-decay', type=float, default=None, +group.add_argument('--layer-decay', type=float, default=None, help='layer-wise learning rate decay (default: None)') # Learning rate schedule parameters -parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', +group = parser.add_argument_group('Learning rate schedule parameters') +group.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER', help='LR scheduler (default: "step"') -parser.add_argument('--lr', type=float, default=0.05, metavar='LR', +group.add_argument('--lr', type=float, default=0.05, metavar='LR', help='learning rate (default: 0.05)') -parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', +group.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct', help='learning rate noise on/off epoch percentages') -parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', +group.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT', help='learning rate noise limit percent (default: 0.67)') -parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', +group.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV', help='learning rate noise std-dev (default: 1.0)') -parser.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', +group.add_argument('--lr-cycle-mul', type=float, default=1.0, metavar='MULT', help='learning rate cycle len multiplier (default: 1.0)') -parser.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT', +group.add_argument('--lr-cycle-decay', type=float, default=0.5, metavar='MULT', help='amount to decay each learning rate cycle (default: 0.5)') -parser.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', +group.add_argument('--lr-cycle-limit', type=int, default=1, metavar='N', help='learning rate cycle limit, cycles enabled if > 1') -parser.add_argument('--lr-k-decay', type=float, default=1.0, +group.add_argument('--lr-k-decay', type=float, default=1.0, help='learning rate k-decay for cosine/poly (default: 1.0)') -parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', +group.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR', help='warmup learning rate (default: 0.0001)') -parser.add_argument('--min-lr', type=float, default=1e-6, metavar='LR', +group.add_argument('--min-lr', type=float, default=1e-6, metavar='LR', help='lower lr bound for cyclic schedulers that hit 0 (1e-5)') -parser.add_argument('--epochs', type=int, default=300, metavar='N', +group.add_argument('--epochs', type=int, default=300, metavar='N', help='number of epochs to train (default: 300)') -parser.add_argument('--epoch-repeats', type=float, default=0., metavar='N', +group.add_argument('--epoch-repeats', type=float, default=0., metavar='N', help='epoch repeat multiplier (number of times to repeat dataset epoch per train epoch).') -parser.add_argument('--start-epoch', default=None, type=int, metavar='N', +group.add_argument('--start-epoch', default=None, type=int, metavar='N', help='manual epoch number (useful on restarts)') -parser.add_argument('--decay-milestones', default=[30, 60], type=int, nargs='+', metavar="MILESTONES", +group.add_argument('--decay-milestones', default=[30, 60], type=int, nargs='+', metavar="MILESTONES", help='list of decay epoch indices for multistep lr. must be increasing') -parser.add_argument('--decay-epochs', type=float, default=100, metavar='N', +group.add_argument('--decay-epochs', type=float, default=100, metavar='N', help='epoch interval to decay LR') -parser.add_argument('--warmup-epochs', type=int, default=3, metavar='N', +group.add_argument('--warmup-epochs', type=int, default=3, metavar='N', help='epochs to warmup LR, if scheduler supports') -parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', +group.add_argument('--cooldown-epochs', type=int, default=10, metavar='N', help='epochs to cooldown LR at min_lr, after cyclic schedule ends') -parser.add_argument('--patience-epochs', type=int, default=10, metavar='N', +group.add_argument('--patience-epochs', type=int, default=10, metavar='N', help='patience epochs for Plateau LR scheduler (default: 10') -parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', +group.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE', help='LR decay rate (default: 0.1)') # Augmentation & regularization parameters -parser.add_argument('--no-aug', action='store_true', default=False, +group = parser.add_argument_group('Augmentation and regularization parameters') +group.add_argument('--no-aug', action='store_true', default=False, help='Disable all training augmentation, override other train aug args') -parser.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', +group.add_argument('--scale', type=float, nargs='+', default=[0.08, 1.0], metavar='PCT', help='Random resize scale (default: 0.08 1.0)') -parser.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO', +group.add_argument('--ratio', type=float, nargs='+', default=[3./4., 4./3.], metavar='RATIO', help='Random resize aspect ratio (default: 0.75 1.33)') -parser.add_argument('--hflip', type=float, default=0.5, +group.add_argument('--hflip', type=float, default=0.5, help='Horizontal flip training aug probability') -parser.add_argument('--vflip', type=float, default=0., +group.add_argument('--vflip', type=float, default=0., help='Vertical flip training aug probability') -parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', +group.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT', help='Color jitter factor (default: 0.4)') -parser.add_argument('--aa', type=str, default=None, metavar='NAME', +group.add_argument('--aa', type=str, default=None, metavar='NAME', help='Use AutoAugment policy. "v0" or "original". (default: None)'), -parser.add_argument('--aug-repeats', type=float, default=0, +group.add_argument('--aug-repeats', type=float, default=0, help='Number of augmentation repetitions (distributed training only) (default: 0)') -parser.add_argument('--aug-splits', type=int, default=0, +group.add_argument('--aug-splits', type=int, default=0, help='Number of augmentation splits (default: 0, valid: 0 or >=2)') -parser.add_argument('--jsd-loss', action='store_true', default=False, +group.add_argument('--jsd-loss', action='store_true', default=False, help='Enable Jensen-Shannon Divergence + CE loss. Use with `--aug-splits`.') -parser.add_argument('--bce-loss', action='store_true', default=False, +group.add_argument('--bce-loss', action='store_true', default=False, help='Enable BCE loss w/ Mixup/CutMix use.') -parser.add_argument('--bce-target-thresh', type=float, default=None, +group.add_argument('--bce-target-thresh', type=float, default=None, help='Threshold for binarizing softened BCE targets (default: None, disabled)') -parser.add_argument('--reprob', type=float, default=0., metavar='PCT', +group.add_argument('--reprob', type=float, default=0., metavar='PCT', help='Random erase prob (default: 0.)') -parser.add_argument('--remode', type=str, default='pixel', +group.add_argument('--remode', type=str, default='pixel', help='Random erase mode (default: "pixel")') -parser.add_argument('--recount', type=int, default=1, +group.add_argument('--recount', type=int, default=1, help='Random erase count (default: 1)') -parser.add_argument('--resplit', action='store_true', default=False, +group.add_argument('--resplit', action='store_true', default=False, help='Do not random erase first (clean) augmentation split') -parser.add_argument('--mixup', type=float, default=0.0, +group.add_argument('--mixup', type=float, default=0.0, help='mixup alpha, mixup enabled if > 0. (default: 0.)') -parser.add_argument('--cutmix', type=float, default=0.0, +group.add_argument('--cutmix', type=float, default=0.0, help='cutmix alpha, cutmix enabled if > 0. (default: 0.)') -parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, +group.add_argument('--cutmix-minmax', type=float, nargs='+', default=None, help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)') -parser.add_argument('--mixup-prob', type=float, default=1.0, +group.add_argument('--mixup-prob', type=float, default=1.0, help='Probability of performing mixup or cutmix when either/both is enabled') -parser.add_argument('--mixup-switch-prob', type=float, default=0.5, +group.add_argument('--mixup-switch-prob', type=float, default=0.5, help='Probability of switching to cutmix when both mixup and cutmix enabled') -parser.add_argument('--mixup-mode', type=str, default='batch', +group.add_argument('--mixup-mode', type=str, default='batch', help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"') -parser.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', +group.add_argument('--mixup-off-epoch', default=0, type=int, metavar='N', help='Turn off mixup after this epoch, disabled if 0 (default: 0)') -parser.add_argument('--smoothing', type=float, default=0.1, +group.add_argument('--smoothing', type=float, default=0.1, help='Label smoothing (default: 0.1)') -parser.add_argument('--train-interpolation', type=str, default='random', +group.add_argument('--train-interpolation', type=str, default='random', help='Training interpolation (random, bilinear, bicubic default: "random")') -parser.add_argument('--drop', type=float, default=0.0, metavar='PCT', +group.add_argument('--drop', type=float, default=0.0, metavar='PCT', help='Dropout rate (default: 0.)') -parser.add_argument('--drop-connect', type=float, default=None, metavar='PCT', +group.add_argument('--drop-connect', type=float, default=None, metavar='PCT', help='Drop connect rate, DEPRECATED, use drop-path (default: None)') -parser.add_argument('--drop-path', type=float, default=None, metavar='PCT', +group.add_argument('--drop-path', type=float, default=None, metavar='PCT', help='Drop path rate (default: None)') -parser.add_argument('--drop-block', type=float, default=None, metavar='PCT', +group.add_argument('--drop-block', type=float, default=None, metavar='PCT', help='Drop block rate (default: None)') # Batch norm parameters (only works with gen_efficientnet based models currently) -parser.add_argument('--bn-momentum', type=float, default=None, +group = parser.add_argument_group('Batch norm parameters', 'Only works with gen_efficientnet based models currently.') +group.add_argument('--bn-momentum', type=float, default=None, help='BatchNorm momentum override (if not None)') -parser.add_argument('--bn-eps', type=float, default=None, +group.add_argument('--bn-eps', type=float, default=None, help='BatchNorm epsilon override (if not None)') -parser.add_argument('--sync-bn', action='store_true', +group.add_argument('--sync-bn', action='store_true', help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') -parser.add_argument('--dist-bn', type=str, default='reduce', +group.add_argument('--dist-bn', type=str, default='reduce', help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') -parser.add_argument('--split-bn', action='store_true', +group.add_argument('--split-bn', action='store_true', help='Enable separate BN layers per augmentation split.') # Model Exponential Moving Average -parser.add_argument('--model-ema', action='store_true', default=False, +group = parser.add_argument_group('Model exponential moving average parameters') +group.add_argument('--model-ema', action='store_true', default=False, help='Enable tracking moving average of model weights') -parser.add_argument('--model-ema-force-cpu', action='store_true', default=False, +group.add_argument('--model-ema-force-cpu', action='store_true', default=False, help='Force ema to be tracked on CPU, rank=0 node only. Disables EMA validation.') -parser.add_argument('--model-ema-decay', type=float, default=0.9998, +group.add_argument('--model-ema-decay', type=float, default=0.9998, help='decay factor for model weights moving average (default: 0.9998)') # Misc -parser.add_argument('--seed', type=int, default=42, metavar='S', +group = parser.add_argument_group('Miscellaneous parameters') +group.add_argument('--seed', type=int, default=42, metavar='S', help='random seed (default: 42)') -parser.add_argument('--worker-seeding', type=str, default='all', +group.add_argument('--worker-seeding', type=str, default='all', help='worker seed mode (default: all)') -parser.add_argument('--log-interval', type=int, default=50, metavar='N', +group.add_argument('--log-interval', type=int, default=50, metavar='N', help='how many batches to wait before logging training status') -parser.add_argument('--recovery-interval', type=int, default=0, metavar='N', +group.add_argument('--recovery-interval', type=int, default=0, metavar='N', help='how many batches to wait before writing recovery checkpoint') -parser.add_argument('--checkpoint-hist', type=int, default=10, metavar='N', +group.add_argument('--checkpoint-hist', type=int, default=10, metavar='N', help='number of checkpoints to keep (default: 10)') -parser.add_argument('-j', '--workers', type=int, default=4, metavar='N', +group.add_argument('-j', '--workers', type=int, default=4, metavar='N', help='how many training processes to use (default: 4)') -parser.add_argument('--save-images', action='store_true', default=False, +group.add_argument('--save-images', action='store_true', default=False, help='save images of input bathes every log interval for debugging') -parser.add_argument('--amp', action='store_true', default=False, +group.add_argument('--amp', action='store_true', default=False, help='use NVIDIA Apex AMP or Native AMP for mixed precision training') -parser.add_argument('--apex-amp', action='store_true', default=False, +group.add_argument('--apex-amp', action='store_true', default=False, help='Use NVIDIA Apex AMP mixed precision') -parser.add_argument('--native-amp', action='store_true', default=False, +group.add_argument('--native-amp', action='store_true', default=False, help='Use Native Torch AMP mixed precision') -parser.add_argument('--no-ddp-bb', action='store_true', default=False, +group.add_argument('--no-ddp-bb', action='store_true', default=False, help='Force broadcast buffers for native DDP to off.') -parser.add_argument('--pin-mem', action='store_true', default=False, +group.add_argument('--pin-mem', action='store_true', default=False, help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.') -parser.add_argument('--no-prefetcher', action='store_true', default=False, +group.add_argument('--no-prefetcher', action='store_true', default=False, help='disable fast prefetcher') -parser.add_argument('--output', default='', type=str, metavar='PATH', +group.add_argument('--output', default='', type=str, metavar='PATH', help='path to output folder (default: none, current dir)') -parser.add_argument('--experiment', default='', type=str, metavar='NAME', +group.add_argument('--experiment', default='', type=str, metavar='NAME', help='name of train experiment, name of sub-folder for output') -parser.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', +group.add_argument('--eval-metric', default='top1', type=str, metavar='EVAL_METRIC', help='Best metric (default: "top1"') -parser.add_argument('--tta', type=int, default=0, metavar='N', +group.add_argument('--tta', type=int, default=0, metavar='N', help='Test/inference time augmentation (oversampling) factor. 0=None (default: 0)') -parser.add_argument("--local_rank", default=0, type=int) -parser.add_argument('--use-multi-epochs-loader', action='store_true', default=False, +group.add_argument("--local_rank", default=0, type=int) +group.add_argument('--use-multi-epochs-loader', action='store_true', default=False, help='use the multi-epochs-loader to save time at the beginning of every epoch') -parser.add_argument('--log-wandb', action='store_true', default=False, +group.add_argument('--log-wandb', action='store_true', default=False, help='log training and validation metrics to wandb') From ce5578bc3a3e2def84df79bc08e10be5f5fc7a14 Mon Sep 17 00:00:00 2001 From: Jakub Kaczmarzyk Date: Wed, 18 May 2022 11:04:10 -0400 Subject: [PATCH 31/32] replace star imports with imported names --- train.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/train.py b/train.py index 6f31e295..3d19cd9f 100755 --- a/train.py +++ b/train.py @@ -31,8 +31,11 @@ from torch.nn.parallel import DistributedDataParallel as NativeDDP from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\ convert_splitbn_model, model_parameters -from timm.utils import * -from timm.loss import * +from timm.utils import setup_default_logging, random_seed, set_jit_fuser, ModelEmaV2,\ + get_outdir, CheckpointSaver, distribute_bn, update_summary, accuracy, AverageMeter,\ + dispatch_clip_grad, reduce_tensor +from timm.loss import JsdCrossEntropy, BinaryCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy,\ + LabelSmoothingCrossEntropy from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler from timm.utils import ApexScaler, NativeScaler From 57f8361a01a68bfe0b96b84d9b07b74e6fb6ca92 Mon Sep 17 00:00:00 2001 From: Yonghye Kwon Date: Wed, 25 May 2022 00:36:28 +0900 Subject: [PATCH 32/32] fix a function parameter typo(cropt_pct -> crop_pct) --- validate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/validate.py b/validate.py index 2720f903..27b88299 100755 --- a/validate.py +++ b/validate.py @@ -271,7 +271,7 @@ def validate(args): top5=round(top5a, 4), top5_err=round(100 - top5a, 4), param_count=round(param_count / 1e6, 2), img_size=data_config['input_size'][-1], - cropt_pct=crop_pct, + crop_pct=crop_pct, interpolation=data_config['interpolation']) _logger.info(' * Acc@1 {:.3f} ({:.3f}) Acc@5 {:.3f} ({:.3f})'.format(