diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1cc44acf..9f7aebdb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -17,8 +17,8 @@ jobs: matrix: os: [ubuntu-latest, macOS-latest] python: ['3.8'] - torch: ['1.8.0'] - torchvision: ['0.9.0'] + torch: ['1.8.1'] + torchvision: ['0.9.1'] runs-on: ${{ matrix.os }} steps: diff --git a/README.md b/README.md index 6a8d520e..06aee7ec 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,14 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor ## What's New +### May 25, 2021 +* Add LeViT, Visformer, ConViT (PR by Aman Arora), Twins (PR by paper authors) transformer models +* Add ResMLP and gMLP MLP vision models to the existing MLP Mixer impl +* Fix a number of torchscript issues with various vision transformer models +* Cleanup input_size/img_size override handling and improve testing / test coverage for all vision transformer and MLP models +* More flexible pos embedding resize (non-square) for ViT and TnT. Thanks [Alexander Soare](https://github.com/alexander-soare) +* Add `efficientnetv2_rw_m` model and weights (started training before official code). 84.8 top-1, 53M params. + ### May 14, 2021 * Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl. * 1k trained variants: `tf_efficientnetv2_s/m/l` @@ -166,30 +174,6 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor * Misc fixes for SiLU ONNX export, default_cfg missing from Feature extraction models, Linear layer w/ AMP + torchscript * PyPi release @ 0.3.2 (needed by EfficientDet) -### Oct 30, 2020 -* Test with PyTorch 1.7 and fix a small top-n metric view vs reshape issue. -* Convert newly added 224x224 Vision Transformer weights from official JAX repo. 81.8 top-1 for B/16, 83.1 L/16. -* Support PyTorch 1.7 optimized, native SiLU (aka Swish) activation. Add mapping to 'silu' name, custom swish will eventually be deprecated. -* Fix regression for loading pretrained classifier via direct model entrypoint functions. Didn't impact create_model() factory usage. -* PyPi release @ 0.3.0 version! - -### Oct 26, 2020 -* Update Vision Transformer models to be compatible with official code release at https://github.com/google-research/vision_transformer -* Add Vision Transformer weights (ImageNet-21k pretrain) for 384x384 base and large models converted from official jax impl - * ViT-B/16 - 84.2 - * ViT-B/32 - 81.7 - * ViT-L/16 - 85.2 - * ViT-L/32 - 81.5 - -### Oct 21, 2020 -* Weights added for Vision Transformer (ViT) models. 77.86 top-1 for 'small' and 79.35 for 'base'. Thanks to [Christof](https://www.kaggle.com/christofhenkel) for training the base model w/ lots of GPUs. - -### Oct 13, 2020 -* Initial impl of Vision Transformer models. Both patch and hybrid (CNN backbone) variants. Currently trying to train... -* Adafactor and AdaHessian (FP32 only, no AMP) optimizers -* EdgeTPU-M (`efficientnet_em`) model trained in PyTorch, 79.3 top-1 -* Pip release, doc updates pending a few more changes... - ## Introduction @@ -207,6 +191,7 @@ A full version of the list below with source links can be found in the [document * Bottleneck Transformers - https://arxiv.org/abs/2101.11605 * CaiT (Class-Attention in Image Transformers) - https://arxiv.org/abs/2103.17239 * CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399 +* ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697 * CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929 * DeiT (Vision Transformer) - https://arxiv.org/abs/2012.12877 * DenseNet - https://arxiv.org/abs/1608.06993 @@ -224,6 +209,7 @@ A full version of the list below with source links can be found in the [document * MobileNet-V2 - https://arxiv.org/abs/1801.04381 * Single-Path NAS - https://arxiv.org/abs/1904.02877 * GhostNet - https://arxiv.org/abs/1911.11907 +* gMLP - https://arxiv.org/abs/2105.08050 * GPU-Efficient Networks - https://arxiv.org/abs/2006.14090 * Halo Nets - https://arxiv.org/abs/2103.12731 * HardCoRe-NAS - https://arxiv.org/abs/2102.11646 @@ -231,6 +217,7 @@ A full version of the list below with source links can be found in the [document * Inception-V3 - https://arxiv.org/abs/1512.00567 * Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261 * Lambda Networks - https://arxiv.org/abs/2102.08602 +* LeViT (Vision Transformer in ConvNet's Clothing) - https://arxiv.org/abs/2104.01136 * MLP-Mixer - https://arxiv.org/abs/2105.01601 * MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244 * NASNet-A - https://arxiv.org/abs/1707.07012 @@ -240,6 +227,7 @@ A full version of the list below with source links can be found in the [document * Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302 * RegNet - https://arxiv.org/abs/2003.13678 * RepVGG - https://arxiv.org/abs/2101.03697 +* ResMLP - https://arxiv.org/abs/2105.03404 * ResNet/ResNeXt * ResNet (v1b/v1.5) - https://arxiv.org/abs/1512.03385 * ResNeXt - https://arxiv.org/abs/1611.05431 @@ -257,6 +245,7 @@ A full version of the list below with source links can be found in the [document * Swin Transformer - https://arxiv.org/abs/2103.14030 * Transformer-iN-Transformer (TNT) - https://arxiv.org/abs/2103.00112 * TResNet - https://arxiv.org/abs/2003.13630 +* Twins (Spatial Attention in Vision Transformers) - https://arxiv.org/pdf/2104.13840.pdf * Vision Transformer - https://arxiv.org/abs/2010.11929 * VovNet V2 and V1 - https://arxiv.org/abs/1911.06667 * Xception - https://arxiv.org/abs/1610.02357 diff --git a/docs/archived_changes.md b/docs/archived_changes.md index 857a914d..56ee706f 100644 --- a/docs/archived_changes.md +++ b/docs/archived_changes.md @@ -1,5 +1,29 @@ # Archived Changes +### Oct 30, 2020 +* Test with PyTorch 1.7 and fix a small top-n metric view vs reshape issue. +* Convert newly added 224x224 Vision Transformer weights from official JAX repo. 81.8 top-1 for B/16, 83.1 L/16. +* Support PyTorch 1.7 optimized, native SiLU (aka Swish) activation. Add mapping to 'silu' name, custom swish will eventually be deprecated. +* Fix regression for loading pretrained classifier via direct model entrypoint functions. Didn't impact create_model() factory usage. +* PyPi release @ 0.3.0 version! + +### Oct 26, 2020 +* Update Vision Transformer models to be compatible with official code release at https://github.com/google-research/vision_transformer +* Add Vision Transformer weights (ImageNet-21k pretrain) for 384x384 base and large models converted from official jax impl + * ViT-B/16 - 84.2 + * ViT-B/32 - 81.7 + * ViT-L/16 - 85.2 + * ViT-L/32 - 81.5 + +### Oct 21, 2020 +* Weights added for Vision Transformer (ViT) models. 77.86 top-1 for 'small' and 79.35 for 'base'. Thanks to [Christof](https://www.kaggle.com/christofhenkel) for training the base model w/ lots of GPUs. + +### Oct 13, 2020 +* Initial impl of Vision Transformer models. Both patch and hybrid (CNN backbone) variants. Currently trying to train... +* Adafactor and AdaHessian (FP32 only, no AMP) optimizers +* EdgeTPU-M (`efficientnet_em`) model trained in PyTorch, 79.3 top-1 +* Pip release, doc updates pending a few more changes... + ### Sept 18, 2020 * New ResNet 'D' weights. 72.7 (top-1) ResNet-18-D, 77.1 ResNet-34-D, 80.5 ResNet-50-D * Added a few untrained defs for other ResNet models (66D, 101D, 152D, 200/200D) diff --git a/docs/changes.md b/docs/changes.md index b0ac125c..9719dd65 100644 --- a/docs/changes.md +++ b/docs/changes.md @@ -1,5 +1,33 @@ # Recent Changes +### May 25, 2021 +* Add LeViT, Visformer, Convit (PR by Aman Arora), Twins (PR by paper authors) transformer models +* Cleanup input_size/img_size override handling and testing for all vision transformer models +* Add `efficientnetv2_rw_m` model and weights (started training before official code). 84.8 top-1, 53M params. + +### May 14, 2021 +* Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl. + * 1k trained variants: `tf_efficientnetv2_s/m/l` + * 21k trained variants: `tf_efficientnetv2_s/m/l_in21k` + * 21k pretrained -> 1k fine-tuned: `tf_efficientnetv2_s/m/l_in21ft1k` + * v2 models w/ v1 scaling: `tf_efficientnetv2_b0` through `b3` + * Rename my prev V2 guess `efficientnet_v2s` -> `efficientnetv2_rw_s` + * Some blank `efficientnetv2_*` models in-place for future native PyTorch training + +### May 5, 2021 +* Add MLP-Mixer models and port pretrained weights from [Google JAX impl](https://github.com/google-research/vision_transformer/tree/linen) +* Add CaiT models and pretrained weights from [FB](https://github.com/facebookresearch/deit) +* Add ResNet-RS models and weights from [TF](https://github.com/tensorflow/tpu/tree/master/models/official/resnet/resnet_rs). Thanks [Aman Arora](https://github.com/amaarora) +* Add CoaT models and weights. Thanks [Mohammed Rizin](https://github.com/morizin) +* Add new ImageNet-21k weights & finetuned weights for TResNet, MobileNet-V3, ViT models. Thanks [mrT](https://github.com/mrT23) +* Add GhostNet models and weights. Thanks [Kai Han](https://github.com/iamhankai) +* Update ByoaNet attention modles + * Improve SA module inits + * Hack together experimental stand-alone Swin based attn module and `swinnet` + * Consistent '26t' model defs for experiments. +* Add improved Efficientnet-V2S (prelim model def) weights. 83.8 top-1. +* WandB logging support + ### April 13, 2021 * Add Swin Transformer models and weights from https://github.com/microsoft/Swin-Transformer diff --git a/tests/test_models.py b/tests/test_models.py index fa148133..44cb3ba2 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -16,7 +16,8 @@ if hasattr(torch._C, '_jit_set_profiling_executor'): # transformer models don't support many of the spatial / feature based model functionalities NON_STD_FILTERS = [ - 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', 'convit_*'] + 'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*', + 'convit_*', 'levit*', 'visformer*'] NUM_NON_STD = len(NON_STD_FILTERS) # exclude models that cause specific test failures @@ -25,29 +26,56 @@ if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): EXCLUDE_FILTERS = [ '*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', - '*resnetrs350*', '*resnetrs420*'] + NON_STD_FILTERS + '*resnetrs350*', '*resnetrs420*'] else: - EXCLUDE_FILTERS = NON_STD_FILTERS + EXCLUDE_FILTERS = [] -MAX_FWD_SIZE = 384 -MAX_BWD_SIZE = 128 -MAX_FWD_FEAT_SIZE = 448 +TARGET_FWD_SIZE = MAX_FWD_SIZE = 384 +TARGET_BWD_SIZE = 128 +MAX_BWD_SIZE = 320 +MAX_FWD_OUT_SIZE = 448 +TARGET_JIT_SIZE = 128 +MAX_JIT_SIZE = 320 +TARGET_FFEAT_SIZE = 96 +MAX_FFEAT_SIZE = 256 + + +def _get_input_size(model=None, model_name='', target=None): + if model is None: + assert model_name, "One of model or model_name must be provided" + input_size = get_model_default_value(model_name, 'input_size') + fixed_input_size = get_model_default_value(model_name, 'fixed_input_size') + min_input_size = get_model_default_value(model_name, 'min_input_size') + else: + default_cfg = model.default_cfg + input_size = default_cfg['input_size'] + fixed_input_size = default_cfg.get('fixed_input_size', None) + min_input_size = default_cfg.get('min_input_size', None) + assert input_size is not None + + if fixed_input_size: + return input_size + + if min_input_size: + if target and max(input_size) > target: + input_size = min_input_size + else: + if target and max(input_size) > target: + input_size = tuple([min(x, target) for x in input_size]) + return input_size @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-NUM_NON_STD])) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward(model_name, batch_size): """Run a single forward pass with each model""" model = create_model(model_name, pretrained=False) model.eval() - input_size = model.default_cfg['input_size'] - if any([x > MAX_FWD_SIZE for x in input_size]): - if is_model_default_key(model_name, 'fixed_input_size'): - pytest.skip("Fixed input size model > limit.") - # cap forward test at max res 384 * 384 to keep resource down - input_size = tuple([min(x, MAX_FWD_SIZE) for x in input_size]) + input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE) + if max(input_size) > MAX_FWD_SIZE: + pytest.skip("Fixed input size model > limit.") inputs = torch.randn((batch_size, *input_size)) outputs = model(inputs) @@ -56,26 +84,22 @@ def test_model_forward(model_name, batch_size): @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True)) @pytest.mark.parametrize('batch_size', [2]) def test_model_backward(model_name, batch_size): """Run a single forward pass with each model""" + input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE) + if max(input_size) > MAX_BWD_SIZE: + pytest.skip("Fixed input size model > limit.") + model = create_model(model_name, pretrained=False, num_classes=42) num_params = sum([x.numel() for x in model.parameters()]) - model.eval() - - input_size = model.default_cfg['input_size'] - if not is_model_default_key(model_name, 'fixed_input_size'): - min_input_size = get_model_default_value(model_name, 'min_input_size') - if min_input_size is not None: - input_size = min_input_size - else: - if any([x > MAX_BWD_SIZE for x in input_size]): - # cap backward test at 128 * 128 to keep resource usage down - input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size]) + model.train() inputs = torch.randn((batch_size, *input_size)) outputs = model(inputs) + if isinstance(outputs, tuple): + outputs = torch.cat(outputs) outputs.mean().backward() for n, x in model.named_parameters(): assert x.grad is not None, f'No gradient for {n}' @@ -100,10 +124,10 @@ def test_model_default_cfgs(model_name, batch_size): pool_size = cfg['pool_size'] input_size = model.default_cfg['input_size'] - if all([x <= MAX_FWD_FEAT_SIZE for x in input_size]) and \ + if all([x <= MAX_FWD_OUT_SIZE for x in input_size]) and \ not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]): # output sizes only checked if default res <= 448 * 448 to keep resource down - input_size = tuple([min(x, MAX_FWD_FEAT_SIZE) for x in input_size]) + input_size = tuple([min(x, MAX_FWD_OUT_SIZE) for x in input_size]) input_tensor = torch.randn((batch_size, *input_size)) # test forward_features (always unpooled) @@ -154,26 +178,25 @@ if 'GITHUB_ACTIONS' not in os.environ: EXCLUDE_JIT_FILTERS = [ '*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable - 'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point + 'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point + 'vit_large_*', 'vit_huge_*', ] @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS)) +@pytest.mark.parametrize( + 'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True)) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward_torchscript(model_name, batch_size): """Run a single forward pass with each model""" + input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) + if max(input_size) > MAX_JIT_SIZE: + pytest.skip("Fixed input size model > limit.") + with set_scriptable(True): model = create_model(model_name, pretrained=False) model.eval() - if has_model_default_key(model_name, 'fixed_input_size'): - input_size = get_model_default_value(model_name, 'input_size') - elif has_model_default_key(model_name, 'min_input_size'): - input_size = get_model_default_value(model_name, 'min_input_size') - else: - input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already... - model = torch.jit.script(model) outputs = model(torch.randn((batch_size, *input_size))) @@ -183,7 +206,7 @@ def test_model_forward_torchscript(model_name, batch_size): EXCLUDE_FEAT_FILTERS = [ '*pruned*', # hopefully fix at some point -] +] + NON_STD_FILTERS if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system(): # GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models EXCLUDE_FEAT_FILTERS += ['*resnext101_32x32d', '*resnext101_32x16d'] @@ -199,12 +222,9 @@ def test_model_forward_features(model_name, batch_size): expected_channels = model.feature_info.channels() assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6 - if has_model_default_key(model_name, 'fixed_input_size'): - input_size = get_model_default_value(model_name, 'input_size') - elif has_model_default_key(model_name, 'min_input_size'): - input_size = get_model_default_value(model_name, 'min_input_size') - else: - input_size = (3, 96, 96) # jit compile is already a bit slow and we've tested normal res already... + input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE) + if max(input_size) > MAX_FFEAT_SIZE: + pytest.skip("Fixed input size model > limit.") outputs = model(torch.randn((batch_size, *input_size))) assert len(expected_channels) == len(outputs) diff --git a/timm/data/parsers/parser_tfds.py b/timm/data/parsers/parser_tfds.py index 0b12a4db..2ff90b09 100644 --- a/timm/data/parsers/parser_tfds.py +++ b/timm/data/parsers/parser_tfds.py @@ -25,8 +25,8 @@ from .parser import Parser MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities -SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue -PREFETCH_SIZE = 4096 # samples to prefetch +SHUFFLE_SIZE = 20480 # samples to shuffle in DS queue +PREFETCH_SIZE = 2048 # samples to prefetch def even_split_indices(split, n, num_samples): @@ -144,14 +144,16 @@ class ParserTfds(Parser): ds = self.builder.as_dataset( split=self.subsplit or self.split, shuffle_files=self.shuffle, read_config=read_config) # avoid overloading threading w/ combo fo TF ds threads + PyTorch workers - ds.options().experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) - ds.options().experimental_threading.max_intra_op_parallelism = 1 + options = tf.data.Options() + options.experimental_threading.private_threadpool_size = max(1, MAX_TP_SIZE // num_workers) + options.experimental_threading.max_intra_op_parallelism = 1 + ds = ds.with_options(options) if self.is_training or self.repeats > 1: # to prevent excessive drop_last batch behaviour w/ IterableDatasets # see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading ds = ds.repeat() # allow wrap around and break iteration manually if self.shuffle: - ds = ds.shuffle(min(self.num_samples // self._num_pipelines, SHUFFLE_SIZE), seed=0) + ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=0) ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE)) self.ds = tfds.as_numpy(ds) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 0488094c..788b7518 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -16,6 +16,8 @@ from .hrnet import * from .inception_resnet_v2 import * from .inception_v3 import * from .inception_v4 import * +from .levit import * +#from .levit import * from .mlp_mixer import * from .mobilenetv3 import * from .nasnet import * @@ -35,6 +37,7 @@ from .swin_transformer import * from .tnt import * from .tresnet import * from .vgg import * +from .visformer import * from .vision_transformer import * from .vision_transformer_hybrid import * from .vovnet import * diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index ca49089b..c179a01c 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -47,17 +47,24 @@ default_cfgs = { # GPU-Efficient (ResNet) weights 'botnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'botnet50ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), + 'eca_botnext26ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'halonet26t': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'halonet50ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), + 'eca_halonext26ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)), 'lambda_resnet26t': _cfg(url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), 'lambda_resnet50t': _cfg(url='', min_input_size=(3, 128, 128)), + 'eca_lambda_resnext26ts': _cfg(url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)), 'swinnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'swinnet50ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), + 'eca_swinnext26ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), + + 'rednet26t': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)), + 'rednet50ts': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)), } @@ -126,6 +133,23 @@ model_cfgs = dict( self_attn_fixed_size=True, self_attn_kwargs=dict() ), + eca_botnext26ts=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25), + ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25), + interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + act_layer='silu', + attn_layer='eca', + self_attn_layer='bottleneck', + self_attn_fixed_size=True, + self_attn_kwargs=dict() + ), halonet_h1=ByoaCfg( blocks=( @@ -184,6 +208,22 @@ model_cfgs = dict( self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=2) ), + eca_halonext26ts=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), + ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25), + interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + act_layer='silu', + attn_layer='eca', + self_attn_layer='halo', + self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res + ), lambda_resnet26t=ByoaCfg( blocks=( @@ -213,6 +253,22 @@ model_cfgs = dict( self_attn_layer='lambda', self_attn_kwargs=dict() ), + eca_lambda_resnext26ts=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), + ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25), + interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + act_layer='silu', + attn_layer='eca', + self_attn_layer='lambda', + self_attn_kwargs=dict() + ), swinnet26t=ByoaCfg( blocks=( @@ -245,6 +301,56 @@ model_cfgs = dict( self_attn_fixed_size=True, self_attn_kwargs=dict(win_size=8) ), + eca_swinnext26ts=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), + interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=16, br=0.25), + interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + act_layer='silu', + attn_layer='eca', + self_attn_layer='swin', + self_attn_fixed_size=True, + self_attn_kwargs=dict(win_size=8) + ), + + + rednet26t=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='self_attn', d=2, c=256, s=1, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=512, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', # FIXME RedNet uses involution in middle of stem + stem_pool='maxpool', + num_features=0, + self_attn_layer='involution', + self_attn_fixed_size=False, + self_attn_kwargs=dict() + ), + rednet50ts=ByoaCfg( + blocks=( + ByoaBlocksCfg(type='self_attn', d=3, c=256, s=1, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=4, c=512, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25), + ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool='maxpool', + num_features=0, + act_layer='silu', + self_attn_layer='involution', + self_attn_fixed_size=False, + self_attn_kwargs=dict() + ), ) @@ -419,6 +525,14 @@ def botnet50ts_256(pretrained=False, **kwargs): return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs) +@register_model +def eca_botnext26ts_256(pretrained=False, **kwargs): + """ Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final stage. + """ + kwargs.setdefault('img_size', 256) + return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs) + + @register_model def halonet_h1(pretrained=False, **kwargs): """ HaloNet-H1. Halo attention in all stages as per the paper. @@ -449,6 +563,13 @@ def halonet50ts(pretrained=False, **kwargs): return _create_byoanet('halonet50ts', pretrained=pretrained, **kwargs) +@register_model +def eca_halonext26ts(pretrained=False, **kwargs): + """ HaloNet w/ a ResNet26-t backbone, Hallo attention in final stage + """ + return _create_byoanet('eca_halonext26ts', pretrained=pretrained, **kwargs) + + @register_model def lambda_resnet26t(pretrained=False, **kwargs): """ Lambda-ResNet-26T. Lambda layers in one C4 stage and all C5. @@ -463,6 +584,13 @@ def lambda_resnet50t(pretrained=False, **kwargs): return _create_byoanet('lambda_resnet50t', pretrained=pretrained, **kwargs) +@register_model +def eca_lambda_resnext26ts(pretrained=False, **kwargs): + """ Lambda-ResNet-26T. Lambda layers in one C4 stage and all C5. + """ + return _create_byoanet('eca_lambda_resnext26ts', pretrained=pretrained, **kwargs) + + @register_model def swinnet26t_256(pretrained=False, **kwargs): """ @@ -477,3 +605,25 @@ def swinnet50ts_256(pretrained=False, **kwargs): """ kwargs.setdefault('img_size', 256) return _create_byoanet('swinnet50ts_256', 'swinnet50ts', pretrained=pretrained, **kwargs) + + +@register_model +def eca_swinnext26ts_256(pretrained=False, **kwargs): + """ + """ + kwargs.setdefault('img_size', 256) + return _create_byoanet('eca_swinnext26ts_256', 'eca_swinnext26ts', pretrained=pretrained, **kwargs) + + +@register_model +def rednet26t(pretrained=False, **kwargs): + """ + """ + return _create_byoanet('rednet26t', pretrained=pretrained, **kwargs) + + +@register_model +def rednet50ts(pretrained=False, **kwargs): + """ + """ + return _create_byoanet('rednet50ts', pretrained=pretrained, **kwargs) diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 75610f67..8f4a2020 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -98,7 +98,7 @@ class BlocksCfg: s: int = 2 # stride of stage (first block) gs: Optional[Union[int, Callable]] = None # group-size of blocks in stage, conv is depthwise if gs == 1 br: float = 1. # bottleneck-ratio of blocks in stage - no_attn: bool = True # disable channel attn (ie SE) when layer is set for model + no_attn: bool = False # disable channel attn (ie SE) when layer is set for model @dataclass diff --git a/timm/models/cait.py b/timm/models/cait.py index c5f7742f..aa2e5f07 100644 --- a/timm/models/cait.py +++ b/timm/models/cait.py @@ -306,26 +306,15 @@ def checkpoint_filter_fn(state_dict, model=None): return checkpoint_no_module -def _create_cait(variant, pretrained=False, default_cfg=None, **kwargs): - if default_cfg is None: - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] - num_classes = kwargs.pop('num_classes', default_num_classes) - img_size = kwargs.pop('img_size', default_img_size) - +def _create_cait(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') model = build_model_with_cfg( Cait, variant, pretrained, - default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, + default_cfg=default_cfgs[variant], pretrained_filter_fn=checkpoint_filter_fn, **kwargs) - return model diff --git a/timm/models/coat.py b/timm/models/coat.py index cb265522..9eb384d8 100644 --- a/timm/models/coat.py +++ b/timm/models/coat.py @@ -7,19 +7,19 @@ Official CoaT code at: https://github.com/mlpc-ucsd/CoaT Modified from timm/models/vision_transformer.py """ -from typing import Tuple, Dict, Any, Optional +from copy import deepcopy +from functools import partial +from typing import Tuple, List import torch import torch.nn as nn import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.helpers import load_pretrained -from timm.models.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ -from timm.models.registry import register_model +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_ +from .registry import register_model -from functools import partial -from torch import nn __all__ = [ "coat_tiny", @@ -34,7 +34,7 @@ def _cfg_coat(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'first_conv': 'patch_embed1.proj', 'classifier': 'head', **kwargs @@ -42,15 +42,21 @@ def _cfg_coat(url='', **kwargs): default_cfgs = { - 'coat_tiny': _cfg_coat(), - 'coat_mini': _cfg_coat(), + 'coat_tiny': _cfg_coat( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_tiny-473c2a20.pth' + ), + 'coat_mini': _cfg_coat( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_mini-2c6baf49.pth' + ), 'coat_lite_tiny': _cfg_coat( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_tiny-461b07a7.pth' ), 'coat_lite_mini': _cfg_coat( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_mini-d7842000.pth' ), - 'coat_lite_small': _cfg_coat(), + 'coat_lite_small': _cfg_coat( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_small-fea1d5a1.pth' + ), } @@ -120,11 +126,11 @@ class ConvRelPosEnc(nn.Module): class FactorAtt_ConvRelPosEnc(nn.Module): """ Factorized attention with convolutional relative position encoding class. """ - def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., shared_crpe=None): + def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., shared_crpe=None): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = head_dim ** -0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used. @@ -190,9 +196,8 @@ class ConvPosEnc(nn.Module): class SerialBlock(nn.Module): """ Serial block class. Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """ - def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, - shared_cpe=None, shared_crpe=None): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_cpe=None, shared_crpe=None): super().__init__() # Conv-Attention. @@ -200,8 +205,7 @@ class SerialBlock(nn.Module): self.norm1 = norm_layer(dim) self.factoratt_crpe = FactorAtt_ConvRelPosEnc( - dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, - shared_crpe=shared_crpe) + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpe) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() # MLP. @@ -226,27 +230,24 @@ class SerialBlock(nn.Module): class ParallelBlock(nn.Module): """ Parallel block class. """ - def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., - drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, - shared_cpes=None, shared_crpes=None): + def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_crpes=None): super().__init__() # Conv-Attention. - self.cpes = shared_cpes - self.norm12 = norm_layer(dims[1]) self.norm13 = norm_layer(dims[2]) self.norm14 = norm_layer(dims[3]) self.factoratt_crpe2 = FactorAtt_ConvRelPosEnc( - dims[1], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + dims[1], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpes[1] ) self.factoratt_crpe3 = FactorAtt_ConvRelPosEnc( - dims[2], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + dims[2], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpes[2] ) self.factoratt_crpe4 = FactorAtt_ConvRelPosEnc( - dims[3], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, + dims[3], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpes[3] ) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() @@ -262,15 +263,15 @@ class ParallelBlock(nn.Module): self.mlp2 = self.mlp3 = self.mlp4 = Mlp( in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) - def upsample(self, x, factor, size): + def upsample(self, x, factor: float, size: Tuple[int, int]): """ Feature map up-sampling. """ return self.interpolate(x, scale_factor=factor, size=size) - def downsample(self, x, factor, size): + def downsample(self, x, factor: float, size: Tuple[int, int]): """ Feature map down-sampling. """ return self.interpolate(x, scale_factor=1.0/factor, size=size) - def interpolate(self, x, scale_factor, size): + def interpolate(self, x, scale_factor: float, size: Tuple[int, int]): """ Feature map interpolation. """ B, N, C = x.shape H, W = size @@ -280,33 +281,28 @@ class ParallelBlock(nn.Module): img_tokens = x[:, 1:, :] img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W) - img_tokens = F.interpolate(img_tokens, scale_factor=scale_factor, mode='bilinear') + img_tokens = F.interpolate( + img_tokens, scale_factor=scale_factor, recompute_scale_factor=False, mode='bilinear', align_corners=False) img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2) out = torch.cat((cls_token, img_tokens), dim=1) return out - def forward(self, x1, x2, x3, x4, sizes): - _, (H2, W2), (H3, W3), (H4, W4) = sizes - - # Conv-Attention. - x2 = self.cpes[1](x2, size=(H2, W2)) # Note: x1 is ignored. - x3 = self.cpes[2](x3, size=(H3, W3)) - x4 = self.cpes[3](x4, size=(H4, W4)) - + def forward(self, x1, x2, x3, x4, sizes: List[Tuple[int, int]]): + _, S2, S3, S4 = sizes cur2 = self.norm12(x2) cur3 = self.norm13(x3) cur4 = self.norm14(x4) - cur2 = self.factoratt_crpe2(cur2, size=(H2, W2)) - cur3 = self.factoratt_crpe3(cur3, size=(H3, W3)) - cur4 = self.factoratt_crpe4(cur4, size=(H4, W4)) - upsample3_2 = self.upsample(cur3, factor=2, size=(H3, W3)) - upsample4_3 = self.upsample(cur4, factor=2, size=(H4, W4)) - upsample4_2 = self.upsample(cur4, factor=4, size=(H4, W4)) - downsample2_3 = self.downsample(cur2, factor=2, size=(H2, W2)) - downsample3_4 = self.downsample(cur3, factor=2, size=(H3, W3)) - downsample2_4 = self.downsample(cur2, factor=4, size=(H2, W2)) + cur2 = self.factoratt_crpe2(cur2, size=S2) + cur3 = self.factoratt_crpe3(cur3, size=S3) + cur4 = self.factoratt_crpe4(cur4, size=S4) + upsample3_2 = self.upsample(cur3, factor=2., size=S3) + upsample4_3 = self.upsample(cur4, factor=2., size=S4) + upsample4_2 = self.upsample(cur4, factor=4., size=S4) + downsample2_3 = self.downsample(cur2, factor=2., size=S2) + downsample3_4 = self.downsample(cur3, factor=2., size=S3) + downsample2_4 = self.downsample(cur2, factor=4., size=S2) cur2 = cur2 + upsample3_2 + upsample4_2 cur3 = cur3 + upsample4_3 + downsample2_3 cur4 = cur4 + downsample3_4 + downsample2_4 @@ -330,11 +326,11 @@ class ParallelBlock(nn.Module): class CoaT(nn.Module): """ CoaT class. """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[0, 0, 0, 0], - serial_depths=[0, 0, 0, 0], parallel_depth=0, - num_heads=0, mlp_ratios=[0, 0, 0, 0], qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., - drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), - return_interm_layers=False, out_features = None, crpe_window=None, **kwargs): + def __init__( + self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=(0, 0, 0, 0), + serial_depths=(0, 0, 0, 0), parallel_depth=0, num_heads=0, mlp_ratios=(0, 0, 0, 0), qkv_bias=True, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), + return_interm_layers=False, out_features=None, crpe_window=None, **kwargs): super().__init__() crpe_window = crpe_window or {3: 2, 5: 3, 7: 3} self.return_interm_layers = return_interm_layers @@ -342,17 +338,18 @@ class CoaT(nn.Module): self.num_classes = num_classes # Patch embeddings. + img_size = to_2tuple(img_size) self.patch_embed1 = PatchEmbed( img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dims[0], norm_layer=nn.LayerNorm) self.patch_embed2 = PatchEmbed( - img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0], + img_size=[x // 4 for x in img_size], patch_size=2, in_chans=embed_dims[0], embed_dim=embed_dims[1], norm_layer=nn.LayerNorm) self.patch_embed3 = PatchEmbed( - img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1], + img_size=[x // 8 for x in img_size], patch_size=2, in_chans=embed_dims[1], embed_dim=embed_dims[2], norm_layer=nn.LayerNorm) self.patch_embed4 = PatchEmbed( - img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2], + img_size=[x // 16 for x in img_size], patch_size=2, in_chans=embed_dims[2], embed_dim=embed_dims[3], norm_layer=nn.LayerNorm) # Class tokens. @@ -380,7 +377,7 @@ class CoaT(nn.Module): # Serial blocks 1. self.serial_blocks1 = nn.ModuleList([ SerialBlock( - dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, shared_cpe=self.cpe1, shared_crpe=self.crpe1 ) @@ -390,7 +387,7 @@ class CoaT(nn.Module): # Serial blocks 2. self.serial_blocks2 = nn.ModuleList([ SerialBlock( - dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, shared_cpe=self.cpe2, shared_crpe=self.crpe2 ) @@ -400,7 +397,7 @@ class CoaT(nn.Module): # Serial blocks 3. self.serial_blocks3 = nn.ModuleList([ SerialBlock( - dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, shared_cpe=self.cpe3, shared_crpe=self.crpe3 ) @@ -410,7 +407,7 @@ class CoaT(nn.Module): # Serial blocks 4. self.serial_blocks4 = nn.ModuleList([ SerialBlock( - dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale, + dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, shared_cpe=self.cpe4, shared_crpe=self.crpe4 ) @@ -422,10 +419,9 @@ class CoaT(nn.Module): if self.parallel_depth > 0: self.parallel_blocks = nn.ModuleList([ ParallelBlock( - dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias, qk_scale=qk_scale, - drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, - shared_cpes=[self.cpe1, self.cpe2, self.cpe3, self.cpe4], - shared_crpes=[self.crpe1, self.crpe2, self.crpe3, self.crpe4] + dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer, + shared_crpes=(self.crpe1, self.crpe2, self.crpe3, self.crpe4) ) for _ in range(parallel_depth)] ) @@ -434,9 +430,11 @@ class CoaT(nn.Module): # Classification head(s). if not self.return_interm_layers: - self.norm1 = norm_layer(embed_dims[0]) - self.norm2 = norm_layer(embed_dims[1]) - self.norm3 = norm_layer(embed_dims[2]) + if self.parallel_blocks is not None: + self.norm2 = norm_layer(embed_dims[1]) + self.norm3 = norm_layer(embed_dims[2]) + else: + self.norm2 = self.norm3 = None self.norm4 = norm_layer(embed_dims[3]) if self.parallel_depth > 0: @@ -546,6 +544,7 @@ class CoaT(nn.Module): # Parallel blocks. for blk in self.parallel_blocks: + x2, x3, x4 = self.cpe2(x2, (H2, W2)), self.cpe3(x3, (H3, W3)), self.cpe4(x4, (H4, W4)) x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)]) if not torch.jit.is_scripting() and self.return_interm_layers: @@ -590,52 +589,70 @@ class CoaT(nn.Module): return x +def checkpoint_filter_fn(state_dict, model): + out_dict = {} + for k, v in state_dict.items(): + # original model had unused norm layers, removing them requires filtering pretrained checkpoints + if k.startswith('norm1') or \ + (model.norm2 is None and k.startswith('norm2')) or \ + (model.norm3 is None and k.startswith('norm3')): + continue + out_dict[k] = v + return out_dict + + +def _create_coat(variant, pretrained=False, default_cfg=None, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + CoaT, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + @register_model def coat_tiny(pretrained=False, **kwargs): - model = CoaT( + model_cfg = dict( patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) - model.default_cfg = default_cfgs['coat_tiny'] + model = _create_coat('coat_tiny', pretrained=pretrained, **model_cfg) return model @register_model def coat_mini(pretrained=False, **kwargs): - model = CoaT( + model_cfg = dict( patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6, num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs) - model.default_cfg = default_cfgs['coat_mini'] + model = _create_coat('coat_mini', pretrained=pretrained, **model_cfg) return model @register_model def coat_lite_tiny(pretrained=False, **kwargs): - model = CoaT( + model_cfg = dict( patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) - # FIXME use builder - model.default_cfg = default_cfgs['coat_lite_tiny'] - if pretrained: - load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model = _create_coat('coat_lite_tiny', pretrained=pretrained, **model_cfg) return model @register_model def coat_lite_mini(pretrained=False, **kwargs): - model = CoaT( + model_cfg = dict( patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) - # FIXME use builder - model.default_cfg = default_cfgs['coat_lite_mini'] - if pretrained: - load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model = _create_coat('coat_lite_mini', pretrained=pretrained, **model_cfg) return model @register_model def coat_lite_small(pretrained=False, **kwargs): - model = CoaT( + model_cfg = dict( patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0, num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs) - model.default_cfg = default_cfgs['coat_lite_small'] + model = _create_coat('coat_lite_small', pretrained=pretrained, **model_cfg) return model \ No newline at end of file diff --git a/timm/models/convit.py b/timm/models/convit.py index f6ae3ec1..b15b46d8 100644 --- a/timm/models/convit.py +++ b/timm/models/convit.py @@ -39,7 +39,7 @@ def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True, 'first_conv': 'patch_embed.proj', 'classifier': 'head', **kwargs } @@ -317,6 +317,9 @@ class ConViT(nn.Module): def _create_convit(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + return build_model_with_cfg( ConViT, variant, pretrained, default_cfg=default_cfgs[variant], diff --git a/timm/models/helpers.py b/timm/models/helpers.py index e9ac7f00..dfb6b860 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -44,7 +44,7 @@ def load_state_dict(checkpoint_path, use_ema=False): raise FileNotFoundError() -def load_checkpoint(model, checkpoint_path, use_ema=False, strict=True): +def load_checkpoint(model, checkpoint_path, use_ema=False, strict=False): state_dict = load_state_dict(checkpoint_path, use_ema) model.load_state_dict(state_dict, strict=strict) @@ -378,7 +378,11 @@ def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter): # Overlay default cfg values from `external_default_cfg` if it exists in kwargs overlay_external_default_cfg(default_cfg, kwargs) # Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs) - set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg) + default_kwarg_names = ('num_classes', 'global_pool', 'in_chans') + if default_cfg.get('fixed_input_size', False): + # if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size + default_kwarg_names += ('img_size',) + set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg) # Filter keyword args for task specific model variants (some 'features only' models, etc.) filter_kwargs(kwargs, names=kwargs_filter) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 4aae99e3..cd192281 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -18,6 +18,7 @@ from .eca import EcaModule, CecaModule from .evo_norm import EvoNormBatch2d, EvoNormSample2d from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible from .inplace_abn import InplaceAbn +from .involution import Involution from .linear import Linear from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp diff --git a/timm/models/layers/create_self_attn.py b/timm/models/layers/create_self_attn.py index ba208f17..448ddb34 100644 --- a/timm/models/layers/create_self_attn.py +++ b/timm/models/layers/create_self_attn.py @@ -1,5 +1,6 @@ from .bottleneck_attn import BottleneckAttn from .halo_attn import HaloAttn +from .involution import Involution from .lambda_layer import LambdaLayer from .swin_attn import WindowAttention @@ -13,6 +14,8 @@ def get_self_attn(attn_type): return LambdaLayer elif attn_type == 'swin': return WindowAttention + elif attn_type == 'involution': + return Involution else: assert False, f"Unknown attn type ({attn_type})" diff --git a/timm/models/layers/involution.py b/timm/models/layers/involution.py new file mode 100644 index 00000000..0dba9fae --- /dev/null +++ b/timm/models/layers/involution.py @@ -0,0 +1,50 @@ +""" PyTorch Involution Layer + +Official impl: https://github.com/d-li14/involution/blob/main/cls/mmcls/models/utils/involution_naive.py +Paper: `Involution: Inverting the Inherence of Convolution for Visual Recognition` - https://arxiv.org/abs/2103.06255 +""" +import torch.nn as nn +from .conv_bn_act import ConvBnAct +from .create_conv2d import create_conv2d + + +class Involution(nn.Module): + + def __init__( + self, + channels, + kernel_size=3, + stride=1, + group_size=16, + reduction_ratio=4, + norm_layer=nn.BatchNorm2d, + act_layer=nn.ReLU, + ): + super(Involution, self).__init__() + self.kernel_size = kernel_size + self.stride = stride + self.channels = channels + self.group_size = group_size + self.groups = self.channels // self.group_size + self.conv1 = ConvBnAct( + in_channels=channels, + out_channels=channels // reduction_ratio, + kernel_size=1, + norm_layer=norm_layer, + act_layer=act_layer) + self.conv2 = self.conv = create_conv2d( + in_channels=channels // reduction_ratio, + out_channels=kernel_size**2 * self.groups, + kernel_size=1, + stride=1) + self.avgpool = nn.AvgPool2d(stride, stride) if stride == 2 else nn.Identity() + self.unfold = nn.Unfold(kernel_size, 1, (kernel_size-1)//2, stride) + + def forward(self, x): + weight = self.conv2(self.conv1(self.avgpool(x))) + B, C, H, W = weight.shape + KK = int(self.kernel_size ** 2) + weight = weight.view(B, self.groups, KK, H, W).unsqueeze(2) + out = self.unfold(x).view(B, self.groups, self.group_size, KK, H, W) + out = (weight * out).sum(dim=3).view(B, self.channels, H, W) + return out diff --git a/timm/models/layers/patch_embed.py b/timm/models/layers/patch_embed.py index b06f9982..42997fb8 100644 --- a/timm/models/layers/patch_embed.py +++ b/timm/models/layers/patch_embed.py @@ -15,7 +15,7 @@ from .helpers import to_2tuple class PatchEmbed(nn.Module): """ 2D Image to Patch Embedding """ - def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None): + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) @@ -23,6 +23,7 @@ class PatchEmbed(nn.Module): self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() @@ -31,6 +32,8 @@ class PatchEmbed(nn.Module): B, C, H, W = x.shape assert H == self.img_size[0] and W == self.img_size[1], \ f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." - x = self.proj(x).flatten(2).transpose(1, 2) + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC x = self.norm(x) return x diff --git a/timm/models/levit.py b/timm/models/levit.py new file mode 100644 index 00000000..5019ee9a --- /dev/null +++ b/timm/models/levit.py @@ -0,0 +1,568 @@ +""" LeViT + +Paper: `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference` + - https://arxiv.org/abs/2104.01136 + +@article{graham2021levit, + title={LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference}, + author={Benjamin Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Herv\'e J\'egou and Matthijs Douze}, + journal={arXiv preprint arXiv:22104.01136}, + year={2021} +} + +Adapted from official impl at https://github.com/facebookresearch/LeViT, original copyright bellow. + +This version combines both conv/linear models and fixes torchscript compatibility. + +Modifications by/coyright Copyright 2021 Ross Wightman +""" + +# Copyright (c) 2015-present, Facebook, Inc. +# All rights reserved. + +# Modified from +# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py +# Copyright 2020 Ross Wightman, Apache-2.0 License +import itertools +from copy import deepcopy +from functools import partial +from typing import Dict + +import torch +import torch.nn as nn + +from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import to_ntuple +from .vision_transformer import trunc_normal_ +from .registry import register_model + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'patch_embed.0.c', 'classifier': ('head.l', 'head_dist.l'), + **kwargs + } + + +default_cfgs = dict( + levit_128s=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth' + ), + levit_128=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth' + ), + levit_192=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth' + ), + levit_256=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth' + ), + levit_384=_cfg( + url='https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth' + ), +) + +model_cfgs = dict( + levit_128s=dict( + embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)), + levit_128=dict( + embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)), + levit_192=dict( + embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)), + levit_256=dict( + embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)), + levit_384=dict( + embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)), +) + +__all__ = ['Levit'] + + +@register_model +def levit_128s(pretrained=False, fuse=False,distillation=True, use_conv=False, **kwargs): + return create_levit( + 'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +@register_model +def levit_128(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): + return create_levit( + 'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +@register_model +def levit_192(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): + return create_levit( + 'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +@register_model +def levit_256(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): + return create_levit( + 'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +@register_model +def levit_384(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs): + return create_levit( + 'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +@register_model +def levit_c_128s(pretrained=False, fuse=False, distillation=True, use_conv=True,**kwargs): + return create_levit( + 'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +@register_model +def levit_c_128(pretrained=False, fuse=False,distillation=True, use_conv=True, **kwargs): + return create_levit( + 'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +@register_model +def levit_c_192(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs): + return create_levit( + 'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +@register_model +def levit_c_256(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs): + return create_levit( + 'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +@register_model +def levit_c_384(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs): + return create_levit( + 'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs) + + +class ConvNorm(nn.Sequential): + def __init__( + self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000): + super().__init__() + self.add_module('c', nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = nn.BatchNorm2d(b) + nn.init.constant_(bn.weight, bn_weight_init) + nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = nn.Conv2d( + w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, + padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class LinearNorm(nn.Sequential): + def __init__(self, a, b, bn_weight_init=1, resolution=-100000): + super().__init__() + self.add_module('c', nn.Linear(a, b, bias=False)) + bn = nn.BatchNorm1d(b) + nn.init.constant_(bn.weight, bn_weight_init) + nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + @torch.no_grad() + def fuse(self): + l, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = l.weight * w[:, None] + b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 + m = nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + def forward(self, x): + x = self.c(x) + return self.bn(x.flatten(0, 1)).reshape_as(x) + + +class NormLinear(nn.Sequential): + def __init__(self, a, b, bias=True, std=0.02): + super().__init__() + self.add_module('bn', nn.BatchNorm1d(a)) + l = nn.Linear(a, b, bias=bias) + trunc_normal_(l.weight, std=std) + if bias: + nn.init.constant_(l.bias, 0) + self.add_module('l', l) + + @torch.no_grad() + def fuse(self): + bn, l = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps) ** 0.5 + b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5 + w = l.weight * w[None, :] + if l.bias is None: + b = b @ self.l.weight.T + else: + b = (l.weight @ b[:, None]).view(-1) + self.l.bias + m = nn.Linear(w.size(1), w.size(0)) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +def stem_b16(in_chs, out_chs, activation, resolution=224): + return nn.Sequential( + ConvNorm(in_chs, out_chs // 8, 3, 2, 1, resolution=resolution), + activation(), + ConvNorm(out_chs // 8, out_chs // 4, 3, 2, 1, resolution=resolution // 2), + activation(), + ConvNorm(out_chs // 4, out_chs // 2, 3, 2, 1, resolution=resolution // 4), + activation(), + ConvNorm(out_chs // 2, out_chs, 3, 2, 1, resolution=resolution // 8)) + + +class Residual(nn.Module): + def __init__(self, m, drop): + super().__init__() + self.m = m + self.drop = drop + + def forward(self, x): + if self.training and self.drop > 0: + return x + self.m(x) * torch.rand( + x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach() + else: + return x + self.m(x) + + +class Subsample(nn.Module): + def __init__(self, stride, resolution): + super().__init__() + self.stride = stride + self.resolution = resolution + + def forward(self, x): + B, N, C = x.shape + x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride] + return x.reshape(B, -1, C) + + +class Attention(nn.Module): + ab: Dict[str, torch.Tensor] + + def __init__( + self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False): + super().__init__() + + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + self.use_conv = use_conv + ln_layer = ConvNorm if self.use_conv else LinearNorm + h = self.dh + nh_kd * 2 + self.qkv = ln_layer(dim, h, resolution=resolution) + self.proj = nn.Sequential( + act_layer(), + ln_layer(self.dh, dim, bn_weight_init=0, resolution=resolution)) + + points = list(itertools.product(range(resolution), range(resolution))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N)) + self.ab = {} + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.ab: + self.ab = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.ab: + self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.ab[device_key] + + def forward(self, x): # x (B,C,H,W) + if self.use_conv: + B, C, H, W = x.shape + q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2) + + attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device) + attn = attn.softmax(dim=-1) + + x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W) + else: + B, N, C = x.shape + qkv = self.qkv(x) + q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class AttentionSubsample(nn.Module): + ab: Dict[str, torch.Tensor] + + def __init__( + self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2, + act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False): + super().__init__() + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = self.d * self.num_heads + self.attn_ratio = attn_ratio + self.resolution_ = resolution_ + self.resolution_2 = resolution_ ** 2 + self.use_conv = use_conv + if self.use_conv: + ln_layer = ConvNorm + sub_layer = partial(nn.AvgPool2d, kernel_size=1, padding=0) + else: + ln_layer = LinearNorm + sub_layer = partial(Subsample, resolution=resolution) + + h = self.dh + nh_kd + self.kv = ln_layer(in_dim, h, resolution=resolution) + self.q = nn.Sequential( + sub_layer(stride=stride), + ln_layer(in_dim, nh_kd, resolution=resolution_)) + self.proj = nn.Sequential( + act_layer(), + ln_layer(self.dh, out_dim, resolution=resolution_)) + + self.stride = stride + self.resolution = resolution + points = list(itertools.product(range(resolution), range(resolution))) + points_ = list(itertools.product(range(resolution_), range(resolution_))) + N = len(points) + N_ = len(points_) + attention_offsets = {} + idxs = [] + for p1 in points_: + for p2 in points: + size = 1 + offset = ( + abs(p1[0] * stride - p2[0] + (size - 1) / 2), + abs(p1[1] * stride - p2[1] + (size - 1) / 2)) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N)) + self.ab = {} # per-device attention_biases cache + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and self.ab: + self.ab = {} # clear ab cache + + def get_attention_biases(self, device: torch.device) -> torch.Tensor: + if self.training: + return self.attention_biases[:, self.attention_bias_idxs] + else: + device_key = str(device) + if device_key not in self.ab: + self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs] + return self.ab[device_key] + + def forward(self, x): + if self.use_conv: + B, C, H, W = x.shape + k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2) + q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2) + + attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device) + attn = attn.softmax(dim=-1) + + x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_) + else: + B, N, C = x.shape + k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3) + k = k.permute(0, 2, 1, 3) # BHNC + v = v.permute(0, 2, 1, 3) # BHNC + q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3) + + attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device) + attn = attn.softmax(dim=-1) + + x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh) + x = self.proj(x) + return x + + +class Levit(nn.Module): + """ Vision Transformer with support for patch or hybrid CNN input stage + """ + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + num_classes=1000, + embed_dim=(192,), + key_dim=64, + depth=(12,), + num_heads=(3,), + attn_ratio=2, + mlp_ratio=2, + hybrid_backbone=None, + down_ops=None, + act_layer=nn.Hardswish, + attn_act_layer=nn.Hardswish, + distillation=True, + use_conv=False, + drop_path=0): + super().__init__() + if isinstance(img_size, tuple): + # FIXME origin impl passes single img/res dim through whole hierarchy, + # not sure this model will be used enough to spend time fixing it. + assert img_size[0] == img_size[1] + img_size = img_size[0] + self.num_classes = num_classes + self.num_features = embed_dim[-1] + self.embed_dim = embed_dim + N = len(embed_dim) + assert len(depth) == len(num_heads) == N + key_dim = to_ntuple(N)(key_dim) + attn_ratio = to_ntuple(N)(attn_ratio) + mlp_ratio = to_ntuple(N)(mlp_ratio) + down_ops = down_ops or ( + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + ('Subsample', key_dim[0], embed_dim[0] // key_dim[0], 4, 2, 2), + ('Subsample', key_dim[0], embed_dim[1] // key_dim[1], 4, 2, 2), + ('',) + ) + self.distillation = distillation + self.use_conv = use_conv + ln_layer = ConvNorm if self.use_conv else LinearNorm + + self.patch_embed = hybrid_backbone or stem_b16(in_chans, embed_dim[0], activation=act_layer) + + self.blocks = [] + resolution = img_size // patch_size + for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate( + zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)): + for _ in range(dpth): + self.blocks.append( + Residual( + Attention( + ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer, + resolution=resolution, use_conv=use_conv), + drop_path)) + if mr > 0: + h = int(ed * mr) + self.blocks.append( + Residual(nn.Sequential( + ln_layer(ed, h, resolution=resolution), + act_layer(), + ln_layer(h, ed, bn_weight_init=0, resolution=resolution), + ), drop_path)) + if do[0] == 'Subsample': + # ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride) + resolution_ = (resolution - 1) // do[5] + 1 + self.blocks.append( + AttentionSubsample( + *embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2], + attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5], + resolution=resolution, resolution_=resolution_, use_conv=use_conv)) + resolution = resolution_ + if do[4] > 0: # mlp_ratio + h = int(embed_dim[i + 1] * do[4]) + self.blocks.append( + Residual(nn.Sequential( + ln_layer(embed_dim[i + 1], h, resolution=resolution), + act_layer(), + ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution), + ), drop_path)) + self.blocks = nn.Sequential(*self.blocks) + + # Classifier head + self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() + if distillation: + self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity() + else: + self.head_dist = None + + @torch.jit.ignore + def no_weight_decay(self): + return {x for x in self.state_dict().keys() if 'attention_biases' in x} + + def forward(self, x): + x = self.patch_embed(x) + if not self.use_conv: + x = x.flatten(2).transpose(1, 2) + x = self.blocks(x) + x = x.mean((-2, -1)) if self.use_conv else x.mean(1) + if self.head_dist is not None: + x, x_dist = self.head(x), self.head_dist(x) + if self.training and not torch.jit.is_scripting(): + return x, x_dist + else: + # during inference, return the average of both classifier predictions + return (x + x_dist) / 2 + else: + x = self.head(x) + return x + + +def checkpoint_filter_fn(state_dict, model): + if 'model' in state_dict: + # For deit models + state_dict = state_dict['model'] + D = model.state_dict() + for k in state_dict.keys(): + if D[k].ndim == 4 and state_dict[k].ndim == 2: + state_dict[k] = state_dict[k][:, :, None, None] + return state_dict + + +def create_levit(variant, pretrained=False, default_cfg=None, fuse=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model_cfg = dict(**model_cfgs[variant], **kwargs) + model = build_model_with_cfg( + Levit, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **model_cfg) + #if fuse: + # utils.replace_batchnorm(model) + return model + diff --git a/timm/models/mlp_mixer.py b/timm/models/mlp_mixer.py index 92ca115b..5a6dce6f 100644 --- a/timm/models/mlp_mixer.py +++ b/timm/models/mlp_mixer.py @@ -273,25 +273,14 @@ def _init_weights(m, n: str, head_bias: float = 0.): nn.init.ones_(m.weight) -def _create_mixer(variant, pretrained=False, default_cfg=None, **kwargs): - if default_cfg is None: - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] - num_classes = kwargs.pop('num_classes', default_num_classes) - img_size = kwargs.pop('img_size', default_img_size) - +def _create_mixer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for MLP-Mixer models.') model = build_model_with_cfg( MlpMixer, variant, pretrained, - default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, + default_cfg=default_cfgs[variant], **kwargs) - return model diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 3c21eea1..1b67581e 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -110,6 +110,12 @@ default_cfgs = dict( eca_nfnet_l1=_dcfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l1_ra2-7dce93cd.pth', pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 320, 320), crop_pct=1.0), + eca_nfnet_l2=_dcfg( + url='', + pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 352, 352), crop_pct=1.0), + eca_nfnet_l3=_dcfg( + url='', + pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), crop_pct=1.0), nf_regnet_b0=_dcfg( url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'), @@ -244,6 +250,12 @@ model_cfgs = dict( eca_nfnet_l1=_nfnet_cfg( depths=(2, 4, 12, 6), feat_mult=2, group_size=64, bottle_ratio=0.25, attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), + eca_nfnet_l2=_nfnet_cfg( + depths=(3, 6, 18, 9), feat_mult=2, group_size=64, bottle_ratio=0.25, + attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), + eca_nfnet_l3=_nfnet_cfg( + depths=(4, 8, 24, 12), feat_mult=2, group_size=64, bottle_ratio=0.25, + attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), # EffNet influenced RegNet defs. # NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8. @@ -814,6 +826,22 @@ def eca_nfnet_l1(pretrained=False, **kwargs): return _create_normfreenet('eca_nfnet_l1', pretrained=pretrained, **kwargs) +@register_model +def eca_nfnet_l2(pretrained=False, **kwargs): + """ ECA-NFNet-L2 w/ SiLU + My experimental 'light' model w/ F2 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn + """ + return _create_normfreenet('eca_nfnet_l2', pretrained=pretrained, **kwargs) + + +@register_model +def eca_nfnet_l3(pretrained=False, **kwargs): + """ ECA-NFNet-L3 w/ SiLU + My experimental 'light' model w/ F3 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn + """ + return _create_normfreenet('eca_nfnet_l3', pretrained=pretrained, **kwargs) + + @register_model def nf_regnet_b0(pretrained=False, **kwargs): """ Normalization-Free RegNet-B0 diff --git a/timm/models/pit.py b/timm/models/pit.py index 040d96db..9c350861 100644 --- a/timm/models/pit.py +++ b/timm/models/pit.py @@ -251,24 +251,14 @@ def checkpoint_filter_fn(state_dict, model): def _create_pit(variant, pretrained=False, **kwargs): - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] - img_size = kwargs.pop('img_size', default_img_size) - num_classes = kwargs.pop('num_classes', default_num_classes) - if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') model = build_model_with_cfg( PoolingVisionTransformer, variant, pretrained, - default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, + default_cfg=default_cfgs[variant], pretrained_filter_fn=checkpoint_filter_fn, **kwargs) - return model diff --git a/timm/models/registry.py b/timm/models/registry.py index 9172ac7e..6927b6d6 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -50,7 +50,7 @@ def _natural_key(string_): return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())] -def list_models(filter='', module='', pretrained=False, exclude_filters=''): +def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False): """ Return list of available model names, sorted alphabetically Args: @@ -58,6 +58,7 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''): module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet') pretrained (bool) - Include only models with pretrained weights if True exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter + name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases) Example: model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet' @@ -70,7 +71,7 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''): if filter: models = fnmatch.filter(models, filter) # include these models if exclude_filters: - if not isinstance(exclude_filters, list): + if not isinstance(exclude_filters, (tuple, list)): exclude_filters = [exclude_filters] for xf in exclude_filters: exclude_models = fnmatch.filter(models, xf) # exclude these models @@ -78,6 +79,8 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''): models = set(models).difference(exclude_models) if pretrained: models = _model_has_pretrained.intersection(models) + if name_matches_cfg: + models = set(_model_default_cfgs).intersection(models) return list(sorted(models, key=_natural_key)) diff --git a/timm/models/tnt.py b/timm/models/tnt.py index 8e038718..8186cc4a 100644 --- a/timm/models/tnt.py +++ b/timm/models/tnt.py @@ -12,7 +12,7 @@ import torch.nn as nn from functools import partial from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.models.helpers import load_pretrained +from timm.models.helpers import build_model_with_cfg from timm.models.layers import Mlp, DropPath, trunc_normal_ from timm.models.layers.helpers import to_2tuple from timm.models.registry import register_model @@ -238,24 +238,31 @@ def checkpoint_filter_fn(state_dict, model): return state_dict +def _create_tnt(variant, pretrained=False, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + + model = build_model_with_cfg( + TNT, variant, pretrained, + default_cfg=default_cfgs[variant], + pretrained_filter_fn=checkpoint_filter_fn, + **kwargs) + return model + + @register_model def tnt_s_patch16_224(pretrained=False, **kwargs): - model = TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, + model_cfg = dict( + patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4, qkv_bias=False, **kwargs) - model.default_cfg = default_cfgs['tnt_s_patch16_224'] - if pretrained: - load_pretrained( - model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3), - filter_fn=checkpoint_filter_fn) + model = _create_tnt('tnt_s_patch16_224', pretrained=pretrained, **model_cfg) return model @register_model def tnt_b_patch16_224(pretrained=False, **kwargs): - model = TNT(patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4, + model_cfg = dict( + patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4, qkv_bias=False, **kwargs) - model.default_cfg = default_cfgs['tnt_b_patch16_224'] - if pretrained: - load_pretrained( - model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3)) + model = _create_tnt('tnt_b_patch16_224', pretrained=pretrained, **model_cfg) return model diff --git a/timm/models/twins.py b/timm/models/twins.py index a534d174..793d2ede 100644 --- a/timm/models/twins.py +++ b/timm/models/twins.py @@ -33,7 +33,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'patch_embed.proj', 'classifier': 'head', + 'first_conv': 'patch_embeds.0.proj', 'classifier': 'head', **kwargs } @@ -361,25 +361,14 @@ class Twins(nn.Module): return x -def _create_twins(variant, pretrained=False, default_cfg=None, **kwargs): - if default_cfg is None: - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] - - num_classes = kwargs.pop('num_classes', default_num_classes) - img_size = kwargs.pop('img_size', default_img_size) +def _create_twins(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') model = build_model_with_cfg( Twins, variant, pretrained, - default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, + default_cfg=default_cfgs[variant], **kwargs) - return model diff --git a/timm/models/visformer.py b/timm/models/visformer.py new file mode 100644 index 00000000..33a2fe87 --- /dev/null +++ b/timm/models/visformer.py @@ -0,0 +1,414 @@ +""" Visformer + +Paper: Visformer: The Vision-friendly Transformer - https://arxiv.org/abs/2104.12533 + +From original at https://github.com/danczs/Visformer + +""" +from copy import deepcopy + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg, overlay_external_default_cfg +from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed +from .registry import register_model + + +__all__ = ['Visformer'] + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'head', + **kwargs + } + + +default_cfgs = dict( + visformer_tiny=_cfg(), + visformer_small=_cfg( + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/visformer_small-839e1f5b.pth' + ), +) + + +class LayerNormBHWC(nn.LayerNorm): + def __init__(self, dim): + super().__init__(dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.layer_norm( + x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + + +class SpatialMlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, + act_layer=nn.GELU, drop=0., group=8, spatial_conv=False): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.in_features = in_features + self.out_features = out_features + self.spatial_conv = spatial_conv + if self.spatial_conv: + if group < 2: # net setting + hidden_features = in_features * 5 // 6 + else: + hidden_features = in_features * 2 + self.hidden_features = hidden_features + self.group = group + self.drop = nn.Dropout(drop) + self.conv1 = nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0, bias=False) + self.act1 = act_layer() + if self.spatial_conv: + self.conv2 = nn.Conv2d( + hidden_features, hidden_features, 3, stride=1, padding=1, groups=self.group, bias=False) + self.act2 = act_layer() + else: + self.conv2 = None + self.act2 = None + self.conv3 = nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0, bias=False) + + def forward(self, x): + x = self.conv1(x) + x = self.act1(x) + x = self.drop(x) + if self.conv2 is not None: + x = self.conv2(x) + x = self.act2(x) + x = self.conv3(x) + x = self.drop(x) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, head_dim_ratio=1., attn_drop=0., proj_drop=0.): + super().__init__() + self.dim = dim + self.num_heads = num_heads + head_dim = round(dim // num_heads * head_dim_ratio) + self.head_dim = head_dim + self.scale = head_dim ** -0.5 + self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Conv2d(self.head_dim * self.num_heads, dim, 1, stride=1, padding=0, bias=False) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, C, H, W = x.shape + x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3) + q, k, v = x[0], x[1], x[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class Block(nn.Module): + def __init__(self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4., + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormBHWC, + group=8, attn_disabled=False, spatial_conv=False): + super().__init__() + self.spatial_conv = spatial_conv + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + if attn_disabled: + self.norm1 = None + self.attn = None + else: + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, num_heads=num_heads, head_dim_ratio=head_dim_ratio, attn_drop=attn_drop, proj_drop=drop) + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = SpatialMlp( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, + group=group, spatial_conv=spatial_conv) # new setting + + def forward(self, x): + if self.attn is not None: + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Visformer(nn.Module): + def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384, + depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0., + norm_layer=LayerNormBHWC, attn_stage='111', pos_embed=True, spatial_conv='111', + vit_stem=False, group=8, pool=True, conv_init=False, embed_norm=None): + super().__init__() + self.num_classes = num_classes + self.num_features = self.embed_dim = embed_dim + self.init_channels = init_channels + self.img_size = img_size + self.vit_stem = vit_stem + self.pool = pool + self.conv_init = conv_init + if isinstance(depth, (list, tuple)): + self.stage_num1, self.stage_num2, self.stage_num3 = depth + depth = sum(depth) + else: + self.stage_num1 = self.stage_num3 = depth // 3 + self.stage_num2 = depth - self.stage_num1 - self.stage_num3 + self.pos_embed = pos_embed + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + # stage 1 + if self.vit_stem: + self.stem = None + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) + img_size //= 16 + else: + if self.init_channels is None: + self.stem = None + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans, + embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) + img_size //= 8 + else: + self.stem = nn.Sequential( + nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False), + nn.BatchNorm2d(self.init_channels), + nn.ReLU(inplace=True) + ) + img_size //= 2 + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels, + embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False) + img_size //= 4 + + if self.pos_embed: + if self.vit_stem: + self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size)) + else: + self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, img_size, img_size)) + self.pos_drop = nn.Dropout(p=drop_rate) + self.stage1 = nn.ModuleList([ + Block( + dim=embed_dim//2, num_heads=num_heads, head_dim_ratio=0.5, mlp_ratio=mlp_ratio, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + group=group, attn_disabled=(attn_stage[0] == '0'), spatial_conv=(spatial_conv[0] == '1') + ) + for i in range(self.stage_num1) + ]) + + #stage2 + if not self.vit_stem: + self.patch_embed2 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2, + embed_dim=embed_dim, norm_layer=embed_norm, flatten=False) + img_size //= 2 + if self.pos_embed: + self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size)) + self.stage2 = nn.ModuleList([ + Block( + dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + group=group, attn_disabled=(attn_stage[1] == '0'), spatial_conv=(spatial_conv[1] == '1') + ) + for i in range(self.stage_num1, self.stage_num1+self.stage_num2) + ]) + + # stage 3 + if not self.vit_stem: + self.patch_embed3 = PatchEmbed( + img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim, + embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False) + img_size //= 2 + if self.pos_embed: + self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, img_size, img_size)) + self.stage3 = nn.ModuleList([ + Block( + dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, + group=group, attn_disabled=(attn_stage[2] == '0'), spatial_conv=(spatial_conv[2] == '1') + ) + for i in range(self.stage_num1+self.stage_num2, depth) + ]) + + # head + if self.pool: + self.global_pooling = nn.AdaptiveAvgPool2d(1) + head_dim = embed_dim if self.vit_stem else embed_dim * 2 + self.norm = norm_layer(head_dim) + self.head = nn.Linear(head_dim, num_classes) + + # weights init + if self.pos_embed: + trunc_normal_(self.pos_embed1, std=0.02) + if not self.vit_stem: + trunc_normal_(self.pos_embed2, std=0.02) + trunc_normal_(self.pos_embed3, std=0.02) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + elif isinstance(m, nn.Conv2d): + if self.conv_init: + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + else: + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0.) + + def forward(self, x): + if self.stem is not None: + x = self.stem(x) + + # stage 1 + x = self.patch_embed1(x) + if self.pos_embed: + x = x + self.pos_embed1 + x = self.pos_drop(x) + for b in self.stage1: + x = b(x) + + # stage 2 + if not self.vit_stem: + x = self.patch_embed2(x) + if self.pos_embed: + x = x + self.pos_embed2 + x = self.pos_drop(x) + for b in self.stage2: + x = b(x) + + # stage3 + if not self.vit_stem: + x = self.patch_embed3(x) + if self.pos_embed: + x = x + self.pos_embed3 + x = self.pos_drop(x) + for b in self.stage3: + x = b(x) + + # head + x = self.norm(x) + if self.pool: + x = self.global_pooling(x) + else: + x = x[:, :, 0, 0] + + x = self.head(x.view(x.size(0), -1)) + return x + + +def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs): + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') + model = build_model_with_cfg( + Visformer, variant, pretrained, + default_cfg=default_cfgs[variant], + **kwargs) + return model + + +@register_model +def visformer_tiny(pretrained=False, **kwargs): + model_cfg = dict( + img_size=224, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8, + attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, + embed_norm=nn.BatchNorm2d, **kwargs) + model = _create_visformer('visformer_tiny', pretrained=pretrained, **model_cfg) + return model + + +@register_model +def visformer_small(pretrained=False, **kwargs): + model_cfg = dict( + img_size=224, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8, + attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True, + embed_norm=nn.BatchNorm2d, **kwargs) + model = _create_visformer('visformer_small', pretrained=pretrained, **model_cfg) + return model + + +# @register_model +# def visformer_net1(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=True, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net2(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net3(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net4(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111', +# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net5(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', +# spatial_conv='111', vit_stem=False, conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net6(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111', +# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model +# +# +# @register_model +# def visformer_net7(pretrained=False, **kwargs): +# model = Visformer( +# init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000', +# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs) +# model.default_cfg = _cfg() +# return model + + + + diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index bef6dfb0..ff74d836 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -387,21 +387,20 @@ def checkpoint_filter_fn(state_dict, model): v = v.reshape(O, -1, H, W) elif k == 'pos_embed' and v.shape != model.pos_embed.shape: # To resize pos embedding when using model at different size from pretrained weights - v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1), - model.patch_embed.grid_size) + v = resize_pos_embed( + v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size) out_dict[k] = v return out_dict def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs): - if default_cfg is None: - default_cfg = deepcopy(default_cfgs[variant]) - overlay_external_default_cfg(default_cfg, kwargs) - default_num_classes = default_cfg['num_classes'] - default_img_size = default_cfg['input_size'][-2:] + default_cfg = default_cfg or default_cfgs[variant] + if kwargs.get('features_only', None): + raise RuntimeError('features_only not implemented for Vision Transformer models.') - num_classes = kwargs.pop('num_classes', default_num_classes) - img_size = kwargs.pop('img_size', default_img_size) + # NOTE this extra code to support handling of repr size for in21k pretrained models + default_num_classes = default_cfg['num_classes'] + num_classes = kwargs.get('num_classes', default_num_classes) repr_size = kwargs.pop('representation_size', None) if repr_size is not None and num_classes != default_num_classes: # Remove representation layer if fine-tuning. This may not always be the desired action, @@ -409,18 +408,12 @@ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kw _logger.warning("Removing representation layer for fine-tuning.") repr_size = None - if kwargs.get('features_only', None): - raise RuntimeError('features_only not implemented for Vision Transformer models.') - model = build_model_with_cfg( VisionTransformer, variant, pretrained, default_cfg=default_cfg, - img_size=img_size, - num_classes=num_classes, representation_size=repr_size, pretrained_filter_fn=checkpoint_filter_fn, **kwargs) - return model diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index 1656559f..9e5a62b2 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -27,7 +27,7 @@ def _cfg(url='', **kwargs): return { 'url': url, 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None, - 'crop_pct': .9, 'interpolation': 'bicubic', + 'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True, 'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5), 'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head', **kwargs @@ -107,11 +107,10 @@ class HybridEmbed(nn.Module): def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs): - default_cfg = deepcopy(default_cfgs[variant]) embed_layer = partial(HybridEmbed, backbone=backbone) kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set return _create_vision_transformer( - variant, pretrained=pretrained, default_cfg=default_cfg, embed_layer=embed_layer, **kwargs) + variant, pretrained=pretrained, embed_layer=embed_layer, default_cfg=default_cfgs[variant], **kwargs) def _resnetv2(layers=(3, 4, 9), **kwargs): diff --git a/timm/version.py b/timm/version.py index 2d802716..b94cbb01 100644 --- a/timm/version.py +++ b/timm/version.py @@ -1 +1 @@ -__version__ = '0.4.9' +__version__ = '0.4.10'