Merge branch 'master' into bits_and_tpu

pull/1239/head
Ross Wightman 3 years ago
commit b57a03bd0d

@ -17,8 +17,8 @@ jobs:
matrix:
os: [ubuntu-latest, macOS-latest]
python: ['3.8']
torch: ['1.8.0']
torchvision: ['0.9.0']
torch: ['1.8.1']
torchvision: ['0.9.1']
runs-on: ${{ matrix.os }}
steps:
@ -36,11 +36,16 @@ jobs:
run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }}
- name: Install torch on ubuntu
if: startsWith(matrix.os, 'ubuntu')
run: pip install --no-cache-dir torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
run: |
pip install --no-cache-dir torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
sudo apt update
sudo apt install -y google-perftools
- name: Install requirements
run: |
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install --no-cache-dir git+https://github.com/mapillary/inplace_abn.git@v1.0.12
- name: Run tests
env:
LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
run: |
pytest -vv --durations=0 ./tests

@ -23,6 +23,14 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
## What's New
### May 25, 2021
* Add LeViT, Visformer, ConViT (PR by Aman Arora), Twins (PR by paper authors) transformer models
* Add ResMLP and gMLP MLP vision models to the existing MLP Mixer impl
* Fix a number of torchscript issues with various vision transformer models
* Cleanup input_size/img_size override handling and improve testing / test coverage for all vision transformer and MLP models
* More flexible pos embedding resize (non-square) for ViT and TnT. Thanks [Alexander Soare](https://github.com/alexander-soare)
* Add `efficientnetv2_rw_m` model and weights (started training before official code). 84.8 top-1, 53M params.
### May 14, 2021
* Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl.
* 1k trained variants: `tf_efficientnetv2_s/m/l`
@ -39,7 +47,7 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
* Add CoaT models and weights. Thanks [Mohammed Rizin](https://github.com/morizin)
* Add new ImageNet-21k weights & finetuned weights for TResNet, MobileNet-V3, ViT models. Thanks [mrT](https://github.com/mrT23)
* Add GhostNet models and weights. Thanks [Kai Han](https://github.com/iamhankai)
* Update ByoaNet attention modles
* Update ByoaNet attention modules
* Improve SA module inits
* Hack together experimental stand-alone Swin based attn module and `swinnet`
* Consistent '26t' model defs for experiments.
@ -166,30 +174,6 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
* Misc fixes for SiLU ONNX export, default_cfg missing from Feature extraction models, Linear layer w/ AMP + torchscript
* PyPi release @ 0.3.2 (needed by EfficientDet)
### Oct 30, 2020
* Test with PyTorch 1.7 and fix a small top-n metric view vs reshape issue.
* Convert newly added 224x224 Vision Transformer weights from official JAX repo. 81.8 top-1 for B/16, 83.1 L/16.
* Support PyTorch 1.7 optimized, native SiLU (aka Swish) activation. Add mapping to 'silu' name, custom swish will eventually be deprecated.
* Fix regression for loading pretrained classifier via direct model entrypoint functions. Didn't impact create_model() factory usage.
* PyPi release @ 0.3.0 version!
### Oct 26, 2020
* Update Vision Transformer models to be compatible with official code release at https://github.com/google-research/vision_transformer
* Add Vision Transformer weights (ImageNet-21k pretrain) for 384x384 base and large models converted from official jax impl
* ViT-B/16 - 84.2
* ViT-B/32 - 81.7
* ViT-L/16 - 85.2
* ViT-L/32 - 81.5
### Oct 21, 2020
* Weights added for Vision Transformer (ViT) models. 77.86 top-1 for 'small' and 79.35 for 'base'. Thanks to [Christof](https://www.kaggle.com/christofhenkel) for training the base model w/ lots of GPUs.
### Oct 13, 2020
* Initial impl of Vision Transformer models. Both patch and hybrid (CNN backbone) variants. Currently trying to train...
* Adafactor and AdaHessian (FP32 only, no AMP) optimizers
* EdgeTPU-M (`efficientnet_em`) model trained in PyTorch, 79.3 top-1
* Pip release, doc updates pending a few more changes...
## Introduction
@ -207,6 +191,7 @@ A full version of the list below with source links can be found in the [document
* Bottleneck Transformers - https://arxiv.org/abs/2101.11605
* CaiT (Class-Attention in Image Transformers) - https://arxiv.org/abs/2103.17239
* CoaT (Co-Scale Conv-Attentional Image Transformers) - https://arxiv.org/abs/2104.06399
* ConViT (Soft Convolutional Inductive Biases Vision Transformers)- https://arxiv.org/abs/2103.10697
* CspNet (Cross-Stage Partial Networks) - https://arxiv.org/abs/1911.11929
* DeiT (Vision Transformer) - https://arxiv.org/abs/2012.12877
* DenseNet - https://arxiv.org/abs/1608.06993
@ -224,6 +209,7 @@ A full version of the list below with source links can be found in the [document
* MobileNet-V2 - https://arxiv.org/abs/1801.04381
* Single-Path NAS - https://arxiv.org/abs/1904.02877
* GhostNet - https://arxiv.org/abs/1911.11907
* gMLP - https://arxiv.org/abs/2105.08050
* GPU-Efficient Networks - https://arxiv.org/abs/2006.14090
* Halo Nets - https://arxiv.org/abs/2103.12731
* HardCoRe-NAS - https://arxiv.org/abs/2102.11646
@ -231,6 +217,7 @@ A full version of the list below with source links can be found in the [document
* Inception-V3 - https://arxiv.org/abs/1512.00567
* Inception-ResNet-V2 and Inception-V4 - https://arxiv.org/abs/1602.07261
* Lambda Networks - https://arxiv.org/abs/2102.08602
* LeViT (Vision Transformer in ConvNet's Clothing) - https://arxiv.org/abs/2104.01136
* MLP-Mixer - https://arxiv.org/abs/2105.01601
* MobileNet-V3 (MBConvNet w/ Efficient Head) - https://arxiv.org/abs/1905.02244
* NASNet-A - https://arxiv.org/abs/1707.07012
@ -240,6 +227,7 @@ A full version of the list below with source links can be found in the [document
* Pooling-based Vision Transformer (PiT) - https://arxiv.org/abs/2103.16302
* RegNet - https://arxiv.org/abs/2003.13678
* RepVGG - https://arxiv.org/abs/2101.03697
* ResMLP - https://arxiv.org/abs/2105.03404
* ResNet/ResNeXt
* ResNet (v1b/v1.5) - https://arxiv.org/abs/1512.03385
* ResNeXt - https://arxiv.org/abs/1611.05431
@ -257,6 +245,7 @@ A full version of the list below with source links can be found in the [document
* Swin Transformer - https://arxiv.org/abs/2103.14030
* Transformer-iN-Transformer (TNT) - https://arxiv.org/abs/2103.00112
* TResNet - https://arxiv.org/abs/2003.13630
* Twins (Spatial Attention in Vision Transformers) - https://arxiv.org/pdf/2104.13840.pdf
* Vision Transformer - https://arxiv.org/abs/2010.11929
* VovNet V2 and V1 - https://arxiv.org/abs/1911.06667
* Xception - https://arxiv.org/abs/1610.02357
@ -282,7 +271,7 @@ Several (less common) features that I often utilize in my projects are included.
* PyTorch DistributedDataParallel w/ multi-gpu, single process (AMP disabled as it crashes when enabled)
* PyTorch w/ single GPU single process (AMP optional)
* A dynamic global pool implementation that allows selecting from average pooling, max pooling, average + max, or concat([average, max]) at model creation. All global pooling is adaptive average by default and compatible with pretrained weights.
* A 'Test Time Pool' wrapper that can wrap any of the included models and usually provide improved performance doing inference with input images larger than the training size. Idea adapted from original DPN implementation when I ported (https://github.com/cypw/DPNs)
* A 'Test Time Pool' wrapper that can wrap any of the included models and usually provides improved performance doing inference with input images larger than the training size. Idea adapted from original DPN implementation when I ported (https://github.com/cypw/DPNs)
* Learning rate schedulers
* Ideas adopted from
* [AllenNLP schedulers](https://github.com/allenai/allennlp/tree/master/allennlp/training/learning_rate_schedulers)
@ -306,10 +295,24 @@ Several (less common) features that I often utilize in my projects are included.
* SplitBachNorm - allows splitting batch norm layers between clean and augmented (auxiliary batch norm) data
* DropPath aka "Stochastic Depth" (https://arxiv.org/abs/1603.09382)
* DropBlock (https://arxiv.org/abs/1810.12890)
* Efficient Channel Attention - ECA (https://arxiv.org/abs/1910.03151)
* Blur Pooling (https://arxiv.org/abs/1904.11486)
* Space-to-Depth by [mrT23](https://github.com/mrT23/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py) (https://arxiv.org/abs/1801.04590) -- original paper?
* Adaptive Gradient Clipping (https://arxiv.org/abs/2102.06171, https://github.com/deepmind/deepmind-research/tree/master/nfnets)
* An extensive selection of channel and/or spatial attention modules:
* Bottleneck Transformer - https://arxiv.org/abs/2101.11605
* CBAM - https://arxiv.org/abs/1807.06521
* Effective Squeeze-Excitation (ESE) - https://arxiv.org/abs/1911.06667
* Efficient Channel Attention (ECA) - https://arxiv.org/abs/1910.03151
* Gather-Excite (GE) - https://arxiv.org/abs/1810.12348
* Global Context (GC) - https://arxiv.org/abs/1904.11492
* Halo - https://arxiv.org/abs/2103.12731
* Involution - https://arxiv.org/abs/2103.06255
* Lambda Layer - https://arxiv.org/abs/2102.08602
* Non-Local (NL) - https://arxiv.org/abs/1711.07971
* Squeeze-and-Excitation (SE) - https://arxiv.org/abs/1709.01507
* Selective Kernel (SK) - (https://arxiv.org/abs/1903.06586
* Split (SPLAT) - https://arxiv.org/abs/2004.08955
* Shifted Window (SWIN) - https://arxiv.org/abs/2103.14030
## Results
@ -329,7 +332,7 @@ The root folder of the repository contains reference train, validation, and infe
## Awesome PyTorch Resources
One of the greatest assets of PyTorch is the community and their contributions. A few of my favourite resources that pair well with the models and componenets here are listed below.
One of the greatest assets of PyTorch is the community and their contributions. A few of my favourite resources that pair well with the models and components here are listed below.
### Object Detection, Instance and Semantic Segmentation
* Detectron2 - https://github.com/facebookresearch/detectron2
@ -353,7 +356,7 @@ One of the greatest assets of PyTorch is the community and their contributions.
## Licenses
### Code
The code here is licensed Apache 2.0. I've taken care to make sure any third party code included or adapted has compatible (permissive) licenses such as MIT, BSD, etc. I've made an effort to avoid any GPL / LGPL conflicts. That said, it is your responsibility to ensure you comply with license here and conditions of any dependent licenses. Where applicable, I've linked the sources/references for various components in docstrings. If you think I've missed anything please create an issue.
The code here is licensed Apache 2.0. I've taken care to make sure any third party code included or adapted has compatible (permissive) licenses such as MIT, BSD, etc. I've made an effort to avoid any GPL / LGPL conflicts. That said, it is your responsibility to ensure you comply with licenses here and conditions of any dependent licenses. Where applicable, I've linked the sources/references for various components in docstrings. If you think I've missed anything please create an issue.
### Pretrained Weights
So far all of the pretrained weights available here are pretrained on ImageNet with a select few that have some additional pretraining (see extra note below). ImageNet was released for non-commercial research purposes only (http://www.image-net.org/download-faq). It's not clear what the implications of that are for the use of pretrained weights from that dataset. Any models I have trained with ImageNet are done for research purposes and one should assume that the original dataset license applies to the weights. It's best to seek legal advice if you intend to use the pretrained weights in a commercial product.

@ -1,5 +1,29 @@
# Archived Changes
### Oct 30, 2020
* Test with PyTorch 1.7 and fix a small top-n metric view vs reshape issue.
* Convert newly added 224x224 Vision Transformer weights from official JAX repo. 81.8 top-1 for B/16, 83.1 L/16.
* Support PyTorch 1.7 optimized, native SiLU (aka Swish) activation. Add mapping to 'silu' name, custom swish will eventually be deprecated.
* Fix regression for loading pretrained classifier via direct model entrypoint functions. Didn't impact create_model() factory usage.
* PyPi release @ 0.3.0 version!
### Oct 26, 2020
* Update Vision Transformer models to be compatible with official code release at https://github.com/google-research/vision_transformer
* Add Vision Transformer weights (ImageNet-21k pretrain) for 384x384 base and large models converted from official jax impl
* ViT-B/16 - 84.2
* ViT-B/32 - 81.7
* ViT-L/16 - 85.2
* ViT-L/32 - 81.5
### Oct 21, 2020
* Weights added for Vision Transformer (ViT) models. 77.86 top-1 for 'small' and 79.35 for 'base'. Thanks to [Christof](https://www.kaggle.com/christofhenkel) for training the base model w/ lots of GPUs.
### Oct 13, 2020
* Initial impl of Vision Transformer models. Both patch and hybrid (CNN backbone) variants. Currently trying to train...
* Adafactor and AdaHessian (FP32 only, no AMP) optimizers
* EdgeTPU-M (`efficientnet_em`) model trained in PyTorch, 79.3 top-1
* Pip release, doc updates pending a few more changes...
### Sept 18, 2020
* New ResNet 'D' weights. 72.7 (top-1) ResNet-18-D, 77.1 ResNet-34-D, 80.5 ResNet-50-D
* Added a few untrained defs for other ResNet models (66D, 101D, 152D, 200/200D)

@ -1,5 +1,33 @@
# Recent Changes
### May 25, 2021
* Add LeViT, Visformer, Convit (PR by Aman Arora), Twins (PR by paper authors) transformer models
* Cleanup input_size/img_size override handling and testing for all vision transformer models
* Add `efficientnetv2_rw_m` model and weights (started training before official code). 84.8 top-1, 53M params.
### May 14, 2021
* Add EfficientNet-V2 official model defs w/ ported weights from official [Tensorflow/Keras](https://github.com/google/automl/tree/master/efficientnetv2) impl.
* 1k trained variants: `tf_efficientnetv2_s/m/l`
* 21k trained variants: `tf_efficientnetv2_s/m/l_in21k`
* 21k pretrained -> 1k fine-tuned: `tf_efficientnetv2_s/m/l_in21ft1k`
* v2 models w/ v1 scaling: `tf_efficientnetv2_b0` through `b3`
* Rename my prev V2 guess `efficientnet_v2s` -> `efficientnetv2_rw_s`
* Some blank `efficientnetv2_*` models in-place for future native PyTorch training
### May 5, 2021
* Add MLP-Mixer models and port pretrained weights from [Google JAX impl](https://github.com/google-research/vision_transformer/tree/linen)
* Add CaiT models and pretrained weights from [FB](https://github.com/facebookresearch/deit)
* Add ResNet-RS models and weights from [TF](https://github.com/tensorflow/tpu/tree/master/models/official/resnet/resnet_rs). Thanks [Aman Arora](https://github.com/amaarora)
* Add CoaT models and weights. Thanks [Mohammed Rizin](https://github.com/morizin)
* Add new ImageNet-21k weights & finetuned weights for TResNet, MobileNet-V3, ViT models. Thanks [mrT](https://github.com/mrT23)
* Add GhostNet models and weights. Thanks [Kai Han](https://github.com/iamhankai)
* Update ByoaNet attention modles
* Improve SA module inits
* Hack together experimental stand-alone Swin based attn module and `swinnet`
* Consistent '26t' model defs for experiments.
* Add improved Efficientnet-V2S (prelim model def) weights. 83.8 top-1.
* WandB logging support
### April 13, 2021
* Add Swin Transformer models and weights from https://github.com/microsoft/Swin-Transformer

@ -8,10 +8,10 @@ from timm.models.layers import create_act_layer, get_act_layer, set_layer_config
class MLP(nn.Module):
def __init__(self, act_layer="relu"):
def __init__(self, act_layer="relu", inplace=True):
super(MLP, self).__init__()
self.fc1 = nn.Linear(1000, 100)
self.act = create_act_layer(act_layer, inplace=True)
self.act = create_act_layer(act_layer, inplace=inplace)
self.fc2 = nn.Linear(100, 10)
def forward(self, x):
@ -21,14 +21,14 @@ class MLP(nn.Module):
return x
def _run_act_layer_grad(act_type):
def _run_act_layer_grad(act_type, inplace=True):
x = torch.rand(10, 1000) * 10
m = MLP(act_layer=act_type)
m = MLP(act_layer=act_type, inplace=inplace)
def _run(x, act_layer=''):
if act_layer:
# replace act layer if set
m.act = create_act_layer(act_layer, inplace=True)
m.act = create_act_layer(act_layer, inplace=inplace)
out = m(x)
l = (out - 0).pow(2).sum()
return l
@ -58,7 +58,7 @@ def test_mish_grad():
def test_hard_sigmoid_grad():
for _ in range(100):
_run_act_layer_grad('hard_sigmoid')
_run_act_layer_grad('hard_sigmoid', inplace=None)
def test_hard_swish_grad():

@ -15,38 +15,67 @@ if hasattr(torch._C, '_jit_set_profiling_executor'):
torch._C._jit_set_profiling_mode(False)
# transformer models don't support many of the spatial / feature based model functionalities
NON_STD_FILTERS = ['vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*']
NON_STD_FILTERS = [
'vit_*', 'tnt_*', 'pit_*', 'swin_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
'convit_*', 'levit*', 'visformer*']
NUM_NON_STD = len(NON_STD_FILTERS)
# exclude models that cause specific test failures
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm',
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*',
'*resnetrs350*', '*resnetrs420*'] + NON_STD_FILTERS
'*resnetrs350*', '*resnetrs420*']
else:
EXCLUDE_FILTERS = NON_STD_FILTERS
EXCLUDE_FILTERS = []
MAX_FWD_SIZE = 384
MAX_BWD_SIZE = 128
MAX_FWD_FEAT_SIZE = 448
TARGET_FWD_SIZE = MAX_FWD_SIZE = 384
TARGET_BWD_SIZE = 128
MAX_BWD_SIZE = 320
MAX_FWD_OUT_SIZE = 448
TARGET_JIT_SIZE = 128
MAX_JIT_SIZE = 320
TARGET_FFEAT_SIZE = 96
MAX_FFEAT_SIZE = 256
def _get_input_size(model=None, model_name='', target=None):
if model is None:
assert model_name, "One of model or model_name must be provided"
input_size = get_model_default_value(model_name, 'input_size')
fixed_input_size = get_model_default_value(model_name, 'fixed_input_size')
min_input_size = get_model_default_value(model_name, 'min_input_size')
else:
default_cfg = model.default_cfg
input_size = default_cfg['input_size']
fixed_input_size = default_cfg.get('fixed_input_size', None)
min_input_size = default_cfg.get('min_input_size', None)
assert input_size is not None
if fixed_input_size:
return input_size
if min_input_size:
if target and max(input_size) > target:
input_size = min_input_size
else:
if target and max(input_size) > target:
input_size = tuple([min(x, target) for x in input_size])
return input_size
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-NUM_NON_STD]))
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward(model_name, batch_size):
"""Run a single forward pass with each model"""
model = create_model(model_name, pretrained=False)
model.eval()
input_size = model.default_cfg['input_size']
if any([x > MAX_FWD_SIZE for x in input_size]):
if is_model_default_key(model_name, 'fixed_input_size'):
pytest.skip("Fixed input size model > limit.")
# cap forward test at max res 384 * 384 to keep resource down
input_size = tuple([min(x, MAX_FWD_SIZE) for x in input_size])
input_size = _get_input_size(model=model, target=TARGET_FWD_SIZE)
if max(input_size) > MAX_FWD_SIZE:
pytest.skip("Fixed input size model > limit.")
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
@ -55,26 +84,22 @@ def test_model_forward(model_name, batch_size):
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS))
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [2])
def test_model_backward(model_name, batch_size):
"""Run a single forward pass with each model"""
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE)
if max(input_size) > MAX_BWD_SIZE:
pytest.skip("Fixed input size model > limit.")
model = create_model(model_name, pretrained=False, num_classes=42)
num_params = sum([x.numel() for x in model.parameters()])
model.eval()
input_size = model.default_cfg['input_size']
if not is_model_default_key(model_name, 'fixed_input_size'):
min_input_size = get_model_default_value(model_name, 'min_input_size')
if min_input_size is not None:
input_size = min_input_size
else:
if any([x > MAX_BWD_SIZE for x in input_size]):
# cap backward test at 128 * 128 to keep resource usage down
input_size = tuple([min(x, MAX_BWD_SIZE) for x in input_size])
model.train()
inputs = torch.randn((batch_size, *input_size))
outputs = model(inputs)
if isinstance(outputs, tuple):
outputs = torch.cat(outputs)
outputs.mean().backward()
for n, x in model.named_parameters():
assert x.grad is not None, f'No gradient for {n}'
@ -85,7 +110,7 @@ def test_model_backward(model_name, batch_size):
assert not torch.isnan(outputs).any(), 'Output included NaNs'
@pytest.mark.timeout(120)
@pytest.mark.timeout(300)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS))
@pytest.mark.parametrize('batch_size', [1])
def test_model_default_cfgs(model_name, batch_size):
@ -99,10 +124,10 @@ def test_model_default_cfgs(model_name, batch_size):
pool_size = cfg['pool_size']
input_size = model.default_cfg['input_size']
if all([x <= MAX_FWD_FEAT_SIZE for x in input_size]) and \
if all([x <= MAX_FWD_OUT_SIZE for x in input_size]) and \
not any([fnmatch.fnmatch(model_name, x) for x in EXCLUDE_FILTERS]):
# output sizes only checked if default res <= 448 * 448 to keep resource down
input_size = tuple([min(x, MAX_FWD_FEAT_SIZE) for x in input_size])
input_size = tuple([min(x, MAX_FWD_OUT_SIZE) for x in input_size])
input_tensor = torch.randn((batch_size, *input_size))
# test forward_features (always unpooled)
@ -153,26 +178,25 @@ if 'GITHUB_ACTIONS' not in os.environ:
EXCLUDE_JIT_FILTERS = [
'*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable
'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point
'dla*', 'hrnet*', 'ghostnet*', # hopefully fix at some point
'vit_large_*', 'vit_huge_*',
]
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS))
@pytest.mark.parametrize(
'model_name', list_models(exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS, name_matches_cfg=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_torchscript(model_name, batch_size):
"""Run a single forward pass with each model"""
input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
if max(input_size) > MAX_JIT_SIZE:
pytest.skip("Fixed input size model > limit.")
with set_scriptable(True):
model = create_model(model_name, pretrained=False)
model.eval()
if has_model_default_key(model_name, 'fixed_input_size'):
input_size = get_model_default_value(model_name, 'input_size')
elif has_model_default_key(model_name, 'min_input_size'):
input_size = get_model_default_value(model_name, 'min_input_size')
else:
input_size = (3, 128, 128) # jit compile is already a bit slow and we've tested normal res already...
model = torch.jit.script(model)
outputs = model(torch.randn((batch_size, *input_size)))
@ -182,7 +206,7 @@ def test_model_forward_torchscript(model_name, batch_size):
EXCLUDE_FEAT_FILTERS = [
'*pruned*', # hopefully fix at some point
]
] + NON_STD_FILTERS
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
EXCLUDE_FEAT_FILTERS += ['*resnext101_32x32d', '*resnext101_32x16d']
@ -198,12 +222,9 @@ def test_model_forward_features(model_name, batch_size):
expected_channels = model.feature_info.channels()
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
if has_model_default_key(model_name, 'fixed_input_size'):
input_size = get_model_default_value(model_name, 'input_size')
elif has_model_default_key(model_name, 'min_input_size'):
input_size = get_model_default_value(model_name, 'min_input_size')
else:
input_size = (3, 96, 96) # jit compile is already a bit slow and we've tested normal res already...
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
if max(input_size) > MAX_FFEAT_SIZE:
pytest.skip("Fixed input size model > limit.")
outputs = model(torch.randn((batch_size, *input_size)))
assert len(expected_channels) == len(outputs)

@ -26,8 +26,8 @@ from .parser import Parser
from timm.bits import get_global_device
MAX_TP_SIZE = 8 # maximum TF threadpool size, only doing jpeg decodes and queuing activities
SHUFFLE_SIZE = 16834 # samples to shuffle in DS queue
PREFETCH_SIZE = 4096 # samples to prefetch
SHUFFLE_SIZE = 20480 # samples to shuffle in DS queue
PREFETCH_SIZE = 2048 # samples to prefetch
def even_split_indices(split, n, num_samples):
@ -159,7 +159,7 @@ class ParserTfds(Parser):
# see warnings at https://pytorch.org/docs/stable/data.html#multi-process-data-loading
ds = ds.repeat() # allow wrap around and break iteration manually
if self.shuffle:
ds = ds.shuffle(min(self.num_samples // self._num_pipelines, SHUFFLE_SIZE), seed=0)
ds = ds.shuffle(min(self.num_samples, SHUFFLE_SIZE) // self._num_pipelines, seed=0)
ds = ds.prefetch(min(self.num_samples // self._num_pipelines, PREFETCH_SIZE))
self.ds = tfds.as_numpy(ds)

@ -2,6 +2,7 @@ from .byoanet import *
from .byobnet import *
from .cait import *
from .coat import *
from .convit import *
from .cspnet import *
from .densenet import *
from .dla import *
@ -15,6 +16,7 @@ from .hrnet import *
from .inception_resnet_v2 import *
from .inception_v3 import *
from .inception_v4 import *
from .levit import *
from .mlp_mixer import *
from .mobilenetv3 import *
from .nasnet import *
@ -34,11 +36,13 @@ from .swin_transformer import *
from .tnt import *
from .tresnet import *
from .vgg import *
from .visformer import *
from .vision_transformer import *
from .vision_transformer_hybrid import *
from .vovnet import *
from .xception import *
from .xception_aligned import *
from .twins import *
from .factory import create_model, split_model_name, safe_model_name
from .helpers import load_checkpoint, resume_checkpoint, model_parameters

@ -12,24 +12,12 @@ Consider all of the models definitions here as experimental WIP and likely to ch
Hacked together by / copyright Ross Wightman, 2021.
"""
import math
from dataclasses import dataclass, field
from collections import OrderedDict
from typing import Tuple, List, Optional, Union, Any, Callable
from functools import partial
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .byobnet import BlocksCfg, ByobCfg, create_byob_stem, create_byob_stages, create_downsample,\
reduce_feat_size, register_block, num_groups, LayerFn, _init_weights
from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks
from .helpers import build_model_with_cfg
from .layers import ClassifierHead, ConvBnAct, DropPath, get_act_layer, convert_norm_act, get_attn, get_self_attn,\
make_divisible, to_2tuple
from .registry import register_model
__all__ = ['ByoaNet']
__all__ = []
def _cfg(url='', **kwargs):
@ -47,92 +35,84 @@ default_cfgs = {
# GPU-Efficient (ResNet) weights
'botnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'botnet50ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'eca_botnext26ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'halonet_h1': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'halonet_h1_c4c5': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'halonet26t': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'halonet50ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'eca_halonext26ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8), min_input_size=(3, 256, 256)),
'lambda_resnet26t': _cfg(url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
'lambda_resnet50t': _cfg(url='', min_input_size=(3, 128, 128)),
'eca_lambda_resnext26ts': _cfg(url='', min_input_size=(3, 128, 128), input_size=(3, 256, 256), pool_size=(8, 8)),
'swinnet26t_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'swinnet50ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
}
@dataclass
class ByoaBlocksCfg(BlocksCfg):
# FIXME allow overriding self_attn layer or args per block/stage,
pass
@dataclass
class ByoaCfg(ByobCfg):
blocks: Tuple[Union[ByoaBlocksCfg, Tuple[ByoaBlocksCfg, ...]], ...] = None
self_attn_layer: Optional[str] = None
self_attn_fixed_size: bool = False
self_attn_kwargs: dict = field(default_factory=lambda: dict())
'eca_swinnext26ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
def interleave_attn(
types : Tuple[str, str], every: Union[int, List[int]], d, first: bool = False, **kwargs
) -> Tuple[ByoaBlocksCfg]:
""" interleave attn blocks
"""
assert len(types) == 2
if isinstance(every, int):
every = list(range(0 if first else every, d, every))
if not every:
every = [d - 1]
set(every)
blocks = []
for i in range(d):
block_type = types[1] if i in every else types[0]
blocks += [ByoaBlocksCfg(type=block_type, d=1, **kwargs)]
return tuple(blocks)
'rednet26t': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'rednet50ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
}
model_cfgs = dict(
botnet26t=ByoaCfg(
botnet26t=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
fixed_input_size=True,
self_attn_layer='bottleneck',
self_attn_fixed_size=True,
self_attn_kwargs=dict()
),
botnet50ts=ByoaCfg(
botnet50ts=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='',
num_features=0,
fixed_input_size=True,
act_layer='silu',
self_attn_layer='bottleneck',
self_attn_fixed_size=True,
self_attn_kwargs=dict()
),
eca_botnext26ts=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
fixed_input_size=True,
act_layer='silu',
attn_layer='eca',
self_attn_layer='bottleneck',
self_attn_kwargs=dict()
),
halonet_h1=ByoaCfg(
halonet_h1=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='self_attn', d=3, c=64, s=1, gs=0, br=1.0),
ByoaBlocksCfg(type='self_attn', d=3, c=128, s=2, gs=0, br=1.0),
ByoaBlocksCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
ByoaBlocksCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
ByoBlockCfg(type='self_attn', d=3, c=64, s=1, gs=0, br=1.0),
ByoBlockCfg(type='self_attn', d=3, c=128, s=2, gs=0, br=1.0),
ByoBlockCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
ByoBlockCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
),
stem_chs=64,
stem_type='7x7',
@ -141,12 +121,12 @@ model_cfgs = dict(
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=3),
),
halonet_h1_c4c5=ByoaCfg(
halonet_h1_c4c5=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=64, s=1, gs=0, br=1.0),
ByoaBlocksCfg(type='bottle', d=3, c=128, s=2, gs=0, br=1.0),
ByoaBlocksCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
ByoaBlocksCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=0, br=1.0),
ByoBlockCfg(type='bottle', d=3, c=128, s=2, gs=0, br=1.0),
ByoBlockCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
ByoBlockCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
),
stem_chs=64,
stem_type='tiered',
@ -155,12 +135,12 @@ model_cfgs = dict(
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=3),
),
halonet26t=ByoaCfg(
halonet26t=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
@ -169,12 +149,12 @@ model_cfgs = dict(
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res
),
halonet50ts=ByoaCfg(
halonet50ts=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
@ -184,13 +164,29 @@ model_cfgs = dict(
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=2)
),
eca_halonext26ts=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
act_layer='silu',
attn_layer='eca',
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res
),
lambda_resnet26t=ByoaCfg(
lambda_resnet26t=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
@ -199,12 +195,12 @@ model_cfgs = dict(
self_attn_layer='lambda',
self_attn_kwargs=dict()
),
lambda_resnet50t=ByoaCfg(
lambda_resnet50t=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=3, d=6, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=3, d=6, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
@ -213,190 +209,108 @@ model_cfgs = dict(
self_attn_layer='lambda',
self_attn_kwargs=dict()
),
eca_lambda_resnext26ts=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
act_layer='silu',
attn_layer='eca',
self_attn_layer='lambda',
self_attn_kwargs=dict()
),
swinnet26t=ByoaCfg(
swinnet26t=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
fixed_input_size=True,
self_attn_layer='swin',
self_attn_fixed_size=True,
self_attn_kwargs=dict(win_size=8)
),
swinnet50ts=ByoaCfg(
swinnet50ts=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=4, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
fixed_input_size=True,
act_layer='silu',
self_attn_layer='swin',
self_attn_fixed_size=True,
self_attn_kwargs=dict(win_size=8)
),
)
@dataclass
class ByoaLayerFn(LayerFn):
self_attn: Optional[Callable] = None
eca_swinnext26ts=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=16, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
fixed_input_size=True,
act_layer='silu',
attn_layer='eca',
self_attn_layer='swin',
self_attn_kwargs=dict(win_size=8)
),
class SelfAttnBlock(nn.Module):
""" ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', extra_conv=False, linear_out=False, post_attn_na=True, feat_size=None,
layers: ByoaLayerFn = None, drop_block=None, drop_path_rate=0.):
super(SelfAttnBlock, self).__init__()
assert layers is not None
mid_chs = make_divisible(out_chs * bottle_ratio)
groups = num_groups(group_size, mid_chs)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = create_downsample(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
apply_act=False, layers=layers)
else:
self.shortcut = nn.Identity()
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
if extra_conv:
self.conv2_kxk = layers.conv_norm_act(
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block)
stride = 1 # striding done via conv if enabled
else:
self.conv2_kxk = nn.Identity()
opt_kwargs = {} if feat_size is None else dict(feat_size=feat_size)
# FIXME need to dilate self attn to have dilated network support, moop moop
self.self_attn = layers.self_attn(mid_chs, stride=stride, **opt_kwargs)
self.post_attn = layers.norm_act(mid_chs) if post_attn_na else nn.Identity()
self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
if zero_init_last_bn:
nn.init.zeros_(self.conv3_1x1.bn.weight)
if hasattr(self.self_attn, 'reset_parameters'):
self.self_attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)
x = self.conv1_1x1(x)
x = self.conv2_kxk(x)
x = self.self_attn(x)
x = self.post_attn(x)
x = self.conv3_1x1(x)
x = self.drop_path(x)
x = self.act(x + shortcut)
return x
register_block('self_attn', SelfAttnBlock)
def _byoa_block_args(block_kwargs, block_cfg: ByoaBlocksCfg, model_cfg: ByoaCfg, feat_size=None):
if block_cfg.type == 'self_attn' and model_cfg.self_attn_fixed_size:
assert feat_size is not None
block_kwargs['feat_size'] = feat_size
return block_kwargs
def get_layer_fns(cfg: ByoaCfg):
act = get_act_layer(cfg.act_layer)
norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act)
conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act)
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
self_attn = partial(get_self_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None
layer_fn = ByoaLayerFn(
conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn)
return layer_fn
class ByoaNet(nn.Module):
""" 'Bring-your-own-attention' Net
A ResNet inspired backbone that supports interleaving traditional residual blocks with
'Self Attention' bottleneck blocks that replace the bottleneck kxk conv w/ a self-attention
or similar module.
FIXME This class network definition is almost the same as ByobNet, I'd like to merge them but
torchscript limitations prevent sensible inheritance overrides.
"""
def __init__(self, cfg: ByoaCfg, num_classes=1000, in_chans=3, output_stride=32, global_pool='avg',
zero_init_last_bn=True, img_size=None, drop_rate=0., drop_path_rate=0.):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
layers = get_layer_fns(cfg)
feat_size = to_2tuple(img_size) if img_size is not None else None
self.feature_info = []
stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor))
self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers)
self.feature_info.extend(stem_feat[:-1])
feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction'])
self.stages, stage_feat = create_byob_stages(
cfg, drop_path_rate, output_stride, stem_feat[-1],
feat_size=feat_size, layers=layers, extra_args_fn=_byoa_block_args)
self.feature_info.extend(stage_feat[:-1])
prev_chs = stage_feat[-1]['num_chs']
if cfg.num_features:
self.num_features = int(round(cfg.width_factor * cfg.num_features))
self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1)
else:
self.num_features = prev_chs
self.final_conv = nn.Identity()
self.feature_info += [
dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')]
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
for n, m in self.named_modules():
_init_weights(m, n)
for m in self.modules():
# call each block's weight init for block-specific overrides to init above
if hasattr(m, 'init_weights'):
m.init_weights(zero_init_last_bn=zero_init_last_bn)
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
x = self.final_conv(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
rednet26t=ByoModelCfg(
blocks=(
ByoBlockCfg(type='self_attn', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=512, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered', # FIXME RedNet uses involution in middle of stem
stem_pool='maxpool',
num_features=0,
self_attn_layer='involution',
self_attn_kwargs=dict()
),
rednet50ts=ByoModelCfg(
blocks=(
ByoBlockCfg(type='self_attn', d=3, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=4, c=512, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
act_layer='silu',
self_attn_layer='involution',
self_attn_kwargs=dict()
),
)
def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs):
return build_model_with_cfg(
ByoaNet, variant, pretrained,
ByobNet, variant, pretrained,
default_cfg=default_cfgs[variant],
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
feature_cfg=dict(flatten_sequential=True),
@ -419,6 +333,14 @@ def botnet50ts_256(pretrained=False, **kwargs):
return _create_byoanet('botnet50ts_256', 'botnet50ts', pretrained=pretrained, **kwargs)
@register_model
def eca_botnext26ts_256(pretrained=False, **kwargs):
""" Bottleneck Transformer w/ ResNet26-T backbone. Bottleneck attn in final stage.
"""
kwargs.setdefault('img_size', 256)
return _create_byoanet('eca_botnext26ts_256', 'eca_botnext26ts', pretrained=pretrained, **kwargs)
@register_model
def halonet_h1(pretrained=False, **kwargs):
""" HaloNet-H1. Halo attention in all stages as per the paper.
@ -449,6 +371,13 @@ def halonet50ts(pretrained=False, **kwargs):
return _create_byoanet('halonet50ts', pretrained=pretrained, **kwargs)
@register_model
def eca_halonext26ts(pretrained=False, **kwargs):
""" HaloNet w/ a ResNet26-t backbone, Hallo attention in final stage
"""
return _create_byoanet('eca_halonext26ts', pretrained=pretrained, **kwargs)
@register_model
def lambda_resnet26t(pretrained=False, **kwargs):
""" Lambda-ResNet-26T. Lambda layers in one C4 stage and all C5.
@ -463,6 +392,13 @@ def lambda_resnet50t(pretrained=False, **kwargs):
return _create_byoanet('lambda_resnet50t', pretrained=pretrained, **kwargs)
@register_model
def eca_lambda_resnext26ts(pretrained=False, **kwargs):
""" Lambda-ResNet-26T. Lambda layers in one C4 stage and all C5.
"""
return _create_byoanet('eca_lambda_resnext26ts', pretrained=pretrained, **kwargs)
@register_model
def swinnet26t_256(pretrained=False, **kwargs):
"""
@ -477,3 +413,25 @@ def swinnet50ts_256(pretrained=False, **kwargs):
"""
kwargs.setdefault('img_size', 256)
return _create_byoanet('swinnet50ts_256', 'swinnet50ts', pretrained=pretrained, **kwargs)
@register_model
def eca_swinnext26ts_256(pretrained=False, **kwargs):
"""
"""
kwargs.setdefault('img_size', 256)
return _create_byoanet('eca_swinnext26ts_256', 'eca_swinnext26ts', pretrained=pretrained, **kwargs)
@register_model
def rednet26t(pretrained=False, **kwargs):
"""
"""
return _create_byoanet('rednet26t', pretrained=pretrained, **kwargs)
@register_model
def rednet50ts(pretrained=False, **kwargs):
"""
"""
return _create_byoanet('rednet50ts', pretrained=pretrained, **kwargs)

@ -26,8 +26,7 @@ Hacked together by / copyright Ross Wightman, 2021.
"""
import math
from dataclasses import dataclass, field, replace
from collections import OrderedDict
from typing import Tuple, List, Optional, Union, Any, Callable, Sequence
from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence
from functools import partial
import torch
@ -36,10 +35,10 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible
create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple
from .registry import register_model
__all__ = ['ByobNet', 'ByobCfg', 'BlocksCfg', 'create_byob_stem', 'create_block']
__all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block']
def _cfg(url='', **kwargs):
@ -87,35 +86,59 @@ default_cfgs = {
'repvgg_b3g4': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_b3g4-73c370bf.pth',
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv')),
# experimental configs
'resnet51q': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet51q_ra2-d47dcc76.pth',
first_conv='stem.conv1', input_size=(3, 256, 256), pool_size=(8, 8),
test_input_size=(3, 288, 288), crop_pct=1.0),
'resnet61q': _cfg(
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'geresnet50t': _cfg(
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'gcresnet50t': _cfg(
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
}
@dataclass
class BlocksCfg:
class ByoBlockCfg:
type: Union[str, nn.Module]
d: int # block depth (number of block repeats in stage)
c: int # number of output channels for each block in stage
s: int = 2 # stride of stage (first block)
gs: Optional[Union[int, Callable]] = None # group-size of blocks in stage, conv is depthwise if gs == 1
br: float = 1. # bottleneck-ratio of blocks in stage
no_attn: bool = True # disable channel attn (ie SE) when layer is set for model
# NOTE: these config items override the model cfgs that are applied to all blocks by default
attn_layer: Optional[str] = None
attn_kwargs: Optional[Dict[str, Any]] = None
self_attn_layer: Optional[str] = None
self_attn_kwargs: Optional[Dict[str, Any]] = None
block_kwargs: Optional[Dict[str, Any]] = None
@dataclass
class ByobCfg:
blocks: Tuple[Union[BlocksCfg, Tuple[BlocksCfg, ...]], ...]
class ByoModelCfg:
blocks: Tuple[Union[ByoBlockCfg, Tuple[ByoBlockCfg, ...]], ...]
downsample: str = 'conv1x1'
stem_type: str = '3x3'
stem_pool: str = ''
stem_pool: Optional[str] = 'maxpool'
stem_chs: int = 32
width_factor: float = 1.0
num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0
zero_init_last_bn: bool = True
fixed_input_size: bool = False # model constrained to a fixed-input size / img_size must be provided on creation
act_layer: str = 'relu'
norm_layer: str = 'batchnorm'
# NOTE: these config items will be overridden by the block cfg (per-block) if they are set there
attn_layer: Optional[str] = None
attn_kwargs: dict = field(default_factory=lambda: dict())
self_attn_layer: Optional[str] = None
self_attn_kwargs: dict = field(default_factory=lambda: dict())
block_kwargs: Dict[str, Any] = field(default_factory=lambda: dict())
def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0):
@ -123,103 +146,287 @@ def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0):
group_size = 0
if groups > 0:
group_size = lambda chs, idx: chs // groups if (idx + 1) % 2 == 0 else 0
bcfg = tuple([BlocksCfg(type='rep', d=d, c=c * wf, gs=group_size) for d, c, wf in zip(d, c, wf)])
bcfg = tuple([ByoBlockCfg(type='rep', d=d, c=c * wf, gs=group_size) for d, c, wf in zip(d, c, wf)])
return bcfg
model_cfgs = dict(
def interleave_blocks(
types: Tuple[str, str], every: Union[int, List[int]], d, first: bool = False, **kwargs
) -> Tuple[ByoBlockCfg]:
""" interleave 2 block types in stack
"""
assert len(types) == 2
if isinstance(every, int):
every = list(range(0 if first else every, d, every))
if not every:
every = [d - 1]
set(every)
blocks = []
for i in range(d):
block_type = types[1] if i in every else types[0]
blocks += [ByoBlockCfg(type=block_type, d=1, **kwargs)]
return tuple(blocks)
gernet_l=ByobCfg(
model_cfgs = dict(
gernet_l=ByoModelCfg(
blocks=(
BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
BlocksCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
BlocksCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
BlocksCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.),
BlocksCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.),
ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
ByoBlockCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.),
ByoBlockCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.),
),
stem_chs=32,
stem_pool=None,
num_features=2560,
),
gernet_m=ByobCfg(
gernet_m=ByoModelCfg(
blocks=(
BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
BlocksCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
BlocksCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
BlocksCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.),
BlocksCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.),
ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
ByoBlockCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.),
ByoBlockCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.),
),
stem_chs=32,
stem_pool=None,
num_features=2560,
),
gernet_s=ByobCfg(
gernet_s=ByoModelCfg(
blocks=(
BlocksCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.),
BlocksCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.),
BlocksCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4),
BlocksCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.),
BlocksCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.),
ByoBlockCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.),
ByoBlockCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.),
ByoBlockCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4),
ByoBlockCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.),
ByoBlockCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.),
),
stem_chs=13,
stem_pool=None,
num_features=1920,
),
repvgg_a2=ByobCfg(
repvgg_a2=ByoModelCfg(
blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1.5, 1.5, 1.5, 2.75)),
stem_type='rep',
stem_chs=64,
),
repvgg_b0=ByobCfg(
repvgg_b0=ByoModelCfg(
blocks=_rep_vgg_bcfg(wf=(1., 1., 1., 2.5)),
stem_type='rep',
stem_chs=64,
),
repvgg_b1=ByobCfg(
repvgg_b1=ByoModelCfg(
blocks=_rep_vgg_bcfg(wf=(2., 2., 2., 4.)),
stem_type='rep',
stem_chs=64,
),
repvgg_b1g4=ByobCfg(
repvgg_b1g4=ByoModelCfg(
blocks=_rep_vgg_bcfg(wf=(2., 2., 2., 4.), groups=4),
stem_type='rep',
stem_chs=64,
),
repvgg_b2=ByobCfg(
repvgg_b2=ByoModelCfg(
blocks=_rep_vgg_bcfg(wf=(2.5, 2.5, 2.5, 5.)),
stem_type='rep',
stem_chs=64,
),
repvgg_b2g4=ByobCfg(
repvgg_b2g4=ByoModelCfg(
blocks=_rep_vgg_bcfg(wf=(2.5, 2.5, 2.5, 5.), groups=4),
stem_type='rep',
stem_chs=64,
),
repvgg_b3=ByobCfg(
repvgg_b3=ByoModelCfg(
blocks=_rep_vgg_bcfg(wf=(3., 3., 3., 5.)),
stem_type='rep',
stem_chs=64,
),
repvgg_b3g4=ByobCfg(
repvgg_b3g4=ByoModelCfg(
blocks=_rep_vgg_bcfg(wf=(3., 3., 3., 5.), groups=4),
stem_type='rep',
stem_chs=64,
),
resnet52q=ByobCfg(
# WARN: experimental, may vanish/change
resnet51q=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
ByoBlockCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0),
),
stem_chs=128,
stem_type='quad2',
stem_pool=None,
num_features=2048,
act_layer='silu',
),
resnet61q=ByoModelCfg(
blocks=(
BlocksCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
BlocksCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
BlocksCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25),
BlocksCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0),
ByoBlockCfg(type='edge', d=1, c=256, s=1, gs=0, br=1.0, block_kwargs=dict()),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
ByoBlockCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0),
),
stem_chs=128,
stem_type='quad',
stem_pool=None,
num_features=2048,
act_layer='silu',
block_kwargs=dict(extra_conv=True),
),
# WARN: experimental, may vanish/change
geresnet50t=ByoModelCfg(
blocks=(
ByoBlockCfg(type='edge', d=3, c=256, s=1, br=0.25),
ByoBlockCfg(type='edge', d=4, c=512, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool=None,
attn_layer='ge',
attn_kwargs=dict(extent=8, extra_params=True),
#attn_kwargs=dict(extent=8),
#block_kwargs=dict(attn_last=True)
),
# WARN: experimental, may vanish/change
gcresnet50t=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool=None,
attn_layer='gc'
),
)
def expand_blocks_cfg(stage_blocks_cfg: Union[BlocksCfg, Sequence[BlocksCfg]]) -> List[BlocksCfg]:
@register_model
def gernet_l(pretrained=False, **kwargs):
""" GEResNet-Large (GENet-Large from official impl)
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
"""
return _create_byobnet('gernet_l', pretrained=pretrained, **kwargs)
@register_model
def gernet_m(pretrained=False, **kwargs):
""" GEResNet-Medium (GENet-Normal from official impl)
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
"""
return _create_byobnet('gernet_m', pretrained=pretrained, **kwargs)
@register_model
def gernet_s(pretrained=False, **kwargs):
""" EResNet-Small (GENet-Small from official impl)
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
"""
return _create_byobnet('gernet_s', pretrained=pretrained, **kwargs)
@register_model
def repvgg_a2(pretrained=False, **kwargs):
""" RepVGG-A2
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_a2', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b0(pretrained=False, **kwargs):
""" RepVGG-B0
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b0', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b1(pretrained=False, **kwargs):
""" RepVGG-B1
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b1', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b1g4(pretrained=False, **kwargs):
""" RepVGG-B1g4
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b1g4', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b2(pretrained=False, **kwargs):
""" RepVGG-B2
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b2', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b2g4(pretrained=False, **kwargs):
""" RepVGG-B2g4
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b2g4', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b3(pretrained=False, **kwargs):
""" RepVGG-B3
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b3', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b3g4(pretrained=False, **kwargs):
""" RepVGG-B3g4
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b3g4', pretrained=pretrained, **kwargs)
@register_model
def resnet51q(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('resnet51q', pretrained=pretrained, **kwargs)
@register_model
def resnet61q(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('resnet61q', pretrained=pretrained, **kwargs)
@register_model
def geresnet50t(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('geresnet50t', pretrained=pretrained, **kwargs)
@register_model
def gcresnet50t(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('gcresnet50t', pretrained=pretrained, **kwargs)
def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]:
if not isinstance(stage_blocks_cfg, Sequence):
stage_blocks_cfg = (stage_blocks_cfg,)
block_cfgs = []
@ -243,6 +450,7 @@ class LayerFn:
norm_act: Callable = BatchNormAct2d
act: Callable = nn.ReLU
attn: Optional[Callable] = None
self_attn: Optional[Callable] = None
class DownsampleAvg(nn.Module):
@ -275,7 +483,8 @@ class BasicBlock(nn.Module):
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), group_size=None, bottle_ratio=1.0,
downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None,
drop_path_rate=0.):
super(BasicBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
@ -289,15 +498,19 @@ class BasicBlock(nn.Module):
self.shortcut = nn.Identity()
self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0])
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
self.conv2_kxk = layers.conv_norm_act(
mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block, apply_act=False)
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
def init_weights(self, zero_init_last_bn: bool = False):
if zero_init_last_bn:
nn.init.zeros_(self.conv2_kxk.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)
@ -317,7 +530,8 @@ class BottleneckBlock(nn.Module):
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', linear_out=False, layers : LayerFn = None, drop_block=None, drop_path_rate=0.):
downsample='avg', attn_last=False, linear_out=False, extra_conv=False, layers: LayerFn = None,
drop_block=None, drop_path_rate=0.):
super(BottleneckBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
@ -334,22 +548,36 @@ class BottleneckBlock(nn.Module):
self.conv2_kxk = layers.conv_norm_act(
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block)
self.attn = nn.Identity() if layers.attn is None else layers.attn(mid_chs)
self.conv2_kxk = layers.conv_norm_act(
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block)
if extra_conv:
self.conv2b_kxk = layers.conv_norm_act(
mid_chs, mid_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block)
else:
self.conv2b_kxk = nn.Identity()
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
def init_weights(self, zero_init_last_bn: bool = False):
if zero_init_last_bn:
nn.init.zeros_(self.conv3_1x1.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)
x = self.conv1_1x1(x)
x = self.conv2_kxk(x)
x = self.conv2b_kxk(x)
x = self.attn(x)
x = self.conv3_1x1(x)
x = self.attn_last(x)
x = self.drop_path(x)
x = self.act(x + shortcut)
@ -368,7 +596,8 @@ class DarkBlock(nn.Module):
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None,
drop_path_rate=0.):
super(DarkBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
@ -382,23 +611,28 @@ class DarkBlock(nn.Module):
self.shortcut = nn.Identity()
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
self.conv2_kxk = layers.conv_norm_act(
mid_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block, apply_act=False)
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
def init_weights(self, zero_init_last_bn: bool = False):
if zero_init_last_bn:
nn.init.zeros_(self.conv2_kxk.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)
x = self.conv1_1x1(x)
x = self.conv2_kxk(x)
x = self.attn(x)
x = self.conv2_kxk(x)
x = self.attn_last(x)
x = self.drop_path(x)
x = self.act(x + shortcut)
return x
@ -415,7 +649,8 @@ class EdgeBlock(nn.Module):
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None,
drop_block=None, drop_path_rate=0.):
super(EdgeBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
@ -431,14 +666,18 @@ class EdgeBlock(nn.Module):
self.conv1_kxk = layers.conv_norm_act(
in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block)
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
self.conv2_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
def init_weights(self, zero_init_last_bn: bool = False):
if zero_init_last_bn:
nn.init.zeros_(self.conv2_1x1.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)
@ -446,6 +685,7 @@ class EdgeBlock(nn.Module):
x = self.conv1_kxk(x)
x = self.attn(x)
x = self.conv2_1x1(x)
x = self.attn_last(x)
x = self.drop_path(x)
x = self.act(x + shortcut)
return x
@ -460,7 +700,7 @@ class RepVggBlock(nn.Module):
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='', layers : LayerFn = None, drop_block=None, drop_path_rate=0.):
downsample='', layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(RepVggBlock, self).__init__()
layers = layers or LayerFn()
groups = num_groups(group_size, in_chs)
@ -475,12 +715,14 @@ class RepVggBlock(nn.Module):
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
self.act = layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
def init_weights(self, zero_init_last_bn: bool = False):
# NOTE this init overrides that base model init with specific changes for the block type
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight, .1, .1)
nn.init.normal_(m.bias, 0, .1)
if hasattr(self.attn, 'reset_parameters'):
self.attn.reset_parameters()
def forward(self, x):
if self.identity is None:
@ -495,12 +737,68 @@ class RepVggBlock(nn.Module):
return x
class SelfAttnBlock(nn.Module):
""" ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', extra_conv=False, linear_out=False, post_attn_na=True, feat_size=None,
layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(SelfAttnBlock, self).__init__()
assert layers is not None
mid_chs = make_divisible(out_chs * bottle_ratio)
groups = num_groups(group_size, mid_chs)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = create_downsample(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
apply_act=False, layers=layers)
else:
self.shortcut = nn.Identity()
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
if extra_conv:
self.conv2_kxk = layers.conv_norm_act(
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block)
stride = 1 # striding done via conv if enabled
else:
self.conv2_kxk = nn.Identity()
opt_kwargs = {} if feat_size is None else dict(feat_size=feat_size)
# FIXME need to dilate self attn to have dilated network support, moop moop
self.self_attn = layers.self_attn(mid_chs, stride=stride, **opt_kwargs)
self.post_attn = layers.norm_act(mid_chs) if post_attn_na else nn.Identity()
self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn: bool = False):
if zero_init_last_bn:
nn.init.zeros_(self.conv3_1x1.bn.weight)
if hasattr(self.self_attn, 'reset_parameters'):
self.self_attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)
x = self.conv1_1x1(x)
x = self.conv2_kxk(x)
x = self.self_attn(x)
x = self.post_attn(x)
x = self.conv3_1x1(x)
x = self.drop_path(x)
x = self.act(x + shortcut)
return x
_block_registry = dict(
basic=BasicBlock,
bottle=BottleneckBlock,
dark=DarkBlock,
edge=EdgeBlock,
rep=RepVggBlock,
self_attn=SelfAttnBlock,
)
@ -552,7 +850,7 @@ class Stem(nn.Sequential):
curr_stride *= s
prev_feat = conv_name
if 'max' in pool.lower():
if pool and 'max' in pool.lower():
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
self.add_module('pool', nn.MaxPool2d(3, 2, 1))
curr_stride *= 2
@ -564,7 +862,7 @@ class Stem(nn.Sequential):
def create_byob_stem(in_chs, out_chs, stem_type='', pool_type='', feat_prefix='stem', layers: LayerFn = None):
layers = layers or LayerFn()
assert stem_type in ('', 'quad', 'tiered', 'deep', 'rep', '7x7', '3x3')
assert stem_type in ('', 'quad', 'quad2', 'tiered', 'deep', 'rep', '7x7', '3x3')
if 'quad' in stem_type:
# based on NFNet stem, stack of 4 3x3 convs
num_act = 2 if 'quad2' in stem_type else None
@ -601,9 +899,58 @@ def reduce_feat_size(feat_size, stride=2):
return None if feat_size is None else tuple([s // stride for s in feat_size])
def override_kwargs(block_kwargs, model_kwargs):
""" Override model level attn/self-attn/block kwargs w/ block level
NOTE: kwargs are NOT merged across levels, block_kwargs will fully replace model_kwargs
for the block if set to anything that isn't None.
i.e. an empty block_kwargs dict will remove kwargs set at model level for that block
"""
out_kwargs = block_kwargs if block_kwargs is not None else model_kwargs
return out_kwargs or {} # make sure None isn't returned
def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, model_cfg: ByoModelCfg, ):
layer_fns = block_kwargs['layers']
# override attn layer / args with block local config
if block_cfg.attn_kwargs is not None or block_cfg.attn_layer is not None:
# override attn layer config
if not block_cfg.attn_layer:
# empty string for attn_layer type will disable attn for this block
attn_layer = None
else:
attn_kwargs = override_kwargs(block_cfg.attn_kwargs, model_cfg.attn_kwargs)
attn_layer = block_cfg.attn_layer or model_cfg.attn_layer
attn_layer = partial(get_attn(attn_layer), *attn_kwargs) if attn_layer is not None else None
layer_fns = replace(layer_fns, attn=attn_layer)
# override self-attn layer / args with block local cfg
if block_cfg.self_attn_kwargs is not None or block_cfg.self_attn_layer is not None:
# override attn layer config
if not block_cfg.self_attn_layer:
# empty string for self_attn_layer type will disable attn for this block
self_attn_layer = None
else:
self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs)
self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer
self_attn_layer = partial(get_attn(self_attn_layer), *self_attn_kwargs) \
if self_attn_layer is not None else None
layer_fns = replace(layer_fns, self_attn=self_attn_layer)
block_kwargs['layers'] = layer_fns
# add additional block_kwargs specified in block_cfg or model_cfg, precedence to block if set
block_kwargs.update(override_kwargs(block_cfg.block_kwargs, model_cfg.block_kwargs))
def create_byob_stages(
cfg, drop_path_rate, output_stride, stem_feat,
feat_size=None, layers=None, extra_args_fn=None):
cfg: ByoModelCfg, drop_path_rate: float, output_stride: int, stem_feat: Dict[str, Any],
feat_size: Optional[int] = None,
layers: Optional[LayerFn] = None,
block_kwargs_fn: Optional[Callable] = update_block_kwargs):
layers = layers or LayerFn()
feature_info = []
block_cfgs = [expand_blocks_cfg(s) for s in cfg.blocks]
@ -641,8 +988,10 @@ def create_byob_stages(
drop_path_rate=dpr[stage_idx][block_idx],
layers=layers,
)
if extra_args_fn is not None:
extra_args_fn(block_kwargs, block_cfg=block_cfg, model_cfg=cfg, feat_size=feat_size)
if block_cfg.type in ('self_attn',):
# add feat_size arg for blocks that support/need it
block_kwargs['feat_size'] = feat_size
block_kwargs_fn(block_kwargs, block_cfg=block_cfg, model_cfg=cfg)
blocks += [create_block(block_cfg.type, **block_kwargs)]
first_dilation = dilation
prev_chs = out_chs
@ -656,12 +1005,13 @@ def create_byob_stages(
return nn.Sequential(*stages), feature_info
def get_layer_fns(cfg: ByobCfg):
def get_layer_fns(cfg: ByoModelCfg):
act = get_act_layer(cfg.act_layer)
norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act)
conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act)
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn)
self_attn = partial(get_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None
layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn)
return layer_fn
@ -673,19 +1023,24 @@ class ByobNet(nn.Module):
Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act).
"""
def __init__(self, cfg: ByobCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
zero_init_last_bn=True, drop_rate=0., drop_path_rate=0.):
def __init__(self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
zero_init_last_bn=True, img_size=None, drop_rate=0., drop_path_rate=0.):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
layers = get_layer_fns(cfg)
if cfg.fixed_input_size:
assert img_size is not None, 'img_size argument is required for fixed input size model'
feat_size = to_2tuple(img_size) if img_size is not None else None
self.feature_info = []
stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor))
self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers)
self.feature_info.extend(stem_feat[:-1])
feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction'])
self.stages, stage_feat = create_byob_stages(cfg, drop_path_rate, output_stride, stem_feat[-1], layers=layers)
self.stages, stage_feat = create_byob_stages(
cfg, drop_path_rate, output_stride, stem_feat[-1], layers=layers, feat_size=feat_size)
self.feature_info.extend(stage_feat[:-1])
prev_chs = stage_feat[-1]['num_chs']
@ -748,91 +1103,3 @@ def _create_byobnet(variant, pretrained=False, **kwargs):
model_cfg=model_cfgs[variant],
feature_cfg=dict(flatten_sequential=True),
**kwargs)
@register_model
def gernet_l(pretrained=False, **kwargs):
""" GEResNet-Large (GENet-Large from official impl)
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
"""
return _create_byobnet('gernet_l', pretrained=pretrained, **kwargs)
@register_model
def gernet_m(pretrained=False, **kwargs):
""" GEResNet-Medium (GENet-Normal from official impl)
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
"""
return _create_byobnet('gernet_m', pretrained=pretrained, **kwargs)
@register_model
def gernet_s(pretrained=False, **kwargs):
""" EResNet-Small (GENet-Small from official impl)
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
"""
return _create_byobnet('gernet_s', pretrained=pretrained, **kwargs)
@register_model
def repvgg_a2(pretrained=False, **kwargs):
""" RepVGG-A2
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_a2', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b0(pretrained=False, **kwargs):
""" RepVGG-B0
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b0', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b1(pretrained=False, **kwargs):
""" RepVGG-B1
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b1', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b1g4(pretrained=False, **kwargs):
""" RepVGG-B1g4
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b1g4', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b2(pretrained=False, **kwargs):
""" RepVGG-B2
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b2', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b2g4(pretrained=False, **kwargs):
""" RepVGG-B2g4
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b2g4', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b3(pretrained=False, **kwargs):
""" RepVGG-B3
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b3', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b3g4(pretrained=False, **kwargs):
""" RepVGG-B3g4
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b3g4', pretrained=pretrained, **kwargs)

@ -306,26 +306,15 @@ def checkpoint_filter_fn(state_dict, model=None):
return checkpoint_no_module
def _create_cait(variant, pretrained=False, default_cfg=None, **kwargs):
if default_cfg is None:
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
def _create_cait(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
Cait, variant, pretrained,
default_cfg=default_cfg,
img_size=img_size,
num_classes=num_classes,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model

@ -7,19 +7,19 @@ Official CoaT code at: https://github.com/mlpc-ucsd/CoaT
Modified from timm/models/vision_transformer.py
"""
from typing import Tuple, Dict, Any, Optional
from copy import deepcopy
from functools import partial
from typing import Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import load_pretrained
from timm.models.layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import PatchEmbed, Mlp, DropPath, to_2tuple, trunc_normal_
from .registry import register_model
from functools import partial
from torch import nn
__all__ = [
"coat_tiny",
@ -34,7 +34,7 @@ def _cfg_coat(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed1.proj', 'classifier': 'head',
**kwargs
@ -42,15 +42,21 @@ def _cfg_coat(url='', **kwargs):
default_cfgs = {
'coat_tiny': _cfg_coat(),
'coat_mini': _cfg_coat(),
'coat_tiny': _cfg_coat(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_tiny-473c2a20.pth'
),
'coat_mini': _cfg_coat(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_mini-2c6baf49.pth'
),
'coat_lite_tiny': _cfg_coat(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_tiny-461b07a7.pth'
),
'coat_lite_mini': _cfg_coat(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_mini-d7842000.pth'
),
'coat_lite_small': _cfg_coat(),
'coat_lite_small': _cfg_coat(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-coat-weights/coat_lite_small-fea1d5a1.pth'
),
}
@ -120,11 +126,11 @@ class ConvRelPosEnc(nn.Module):
class FactorAtt_ConvRelPosEnc(nn.Module):
""" Factorized attention with convolutional relative position encoding class. """
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., shared_crpe=None):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0., shared_crpe=None):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop) # Note: attn_drop is actually not used.
@ -190,9 +196,8 @@ class ConvPosEnc(nn.Module):
class SerialBlock(nn.Module):
""" Serial block class.
Note: In this implementation, each serial block only contains a conv-attention and a FFN (MLP) module. """
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
shared_cpe=None, shared_crpe=None):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_cpe=None, shared_crpe=None):
super().__init__()
# Conv-Attention.
@ -200,8 +205,7 @@ class SerialBlock(nn.Module):
self.norm1 = norm_layer(dim)
self.factoratt_crpe = FactorAtt_ConvRelPosEnc(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
shared_crpe=shared_crpe)
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop, shared_crpe=shared_crpe)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
# MLP.
@ -226,27 +230,24 @@ class SerialBlock(nn.Module):
class ParallelBlock(nn.Module):
""" Parallel block class. """
def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
shared_cpes=None, shared_crpes=None):
def __init__(self, dims, num_heads, mlp_ratios=[], qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, shared_crpes=None):
super().__init__()
# Conv-Attention.
self.cpes = shared_cpes
self.norm12 = norm_layer(dims[1])
self.norm13 = norm_layer(dims[2])
self.norm14 = norm_layer(dims[3])
self.factoratt_crpe2 = FactorAtt_ConvRelPosEnc(
dims[1], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
dims[1], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
shared_crpe=shared_crpes[1]
)
self.factoratt_crpe3 = FactorAtt_ConvRelPosEnc(
dims[2], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
dims[2], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
shared_crpe=shared_crpes[2]
)
self.factoratt_crpe4 = FactorAtt_ConvRelPosEnc(
dims[3], num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
dims[3], num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop,
shared_crpe=shared_crpes[3]
)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
@ -262,15 +263,15 @@ class ParallelBlock(nn.Module):
self.mlp2 = self.mlp3 = self.mlp4 = Mlp(
in_features=dims[1], hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def upsample(self, x, factor, size):
def upsample(self, x, factor: float, size: Tuple[int, int]):
""" Feature map up-sampling. """
return self.interpolate(x, scale_factor=factor, size=size)
def downsample(self, x, factor, size):
def downsample(self, x, factor: float, size: Tuple[int, int]):
""" Feature map down-sampling. """
return self.interpolate(x, scale_factor=1.0/factor, size=size)
def interpolate(self, x, scale_factor, size):
def interpolate(self, x, scale_factor: float, size: Tuple[int, int]):
""" Feature map interpolation. """
B, N, C = x.shape
H, W = size
@ -280,33 +281,28 @@ class ParallelBlock(nn.Module):
img_tokens = x[:, 1:, :]
img_tokens = img_tokens.transpose(1, 2).reshape(B, C, H, W)
img_tokens = F.interpolate(img_tokens, scale_factor=scale_factor, mode='bilinear')
img_tokens = F.interpolate(
img_tokens, scale_factor=scale_factor, recompute_scale_factor=False, mode='bilinear', align_corners=False)
img_tokens = img_tokens.reshape(B, C, -1).transpose(1, 2)
out = torch.cat((cls_token, img_tokens), dim=1)
return out
def forward(self, x1, x2, x3, x4, sizes):
_, (H2, W2), (H3, W3), (H4, W4) = sizes
# Conv-Attention.
x2 = self.cpes[1](x2, size=(H2, W2)) # Note: x1 is ignored.
x3 = self.cpes[2](x3, size=(H3, W3))
x4 = self.cpes[3](x4, size=(H4, W4))
def forward(self, x1, x2, x3, x4, sizes: List[Tuple[int, int]]):
_, S2, S3, S4 = sizes
cur2 = self.norm12(x2)
cur3 = self.norm13(x3)
cur4 = self.norm14(x4)
cur2 = self.factoratt_crpe2(cur2, size=(H2, W2))
cur3 = self.factoratt_crpe3(cur3, size=(H3, W3))
cur4 = self.factoratt_crpe4(cur4, size=(H4, W4))
upsample3_2 = self.upsample(cur3, factor=2, size=(H3, W3))
upsample4_3 = self.upsample(cur4, factor=2, size=(H4, W4))
upsample4_2 = self.upsample(cur4, factor=4, size=(H4, W4))
downsample2_3 = self.downsample(cur2, factor=2, size=(H2, W2))
downsample3_4 = self.downsample(cur3, factor=2, size=(H3, W3))
downsample2_4 = self.downsample(cur2, factor=4, size=(H2, W2))
cur2 = self.factoratt_crpe2(cur2, size=S2)
cur3 = self.factoratt_crpe3(cur3, size=S3)
cur4 = self.factoratt_crpe4(cur4, size=S4)
upsample3_2 = self.upsample(cur3, factor=2., size=S3)
upsample4_3 = self.upsample(cur4, factor=2., size=S4)
upsample4_2 = self.upsample(cur4, factor=4., size=S4)
downsample2_3 = self.downsample(cur2, factor=2., size=S2)
downsample3_4 = self.downsample(cur3, factor=2., size=S3)
downsample2_4 = self.downsample(cur2, factor=4., size=S2)
cur2 = cur2 + upsample3_2 + upsample4_2
cur3 = cur3 + upsample4_3 + downsample2_3
cur4 = cur4 + downsample3_4 + downsample2_4
@ -330,11 +326,11 @@ class ParallelBlock(nn.Module):
class CoaT(nn.Module):
""" CoaT class. """
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[0, 0, 0, 0],
serial_depths=[0, 0, 0, 0], parallel_depth=0,
num_heads=0, mlp_ratios=[0, 0, 0, 0], qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
return_interm_layers=False, out_features = None, crpe_window=None, **kwargs):
def __init__(
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=(0, 0, 0, 0),
serial_depths=(0, 0, 0, 0), parallel_depth=0, num_heads=0, mlp_ratios=(0, 0, 0, 0), qkv_bias=True,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
return_interm_layers=False, out_features=None, crpe_window=None, **kwargs):
super().__init__()
crpe_window = crpe_window or {3: 2, 5: 3, 7: 3}
self.return_interm_layers = return_interm_layers
@ -342,17 +338,18 @@ class CoaT(nn.Module):
self.num_classes = num_classes
# Patch embeddings.
img_size = to_2tuple(img_size)
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dims[0], norm_layer=nn.LayerNorm)
self.patch_embed2 = PatchEmbed(
img_size=img_size // 4, patch_size=2, in_chans=embed_dims[0],
img_size=[x // 4 for x in img_size], patch_size=2, in_chans=embed_dims[0],
embed_dim=embed_dims[1], norm_layer=nn.LayerNorm)
self.patch_embed3 = PatchEmbed(
img_size=img_size // 8, patch_size=2, in_chans=embed_dims[1],
img_size=[x // 8 for x in img_size], patch_size=2, in_chans=embed_dims[1],
embed_dim=embed_dims[2], norm_layer=nn.LayerNorm)
self.patch_embed4 = PatchEmbed(
img_size=img_size // 16, patch_size=2, in_chans=embed_dims[2],
img_size=[x // 16 for x in img_size], patch_size=2, in_chans=embed_dims[2],
embed_dim=embed_dims[3], norm_layer=nn.LayerNorm)
# Class tokens.
@ -380,7 +377,7 @@ class CoaT(nn.Module):
# Serial blocks 1.
self.serial_blocks1 = nn.ModuleList([
SerialBlock(
dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
dim=embed_dims[0], num_heads=num_heads, mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpe=self.cpe1, shared_crpe=self.crpe1
)
@ -390,7 +387,7 @@ class CoaT(nn.Module):
# Serial blocks 2.
self.serial_blocks2 = nn.ModuleList([
SerialBlock(
dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
dim=embed_dims[1], num_heads=num_heads, mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpe=self.cpe2, shared_crpe=self.crpe2
)
@ -400,7 +397,7 @@ class CoaT(nn.Module):
# Serial blocks 3.
self.serial_blocks3 = nn.ModuleList([
SerialBlock(
dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
dim=embed_dims[2], num_heads=num_heads, mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpe=self.cpe3, shared_crpe=self.crpe3
)
@ -410,7 +407,7 @@ class CoaT(nn.Module):
# Serial blocks 4.
self.serial_blocks4 = nn.ModuleList([
SerialBlock(
dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
dim=embed_dims[3], num_heads=num_heads, mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpe=self.cpe4, shared_crpe=self.crpe4
)
@ -422,10 +419,9 @@ class CoaT(nn.Module):
if self.parallel_depth > 0:
self.parallel_blocks = nn.ModuleList([
ParallelBlock(
dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_cpes=[self.cpe1, self.cpe2, self.cpe3, self.cpe4],
shared_crpes=[self.crpe1, self.crpe2, self.crpe3, self.crpe4]
dims=embed_dims, num_heads=num_heads, mlp_ratios=mlp_ratios, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr, norm_layer=norm_layer,
shared_crpes=(self.crpe1, self.crpe2, self.crpe3, self.crpe4)
)
for _ in range(parallel_depth)]
)
@ -434,9 +430,11 @@ class CoaT(nn.Module):
# Classification head(s).
if not self.return_interm_layers:
self.norm1 = norm_layer(embed_dims[0])
self.norm2 = norm_layer(embed_dims[1])
self.norm3 = norm_layer(embed_dims[2])
if self.parallel_blocks is not None:
self.norm2 = norm_layer(embed_dims[1])
self.norm3 = norm_layer(embed_dims[2])
else:
self.norm2 = self.norm3 = None
self.norm4 = norm_layer(embed_dims[3])
if self.parallel_depth > 0:
@ -546,6 +544,7 @@ class CoaT(nn.Module):
# Parallel blocks.
for blk in self.parallel_blocks:
x2, x3, x4 = self.cpe2(x2, (H2, W2)), self.cpe3(x3, (H3, W3)), self.cpe4(x4, (H4, W4))
x1, x2, x3, x4 = blk(x1, x2, x3, x4, sizes=[(H1, W1), (H2, W2), (H3, W3), (H4, W4)])
if not torch.jit.is_scripting() and self.return_interm_layers:
@ -590,52 +589,70 @@ class CoaT(nn.Module):
return x
def checkpoint_filter_fn(state_dict, model):
out_dict = {}
for k, v in state_dict.items():
# original model had unused norm layers, removing them requires filtering pretrained checkpoints
if k.startswith('norm1') or \
(model.norm2 is None and k.startswith('norm2')) or \
(model.norm3 is None and k.startswith('norm3')):
continue
out_dict[k] = v
return out_dict
def _create_coat(variant, pretrained=False, default_cfg=None, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
CoaT, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model
@register_model
def coat_tiny(pretrained=False, **kwargs):
model = CoaT(
model_cfg = dict(
patch_size=4, embed_dims=[152, 152, 152, 152], serial_depths=[2, 2, 2, 2], parallel_depth=6,
num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)
model.default_cfg = default_cfgs['coat_tiny']
model = _create_coat('coat_tiny', pretrained=pretrained, **model_cfg)
return model
@register_model
def coat_mini(pretrained=False, **kwargs):
model = CoaT(
model_cfg = dict(
patch_size=4, embed_dims=[152, 216, 216, 216], serial_depths=[2, 2, 2, 2], parallel_depth=6,
num_heads=8, mlp_ratios=[4, 4, 4, 4], **kwargs)
model.default_cfg = default_cfgs['coat_mini']
model = _create_coat('coat_mini', pretrained=pretrained, **model_cfg)
return model
@register_model
def coat_lite_tiny(pretrained=False, **kwargs):
model = CoaT(
model_cfg = dict(
patch_size=4, embed_dims=[64, 128, 256, 320], serial_depths=[2, 2, 2, 2], parallel_depth=0,
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
# FIXME use builder
model.default_cfg = default_cfgs['coat_lite_tiny']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model = _create_coat('coat_lite_tiny', pretrained=pretrained, **model_cfg)
return model
@register_model
def coat_lite_mini(pretrained=False, **kwargs):
model = CoaT(
model_cfg = dict(
patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[2, 2, 2, 2], parallel_depth=0,
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
# FIXME use builder
model.default_cfg = default_cfgs['coat_lite_mini']
if pretrained:
load_pretrained(model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model = _create_coat('coat_lite_mini', pretrained=pretrained, **model_cfg)
return model
@register_model
def coat_lite_small(pretrained=False, **kwargs):
model = CoaT(
model_cfg = dict(
patch_size=4, embed_dims=[64, 128, 320, 512], serial_depths=[3, 4, 6, 3], parallel_depth=0,
num_heads=8, mlp_ratios=[8, 8, 4, 4], **kwargs)
model.default_cfg = default_cfgs['coat_lite_small']
model = _create_coat('coat_lite_small', pretrained=pretrained, **model_cfg)
return model

@ -0,0 +1,353 @@
""" ConViT Model
@article{d2021convit,
title={ConViT: Improving Vision Transformers with Soft Convolutional Inductive Biases},
author={d'Ascoli, St{\'e}phane and Touvron, Hugo and Leavitt, Matthew and Morcos, Ari and Biroli, Giulio and Sagun, Levent},
journal={arXiv preprint arXiv:2103.10697},
year={2021}
}
Paper link: https://arxiv.org/abs/2103.10697
Original code: https://github.com/facebookresearch/convit, original copyright below
"""
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.
#
'''These modules are adapted from those of timm, see
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
'''
import torch
import torch.nn as nn
from functools import partial
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import DropPath, to_2tuple, trunc_normal_, PatchEmbed, Mlp
from .registry import register_model
from .vision_transformer_hybrid import HybridEmbed
import torch
import torch.nn as nn
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'fixed_input_size': True,
'first_conv': 'patch_embed.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
# ConViT
'convit_tiny': _cfg(
url="https://dl.fbaipublicfiles.com/convit/convit_tiny.pth"),
'convit_small': _cfg(
url="https://dl.fbaipublicfiles.com/convit/convit_small.pth"),
'convit_base': _cfg(
url="https://dl.fbaipublicfiles.com/convit/convit_base.pth")
}
class GPSA(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,
locality_strength=1.):
super().__init__()
self.num_heads = num_heads
self.dim = dim
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.locality_strength = locality_strength
self.qk = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.v = nn.Linear(dim, dim, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.pos_proj = nn.Linear(3, num_heads)
self.proj_drop = nn.Dropout(proj_drop)
self.locality_strength = locality_strength
self.gating_param = nn.Parameter(torch.ones(self.num_heads))
self.rel_indices: torch.Tensor = torch.zeros(1, 1, 1, 3) # silly torchscript hack, won't work with None
def forward(self, x):
B, N, C = x.shape
if self.rel_indices is None or self.rel_indices.shape[1] != N:
self.rel_indices = self.get_rel_indices(N)
attn = self.get_attention(x)
v = self.v(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def get_attention(self, x):
B, N, C = x.shape
qk = self.qk(x).reshape(B, N, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k = qk[0], qk[1]
pos_score = self.rel_indices.expand(B, -1, -1, -1)
pos_score = self.pos_proj(pos_score).permute(0, 3, 1, 2)
patch_score = (q @ k.transpose(-2, -1)) * self.scale
patch_score = patch_score.softmax(dim=-1)
pos_score = pos_score.softmax(dim=-1)
gating = self.gating_param.view(1, -1, 1, 1)
attn = (1. - torch.sigmoid(gating)) * patch_score + torch.sigmoid(gating) * pos_score
attn /= attn.sum(dim=-1).unsqueeze(-1)
attn = self.attn_drop(attn)
return attn
def get_attention_map(self, x, return_map=False):
attn_map = self.get_attention(x).mean(0) # average over batch
distances = self.rel_indices.squeeze()[:, :, -1] ** .5
dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / distances.size(0)
if return_map:
return dist, attn_map
else:
return dist
def local_init(self):
self.v.weight.data.copy_(torch.eye(self.dim))
locality_distance = 1 # max(1,1/locality_strength**.5)
kernel_size = int(self.num_heads ** .5)
center = (kernel_size - 1) / 2 if kernel_size % 2 == 0 else kernel_size // 2
for h1 in range(kernel_size):
for h2 in range(kernel_size):
position = h1 + kernel_size * h2
self.pos_proj.weight.data[position, 2] = -1
self.pos_proj.weight.data[position, 1] = 2 * (h1 - center) * locality_distance
self.pos_proj.weight.data[position, 0] = 2 * (h2 - center) * locality_distance
self.pos_proj.weight.data *= self.locality_strength
def get_rel_indices(self, num_patches: int) -> torch.Tensor:
img_size = int(num_patches ** .5)
rel_indices = torch.zeros(1, num_patches, num_patches, 3)
ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1)
indx = ind.repeat(img_size, img_size)
indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1)
indd = indx ** 2 + indy ** 2
rel_indices[:, :, :, 2] = indd.unsqueeze(0)
rel_indices[:, :, :, 1] = indy.unsqueeze(0)
rel_indices[:, :, :, 0] = indx.unsqueeze(0)
device = self.qk.weight.device
return rel_indices.to(device)
class MHSA(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def get_attention_map(self, x, return_map=False):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn_map = (q @ k.transpose(-2, -1)) * self.scale
attn_map = attn_map.softmax(dim=-1).mean(0)
img_size = int(N ** .5)
ind = torch.arange(img_size).view(1, -1) - torch.arange(img_size).view(-1, 1)
indx = ind.repeat(img_size, img_size)
indy = ind.repeat_interleave(img_size, dim=0).repeat_interleave(img_size, dim=1)
indd = indx ** 2 + indy ** 2
distances = indd ** .5
distances = distances.to('cuda')
dist = torch.einsum('nm,hnm->h', (distances, attn_map)) / N
if return_map:
return dist, attn_map
else:
return dist
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, use_gpsa=True, **kwargs):
super().__init__()
self.norm1 = norm_layer(dim)
self.use_gpsa = use_gpsa
if self.use_gpsa:
self.attn = GPSA(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
proj_drop=drop, **kwargs)
else:
self.attn = MHSA(
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
proj_drop=drop, **kwargs)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class ConViT(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, global_pool=None,
local_up_to_layer=3, locality_strength=1., use_pos_embed=True):
super().__init__()
embed_dim *= num_heads
self.num_classes = num_classes
self.local_up_to_layer = local_up_to_layer
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.locality_strength = locality_strength
self.use_pos_embed = use_pos_embed
if hybrid_backbone is not None:
self.patch_embed = HybridEmbed(
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
else:
self.patch_embed = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.num_patches = num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop_rate)
if self.use_pos_embed:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.pos_embed, std=.02)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
use_gpsa=True,
locality_strength=locality_strength)
if i < local_up_to_layer else
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
use_gpsa=False)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
# Classifier head
self.feature_info = [dict(num_chs=embed_dim, reduction=0, module='head')]
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
for n, m in self.named_modules():
if hasattr(m, 'local_init'):
m.local_init()
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
if self.use_pos_embed:
x = x + self.pos_embed
x = self.pos_drop(x)
for u, blk in enumerate(self.blocks):
if u == self.local_up_to_layer:
x = torch.cat((cls_tokens, x), dim=1)
x = blk(x)
x = self.norm(x)
return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _create_convit(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
return build_model_with_cfg(
ConViT, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
@register_model
def convit_tiny(pretrained=False, **kwargs):
model_args = dict(
local_up_to_layer=10, locality_strength=1.0, embed_dim=48,
num_heads=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model = _create_convit(variant='convit_tiny', pretrained=pretrained, **model_args)
return model
@register_model
def convit_small(pretrained=False, **kwargs):
model_args = dict(
local_up_to_layer=10, locality_strength=1.0, embed_dim=48,
num_heads=9, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model = _create_convit(variant='convit_small', pretrained=pretrained, **model_args)
return model
@register_model
def convit_base(pretrained=False, **kwargs):
model_args = dict(
local_up_to_layer=10, locality_strength=1.0, embed_dim=48,
num_heads=16, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model = _create_convit(variant='convit_base', pretrained=pretrained, **model_args)
return model

@ -91,6 +91,12 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth',
interpolation='bilinear'),
# NOTE experimenting with alternate attention
'eca_efficientnet_b0': _cfg(
url=''),
'gc_efficientnet_b0': _cfg(
url=''),
'efficientnet_b0': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth'),
'efficientnet_b1': _cfg(
@ -162,6 +168,9 @@ default_cfgs = {
'efficientnetv2_rw_s': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_v2s_ra2_288-a6477665.pth',
input_size=(3, 288, 288), test_input_size=(3, 384, 384), pool_size=(9, 9), crop_pct=1.0),
'efficientnetv2_rw_m': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnetv2_rw_m_agc-3d90cb1e.pth',
input_size=(3, 320, 320), test_input_size=(3, 416, 416), pool_size=(10, 10), crop_pct=1.0),
'efficientnetv2_s': _cfg(
url='',
@ -173,7 +182,6 @@ default_cfgs = {
url='',
input_size=(3, 384, 384), test_input_size=(3, 480, 480), pool_size=(12, 12), crop_pct=1.0),
'tf_efficientnet_b0': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth',
input_size=(3, 224, 224)),
@ -1221,6 +1229,26 @@ def efficientnet_b0(pretrained=False, **kwargs):
return model
@register_model
def eca_efficientnet_b0(pretrained=False, **kwargs):
""" EfficientNet-B0 w/ ECA attn """
# NOTE experimental config
model = _gen_efficientnet(
'eca_efficientnet_b0', se_layer='ecam', channel_multiplier=1.0, depth_multiplier=1.0,
pretrained=pretrained, **kwargs)
return model
@register_model
def gc_efficientnet_b0(pretrained=False, **kwargs):
""" EfficientNet-B0 w/ GlobalContext """
# NOTE experminetal config
model = _gen_efficientnet(
'gc_efficientnet_b0', se_layer='gc', channel_multiplier=1.0, depth_multiplier=1.0,
pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_b1(pretrained=False, **kwargs):
""" EfficientNet-B1 """
@ -1461,7 +1489,7 @@ def efficientnet_b3_pruned(pretrained=False, **kwargs):
@register_model
def efficientnetv2_rw_s(pretrained=False, **kwargs):
""" EfficientNet-V2 Small.
""" EfficientNet-V2 Small RW variant.
NOTE: This is my initial (pre official code release) w/ some differences.
See efficientnetv2_s and tf_efficientnetv2_s for versions that match the official w/ PyTorch vs TF padding
"""
@ -1469,6 +1497,16 @@ def efficientnetv2_rw_s(pretrained=False, **kwargs):
return model
@register_model
def efficientnetv2_rw_m(pretrained=False, **kwargs):
""" EfficientNet-V2 Medium RW variant.
"""
model = _gen_efficientnetv2_s(
'efficientnetv2_rw_m', channel_multiplier=1.2, depth_multiplier=(1.2,) * 4 + (1.6,) * 2, rw=True,
pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnetv2_s(pretrained=False, **kwargs):
""" EfficientNet-V2 Small. """

@ -7,7 +7,7 @@ import torch
import torch.nn as nn
from torch.nn import functional as F
from .layers import create_conv2d, drop_path, make_divisible
from .layers import create_conv2d, drop_path, make_divisible, create_act_layer
from .layers.activations import sigmoid
__all__ = [
@ -19,33 +19,32 @@ class SqueezeExcite(nn.Module):
Args:
in_chs (int): input channels to layer
se_ratio (float): ratio of squeeze reduction
rd_ratio (float): ratio of squeeze reduction
act_layer (nn.Module): activation layer of containing block
gate_fn (Callable): attention gate function
block_in_chs (int): input channels of containing block (for calculating reduction from)
reduce_from_block (bool): calculate reduction from block input channels if True
gate_layer (Callable): attention gate function
force_act_layer (nn.Module): override block's activation fn if this is set/bound
divisor (int): make reduction channels divisible by this
rd_round_fn (Callable): specify a fn to calculate rounding of reduced chs
"""
def __init__(
self, in_chs, se_ratio=0.25, act_layer=nn.ReLU, gate_fn=sigmoid,
block_in_chs=None, reduce_from_block=True, force_act_layer=None, divisor=1):
self, in_chs, rd_ratio=0.25, rd_channels=None, act_layer=nn.ReLU,
gate_layer=nn.Sigmoid, force_act_layer=None, rd_round_fn=None):
super(SqueezeExcite, self).__init__()
reduced_chs = (block_in_chs or in_chs) if reduce_from_block else in_chs
reduced_chs = make_divisible(reduced_chs * se_ratio, divisor)
if rd_channels is None:
rd_round_fn = rd_round_fn or round
rd_channels = rd_round_fn(in_chs * rd_ratio)
act_layer = force_act_layer or act_layer
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
self.act1 = act_layer(inplace=True)
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
self.gate_fn = gate_fn
self.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True)
self.act1 = create_act_layer(act_layer, inplace=True)
self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
x_se = self.conv_reduce(x_se)
x_se = self.act1(x_se)
x_se = self.conv_expand(x_se)
return x * self.gate_fn(x_se)
return x * self.gate(x_se)
class ConvBnAct(nn.Module):
@ -87,10 +86,9 @@ class DepthwiseSeparableConv(nn.Module):
"""
def __init__(
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
noskip=False, pw_kernel_size=1, pw_act=False, se_ratio=0.,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
noskip=False, pw_kernel_size=1, pw_act=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
se_layer=None, drop_path_rate=0.):
super(DepthwiseSeparableConv, self).__init__()
has_se = se_layer is not None and se_ratio > 0.
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
self.has_pw_act = pw_act # activation after point-wise conv
self.drop_path_rate = drop_path_rate
@ -101,7 +99,7 @@ class DepthwiseSeparableConv(nn.Module):
self.act1 = act_layer(inplace=True)
# Squeeze-and-excitation
self.se = se_layer(in_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity()
self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity()
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
self.bn2 = norm_layer(out_chs)
@ -146,12 +144,11 @@ class InvertedResidual(nn.Module):
def __init__(
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0.,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.):
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.):
super(InvertedResidual, self).__init__()
conv_kwargs = conv_kwargs or {}
mid_chs = make_divisible(in_chs * exp_ratio)
has_se = se_layer is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_path_rate = drop_path_rate
@ -168,8 +165,7 @@ class InvertedResidual(nn.Module):
self.act2 = act_layer(inplace=True)
# Squeeze-and-excitation
self.se = se_layer(
mid_chs, se_ratio=se_ratio, act_layer=act_layer, block_in_chs=in_chs) if has_se else nn.Identity()
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
# Point-wise linear projection
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
@ -215,8 +211,8 @@ class CondConvResidual(InvertedResidual):
def __init__(
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0.,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.):
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.):
self.num_experts = num_experts
conv_kwargs = dict(num_experts=self.num_experts)
@ -224,8 +220,8 @@ class CondConvResidual(InvertedResidual):
super(CondConvResidual, self).__init__(
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type,
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_layer=se_layer,
norm_layer=norm_layer, conv_kwargs=conv_kwargs, drop_path_rate=drop_path_rate)
pw_kernel_size=pw_kernel_size, se_layer=se_layer, norm_layer=norm_layer, conv_kwargs=conv_kwargs,
drop_path_rate=drop_path_rate)
self.routing_fn = nn.Linear(in_chs, self.num_experts)
@ -274,8 +270,8 @@ class EdgeResidual(nn.Module):
def __init__(
self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, pad_type='',
force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, se_ratio=0.,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
super(EdgeResidual, self).__init__()
if force_in_chs > 0:
mid_chs = make_divisible(force_in_chs * exp_ratio)
@ -292,8 +288,7 @@ class EdgeResidual(nn.Module):
self.act1 = act_layer(inplace=True)
# Squeeze-and-excitation
self.se = SqueezeExcite(
mid_chs, se_ratio=se_ratio, act_layer=act_layer, block_in_chs=in_chs) if has_se else nn.Identity()
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
# Point-wise linear projection
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type)

@ -10,11 +10,12 @@ import logging
import math
import re
from copy import deepcopy
from functools import partial
import torch.nn as nn
from .efficientnet_blocks import *
from .layers import CondConv2d, get_condconv_initializer, get_act_layer, make_divisible
from .layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
@ -50,10 +51,7 @@ def resolve_bn_args(kwargs):
def resolve_act_layer(kwargs, default='relu'):
act_layer = kwargs.pop('act_layer', default)
if isinstance(act_layer, str):
act_layer = get_act_layer(act_layer)
return act_layer
return get_act_layer(kwargs.pop('act_layer', default))
def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9):
@ -123,7 +121,9 @@ def _decode_block_str(block_str):
elif v == 'hs':
value = get_act_layer('hard_swish')
elif v == 'sw':
value = get_act_layer('swish')
value = get_act_layer('swish') # aka SiLU
elif v == 'mi':
value = get_act_layer('mish')
else:
continue
options[key] = value
@ -237,7 +237,11 @@ def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='c
def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False):
arch_args = []
for stack_idx, block_strings in enumerate(arch_def):
if isinstance(depth_multiplier, tuple):
assert len(depth_multiplier) == len(arch_def)
else:
depth_multiplier = (depth_multiplier,) * len(arch_def)
for stack_idx, (block_strings, multiplier) in enumerate(zip(arch_def, depth_multiplier)):
assert isinstance(block_strings, list)
stack_args = []
repeats = []
@ -251,7 +255,7 @@ def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_
if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):
arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc))
else:
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
arch_args.append(_scale_stage_depth(stack_args, repeats, multiplier, depth_trunc))
return arch_args
@ -264,14 +268,20 @@ class EfficientNetBuilder:
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
"""
def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels,
def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=False,
act_layer=None, norm_layer=None, se_layer=None, drop_path_rate=0., feature_location=''):
self.output_stride = output_stride
self.pad_type = pad_type
self.round_chs_fn = round_chs_fn
self.se_from_exp = se_from_exp # calculate se channel reduction from expanded (mid) chs
self.act_layer = act_layer
self.norm_layer = norm_layer
self.se_layer = se_layer
self.se_layer = get_attn(se_layer)
try:
self.se_layer(8, rd_ratio=1.0) # test if attn layer accepts rd_ratio arg
self.se_has_ratio = True
except TypeError:
self.se_has_ratio = False
self.drop_path_rate = drop_path_rate
if feature_location == 'depthwise':
# old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
@ -298,16 +308,21 @@ class EfficientNetBuilder:
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
assert ba['act_layer'] is not None
ba['norm_layer'] = self.norm_layer
ba['drop_path_rate'] = drop_path_rate
if bt != 'cn':
ba['se_layer'] = self.se_layer
ba['drop_path_rate'] = drop_path_rate
se_ratio = ba.pop('se_ratio')
if se_ratio and self.se_layer is not None:
if not self.se_from_exp:
# adjust se_ratio by expansion ratio if calculating se channels from block input
se_ratio /= ba.get('exp_ratio', 1.0)
if self.se_has_ratio:
ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio)
else:
ba['se_layer'] = self.se_layer
if bt == 'ir':
_log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
if ba.get('num_experts', 0) > 0:
block = CondConvResidual(**ba)
else:
block = InvertedResidual(**ba)
block = CondConvResidual(**ba) if ba.get('num_experts', 0) else InvertedResidual(**ba)
elif bt == 'ds' or bt == 'dsa':
_log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = DepthwiseSeparableConv(**ba)
@ -417,28 +432,28 @@ def _init_weight_goog(m, n='', fix_group_fanout=True):
if fix_group_fanout:
fan_out //= m.groups
init_weight_fn = get_condconv_initializer(
lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
lambda w: nn.init.normal_(w, 0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
init_weight_fn(m.weight)
if m.bias is not None:
m.bias.data.zero_()
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
if fix_group_fanout:
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
nn.init.normal_(m.weight, 0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
fan_out = m.weight.size(0) # fan-out
fan_in = 0
if 'routing_fn' in n:
fan_in = m.weight.size(1)
init_range = 1.0 / math.sqrt(fan_in + fan_out)
m.weight.data.uniform_(-init_range, init_range)
m.bias.data.zero_()
nn.init.uniform_(m.weight, -init_range, init_range)
nn.init.zeros_(m.bias)
def efficientnet_init_weights(model: nn.Module, init_fn=None):

@ -13,7 +13,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .layers import SelectAdaptivePool2d, Linear, hard_sigmoid, make_divisible
from .layers import SelectAdaptivePool2d, Linear, make_divisible
from .efficientnet_blocks import SqueezeExcite, ConvBnAct
from .helpers import build_model_with_cfg
from .registry import register_model
@ -40,7 +40,7 @@ default_cfgs = {
}
_SE_LAYER = partial(SqueezeExcite, gate_fn=hard_sigmoid, divisor=4)
_SE_LAYER = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=partial(make_divisible, divisor=4))
class GhostModule(nn.Module):
@ -92,7 +92,7 @@ class GhostBottleneck(nn.Module):
self.bn_dw = None
# Squeeze-and-excitation
self.se = _SE_LAYER(mid_chs, se_ratio=se_ratio) if has_se else None
self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio) if has_se else None
# Point-wise linear projection
self.ghost2 = GhostModule(mid_chs, out_chs, relu=False)

@ -4,7 +4,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .efficientnet_blocks import SqueezeExcite
from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args
from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels
from .helpers import build_model_with_cfg, default_cfg_for_features
from .layers import get_act_fn
from .mobilenetv3 import MobileNetV3, MobileNetV3Features
@ -39,8 +39,7 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
"""
num_features = 1280
se_layer = partial(
SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, reduce_from_block=False, divisor=8)
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
num_features=num_features,

@ -378,7 +378,11 @@ def update_default_cfg_and_kwargs(default_cfg, kwargs, kwargs_filter):
# Overlay default cfg values from `external_default_cfg` if it exists in kwargs
overlay_external_default_cfg(default_cfg, kwargs)
# Set model __init__ args that can be determined by default_cfg (if not already passed as kwargs)
set_default_kwargs(kwargs, names=('num_classes', 'global_pool', 'in_chans'), default_cfg=default_cfg)
default_kwarg_names = ('num_classes', 'global_pool', 'in_chans')
if default_cfg.get('fixed_input_size', False):
# if fixed_input_size exists and is True, model takes an img_size arg that fixes its input size
default_kwarg_names += ('img_size',)
set_default_kwargs(kwargs, names=default_kwarg_names, default_cfg=default_cfg)
# Filter keyword args for task specific model variants (some 'features only' models, etc.)
filter_kwargs(kwargs, names=kwargs_filter)

@ -12,25 +12,28 @@ from .create_act import create_act_layer, get_act_layer, get_act_fn
from .create_attn import get_attn, create_attn
from .create_conv2d import create_conv2d
from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act
from .create_self_attn import get_self_attn, create_self_attn
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .eca import EcaModule, CecaModule
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
from .gather_excite import GatherExcite
from .global_context import GlobalContext
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
from .inplace_abn import InplaceAbn
from .involution import Involution
from .linear import Linear
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp
from .norm import GroupNorm
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct
from .padding import get_padding, get_same_padding, pad_same
from .patch_embed import PatchEmbed
from .pool2d_same import AvgPool2dSame, create_pool2d
from .se import SEModule
from .selective_kernel import SelectiveKernelConv
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernel
from .separable_conv import SeparableConv2d, SeparableConvBnAct
from .space_to_depth import SpaceToDepthModule
from .split_attn import SplitAttnConv2d
from .split_attn import SplitAttn
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
from .test_time_pool import TestTimePoolHead, apply_test_time_pool

@ -7,78 +7,87 @@ some tasks, especially fine-grained it seems. I may end up removing this impl.
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from torch import nn as nn
import torch.nn.functional as F
from .conv_bn_act import ConvBnAct
from .create_act import create_act_layer, get_act_layer
from .helpers import make_divisible
class ChannelAttn(nn.Module):
""" Original CBAM channel attention module, currently avg + max pool variant only.
"""
def __init__(self, channels, reduction=16, act_layer=nn.ReLU):
def __init__(
self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
super(ChannelAttn, self).__init__()
self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False)
if not rd_channels:
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias)
self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False)
self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_avg = x.mean((2, 3), keepdim=True)
x_max = F.adaptive_max_pool2d(x, 1)
x_avg = self.fc2(self.act(self.fc1(x_avg)))
x_max = self.fc2(self.act(self.fc1(x_max)))
x_attn = x_avg + x_max
return x * x_attn.sigmoid()
x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True))))
x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True))))
return x * self.gate(x_avg + x_max)
class LightChannelAttn(ChannelAttn):
"""An experimental 'lightweight' that sums avg + max pool first
"""
def __init__(self, channels, reduction=16):
super(LightChannelAttn, self).__init__(channels, reduction)
def __init__(
self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
super(LightChannelAttn, self).__init__(
channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias)
def forward(self, x):
x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * F.adaptive_max_pool2d(x, 1)
x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True)
x_attn = self.fc2(self.act(self.fc1(x_pool)))
return x * x_attn.sigmoid()
return x * F.sigmoid(x_attn)
class SpatialAttn(nn.Module):
""" Original CBAM spatial attention module
"""
def __init__(self, kernel_size=7):
def __init__(self, kernel_size=7, gate_layer='sigmoid'):
super(SpatialAttn, self).__init__()
self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_avg = torch.mean(x, dim=1, keepdim=True)
x_max = torch.max(x, dim=1, keepdim=True)[0]
x_attn = torch.cat([x_avg, x_max], dim=1)
x_attn = torch.cat([x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1)
x_attn = self.conv(x_attn)
return x * x_attn.sigmoid()
return x * self.gate(x_attn)
class LightSpatialAttn(nn.Module):
"""An experimental 'lightweight' variant that sums avg_pool and max_pool results.
"""
def __init__(self, kernel_size=7):
def __init__(self, kernel_size=7, gate_layer='sigmoid'):
super(LightSpatialAttn, self).__init__()
self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_avg = torch.mean(x, dim=1, keepdim=True)
x_max = torch.max(x, dim=1, keepdim=True)[0]
x_attn = 0.5 * x_avg + 0.5 * x_max
x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True)
x_attn = self.conv(x_attn)
return x * x_attn.sigmoid()
return x * self.gate(x_attn)
class CbamModule(nn.Module):
def __init__(self, channels, spatial_kernel_size=7):
def __init__(
self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
super(CbamModule, self).__init__()
self.channel = ChannelAttn(channels)
self.spatial = SpatialAttn(spatial_kernel_size)
self.channel = ChannelAttn(
channels, rd_ratio=rd_ratio, rd_channels=rd_channels,
rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias)
self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer)
def forward(self, x):
x = self.channel(x)
@ -87,9 +96,13 @@ class CbamModule(nn.Module):
class LightCbamModule(nn.Module):
def __init__(self, channels, spatial_kernel_size=7):
def __init__(
self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
super(LightCbamModule, self).__init__()
self.channel = LightChannelAttn(channels)
self.channel = LightChannelAttn(
channels, rd_ratio=rd_ratio, rd_channels=rd_channels,
rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias)
self.spatial = LightSpatialAttn(spatial_kernel_size)
def forward(self, x):

@ -1,20 +1,26 @@
""" Activation Factory
Hacked together by / Copyright 2020 Ross Wightman
"""
from typing import Union, Callable, Type
from .activations import *
from .activations_jit import *
from .activations_me import *
from .config import is_exportable, is_scriptable, is_no_jit
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. This code
# will use native version if present. Eventually, the custom Swish layers will be removed
# and only native 'silu' will be used.
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
# Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used.
_has_silu = 'silu' in dir(torch.nn.functional)
_has_hardswish = 'hardswish' in dir(torch.nn.functional)
_has_hardsigmoid = 'hardsigmoid' in dir(torch.nn.functional)
_has_mish = 'mish' in dir(torch.nn.functional)
_ACT_FN_DEFAULT = dict(
silu=F.silu if _has_silu else swish,
swish=F.silu if _has_silu else swish,
mish=mish,
mish=F.mish if _has_mish else mish,
relu=F.relu,
relu6=F.relu6,
leaky_relu=F.leaky_relu,
@ -24,33 +30,39 @@ _ACT_FN_DEFAULT = dict(
gelu=gelu,
sigmoid=sigmoid,
tanh=tanh,
hard_sigmoid=hard_sigmoid,
hard_swish=hard_swish,
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid,
hard_swish=F.hardswish if _has_hardswish else hard_swish,
hard_mish=hard_mish,
)
_ACT_FN_JIT = dict(
silu=F.silu if _has_silu else swish_jit,
swish=F.silu if _has_silu else swish_jit,
mish=mish_jit,
hard_sigmoid=hard_sigmoid_jit,
hard_swish=hard_swish_jit,
mish=F.mish if _has_mish else mish_jit,
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit,
hard_swish=F.hardswish if _has_hardswish else hard_swish_jit,
hard_mish=hard_mish_jit
)
_ACT_FN_ME = dict(
silu=F.silu if _has_silu else swish_me,
swish=F.silu if _has_silu else swish_me,
mish=mish_me,
hard_sigmoid=hard_sigmoid_me,
hard_swish=hard_swish_me,
mish=F.mish if _has_mish else mish_me,
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me,
hard_swish=F.hardswish if _has_hardswish else hard_swish_me,
hard_mish=hard_mish_me,
)
_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT)
for a in _ACT_FNS:
a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
a.setdefault('hardswish', a.get('hard_swish'))
_ACT_LAYER_DEFAULT = dict(
silu=nn.SiLU if _has_silu else Swish,
swish=nn.SiLU if _has_silu else Swish,
mish=Mish,
mish=nn.Mish if _has_mish else Mish,
relu=nn.ReLU,
relu6=nn.ReLU6,
leaky_relu=nn.LeakyReLU,
@ -61,37 +73,44 @@ _ACT_LAYER_DEFAULT = dict(
gelu=GELU,
sigmoid=Sigmoid,
tanh=Tanh,
hard_sigmoid=HardSigmoid,
hard_swish=HardSwish,
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid,
hard_swish=nn.Hardswish if _has_hardswish else HardSwish,
hard_mish=HardMish,
)
_ACT_LAYER_JIT = dict(
silu=nn.SiLU if _has_silu else SwishJit,
swish=nn.SiLU if _has_silu else SwishJit,
mish=MishJit,
hard_sigmoid=HardSigmoidJit,
hard_swish=HardSwishJit,
mish=nn.Mish if _has_mish else MishJit,
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit,
hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit,
hard_mish=HardMishJit
)
_ACT_LAYER_ME = dict(
silu=nn.SiLU if _has_silu else SwishMe,
swish=nn.SiLU if _has_silu else SwishMe,
mish=MishMe,
hard_sigmoid=HardSigmoidMe,
hard_swish=HardSwishMe,
mish=nn.Mish if _has_mish else MishMe,
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe,
hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe,
hard_mish=HardMishMe,
)
_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT)
for a in _ACT_LAYERS:
a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
a.setdefault('hardswish', a.get('hard_swish'))
def get_act_fn(name='relu'):
def get_act_fn(name: Union[Callable, str] = 'relu'):
""" Activation Function Factory
Fetching activation fns by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
"""
if not name:
return None
if isinstance(name, Callable):
return name
if not (is_no_jit() or is_exportable() or is_scriptable()):
# If not exporting or scripting the model, first look for a memory-efficient version with
# custom autograd, then fallback
@ -106,13 +125,15 @@ def get_act_fn(name='relu'):
return _ACT_FN_DEFAULT[name]
def get_act_layer(name='relu'):
def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
""" Activation Layer Factory
Fetching activation layers by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
"""
if not name:
return None
if isinstance(name, type):
return name
if not (is_no_jit() or is_exportable() or is_scriptable()):
if name in _ACT_LAYER_ME:
return _ACT_LAYER_ME[name]
@ -125,9 +146,8 @@ def get_act_layer(name='relu'):
return _ACT_LAYER_DEFAULT[name]
def create_act_layer(name, inplace=False, **kwargs):
def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs):
act_layer = get_act_layer(name)
if act_layer is not None:
return act_layer(inplace=inplace, **kwargs)
else:
if act_layer is None:
return None
return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs)

@ -1,11 +1,23 @@
""" Select AttentionFactory Method
""" Attention Factory
Hacked together by / Copyright 2020 Ross Wightman
Hacked together by / Copyright 2021 Ross Wightman
"""
import torch
from .se import SEModule, EffectiveSEModule
from .eca import EcaModule, CecaModule
from functools import partial
from .bottleneck_attn import BottleneckAttn
from .cbam import CbamModule, LightCbamModule
from .eca import EcaModule, CecaModule
from .gather_excite import GatherExcite
from .global_context import GlobalContext
from .halo_attn import HaloAttn
from .involution import Involution
from .lambda_layer import LambdaLayer
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .selective_kernel import SelectiveKernel
from .split_attn import SplitAttn
from .squeeze_excite import SEModule, EffectiveSEModule
from .swin_attn import WindowAttention
def get_attn(attn_type):
@ -15,18 +27,54 @@ def get_attn(attn_type):
if attn_type is not None:
if isinstance(attn_type, str):
attn_type = attn_type.lower()
# Lightweight attention modules (channel and/or coarse spatial).
# Typically added to existing network architecture blocks in addition to existing convolutions.
if attn_type == 'se':
module_cls = SEModule
elif attn_type == 'ese':
module_cls = EffectiveSEModule
elif attn_type == 'eca':
module_cls = EcaModule
elif attn_type == 'ecam':
module_cls = partial(EcaModule, use_mlp=True)
elif attn_type == 'ceca':
module_cls = CecaModule
elif attn_type == 'ge':
module_cls = GatherExcite
elif attn_type == 'gc':
module_cls = GlobalContext
elif attn_type == 'cbam':
module_cls = CbamModule
elif attn_type == 'lcbam':
module_cls = LightCbamModule
# Attention / attention-like modules w/ significant params
# Typically replace some of the existing workhorse convs in a network architecture.
# All of these accept a stride argument and can spatially downsample the input.
elif attn_type == 'sk':
module_cls = SelectiveKernel
elif attn_type == 'splat':
module_cls = SplitAttn
# Self-attention / attention-like modules w/ significant compute and/or params
# Typically replace some of the existing workhorse convs in a network architecture.
# All of these accept a stride argument and can spatially downsample the input.
elif attn_type == 'lambda':
return LambdaLayer
elif attn_type == 'bottleneck':
return BottleneckAttn
elif attn_type == 'halo':
return HaloAttn
elif attn_type == 'swin':
return WindowAttention
elif attn_type == 'involution':
return Involution
elif attn_type == 'nl':
module_cls = NonLocalAttn
elif attn_type == 'bat':
module_cls = BatNonLocalAttn
# Woops!
else:
assert False, "Invalid attn module (%s)" % attn_type
elif isinstance(attn_type, bool):

@ -1,22 +0,0 @@
from .bottleneck_attn import BottleneckAttn
from .halo_attn import HaloAttn
from .lambda_layer import LambdaLayer
from .swin_attn import WindowAttention
def get_self_attn(attn_type):
if attn_type == 'bottleneck':
return BottleneckAttn
elif attn_type == 'halo':
return HaloAttn
elif attn_type == 'lambda':
return LambdaLayer
elif attn_type == 'swin':
return WindowAttention
else:
assert False, f"Unknown attn type ({attn_type})"
def create_self_attn(attn_type, dim, stride=1, **kwargs):
attn_fn = get_self_attn(attn_type)
return attn_fn(dim, stride=stride, **kwargs)

@ -38,6 +38,10 @@ from torch import nn
import torch.nn.functional as F
from .create_act import create_act_layer
from .helpers import make_divisible
class EcaModule(nn.Module):
"""Constructs an ECA module.
@ -48,23 +52,48 @@ class EcaModule(nn.Module):
refer to original paper https://arxiv.org/pdf/1910.03151.pdf
(default=None. if channel size not given, use k_size given for kernel size.)
kernel_size: Adaptive selection of kernel size (default=3)
gamm: used in kernel_size calc, see above
beta: used in kernel_size calc, see above
act_layer: optional non-linearity after conv, enables conv bias, this is an experiment
gate_layer: gating non-linearity to use
"""
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
def __init__(
self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid',
rd_ratio=1/8, rd_channels=None, rd_divisor=8, use_mlp=False):
super(EcaModule, self).__init__()
assert kernel_size % 2 == 1
if channels is not None:
t = int(abs(math.log(channels, 2) + beta) / gamma)
kernel_size = max(t if t % 2 else t + 1, 3)
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
assert kernel_size % 2 == 1
padding = (kernel_size - 1) // 2
if use_mlp:
# NOTE 'mlp' mode is a timm experiment, not in paper
assert channels is not None
if rd_channels is None:
rd_channels = make_divisible(channels * rd_ratio, divisor=rd_divisor)
act_layer = act_layer or nn.ReLU
self.conv = nn.Conv1d(1, rd_channels, kernel_size=1, padding=0, bias=True)
self.act = create_act_layer(act_layer)
self.conv2 = nn.Conv1d(rd_channels, 1, kernel_size=kernel_size, padding=padding, bias=True)
else:
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
self.act = None
self.conv2 = None
self.gate = create_act_layer(gate_layer)
def forward(self, x):
y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv
y = self.conv(y)
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
if self.conv2 is not None:
y = self.act(y)
y = self.conv2(y)
y = self.gate(y).view(x.shape[0], -1, 1, 1)
return x * y.expand_as(x)
EfficientChannelAttn = EcaModule # alias
class CecaModule(nn.Module):
"""Constructs a circular ECA module.
@ -83,25 +112,34 @@ class CecaModule(nn.Module):
refer to original paper https://arxiv.org/pdf/1910.03151.pdf
(default=None. if channel size not given, use k_size given for kernel size.)
kernel_size: Adaptive selection of kernel size (default=3)
gamm: used in kernel_size calc, see above
beta: used in kernel_size calc, see above
act_layer: optional non-linearity after conv, enables conv bias, this is an experiment
gate_layer: gating non-linearity to use
"""
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'):
super(CecaModule, self).__init__()
assert kernel_size % 2 == 1
if channels is not None:
t = int(abs(math.log(channels, 2) + beta) / gamma)
kernel_size = max(t if t % 2 else t + 1, 3)
has_act = act_layer is not None
assert kernel_size % 2 == 1
# PyTorch circular padding mode is buggy as of pytorch 1.4
# see https://github.com/pytorch/pytorch/pull/17240
# implement manual circular padding
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False)
self.padding = (kernel_size - 1) // 2
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
y = x.mean((2, 3)).view(x.shape[0], 1, -1)
# Manually implement circular padding, F.pad does not seemed to be bugged
y = F.pad(y, (self.padding, self.padding), mode='circular')
y = self.conv(y)
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
y = self.gate(y).view(x.shape[0], -1, 1, 1)
return x * y.expand_as(x)
CircularEfficientChannelAttn = CecaModule

@ -0,0 +1,90 @@
""" Gather-Excite Attention Block
Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348
Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet
I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another
impl that covers all of the cases.
NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation
Hacked together by / Copyright 2021 Ross Wightman
"""
import math
from torch import nn as nn
import torch.nn.functional as F
from .create_act import create_act_layer, get_act_layer
from .create_conv2d import create_conv2d
from .helpers import make_divisible
from .mlp import ConvMlp
class GatherExcite(nn.Module):
""" Gather-Excite Attention Module
"""
def __init__(
self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True,
rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'):
super(GatherExcite, self).__init__()
self.add_maxpool = add_maxpool
act_layer = get_act_layer(act_layer)
self.extent = extent
if extra_params:
self.gather = nn.Sequential()
if extent == 0:
assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params'
self.gather.add_module(
'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True))
if norm_layer:
self.gather.add_module(f'norm1', nn.BatchNorm2d(channels))
else:
assert extent % 2 == 0
num_conv = int(math.log2(extent))
for i in range(num_conv):
self.gather.add_module(
f'conv{i + 1}',
create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True))
if norm_layer:
self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels))
if i != num_conv - 1:
self.gather.add_module(f'act{i + 1}', act_layer(inplace=True))
else:
self.gather = None
if self.extent == 0:
self.gk = 0
self.gs = 0
else:
assert extent % 2 == 0
self.gk = self.extent * 2 - 1
self.gs = self.extent
if not rd_channels:
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity()
self.gate = create_act_layer(gate_layer)
def forward(self, x):
size = x.shape[-2:]
if self.gather is not None:
x_ge = self.gather(x)
else:
if self.extent == 0:
# global extent
x_ge = x.mean(dim=(2, 3), keepdims=True)
if self.add_maxpool:
# experimental codepath, may remove or change
x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True)
else:
x_ge = F.avg_pool2d(
x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False)
if self.add_maxpool:
# experimental codepath, may remove or change
x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2)
x_ge = self.mlp(x_ge)
if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1:
x_ge = F.interpolate(x_ge, size=size)
return x * self.gate(x_ge)

@ -0,0 +1,67 @@
""" Global Context Attention Block
Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond`
- https://arxiv.org/abs/1904.11492
Official code consulted as reference: https://github.com/xvjiarui/GCNet
Hacked together by / Copyright 2021 Ross Wightman
"""
from torch import nn as nn
import torch.nn.functional as F
from .create_act import create_act_layer, get_act_layer
from .helpers import make_divisible
from .mlp import ConvMlp
from .norm import LayerNorm2d
class GlobalContext(nn.Module):
def __init__(self, channels, use_attn=True, fuse_add=True, fuse_scale=False, init_last_zero=False,
rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'):
super(GlobalContext, self).__init__()
act_layer = get_act_layer(act_layer)
self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None
if rd_channels is None:
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
if fuse_add:
self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
else:
self.mlp_add = None
if fuse_scale:
self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
else:
self.mlp_scale = None
self.gate = create_act_layer(gate_layer)
self.init_last_zero = init_last_zero
self.reset_parameters()
def reset_parameters(self):
if self.conv_attn is not None:
nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu')
if self.mlp_add is not None:
nn.init.zeros_(self.mlp_add.fc2.weight)
def forward(self, x):
B, C, H, W = x.shape
if self.conv_attn is not None:
attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W)
attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1)
context = x.reshape(B, C, H * W).unsqueeze(1) @ attn
context = context.view(B, C, 1, 1)
else:
context = x.mean(dim=(2, 3), keepdim=True)
if self.mlp_scale is not None:
mlp_x = self.mlp_scale(context)
x = x * self.gate(mlp_x)
if self.mlp_add is not None:
mlp_x = self.mlp_add(context)
x = x + mlp_x
return x

@ -28,4 +28,4 @@ def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
# Make sure that round down does not go down by more than 10%.
if new_v < round_limit * v:
new_v += divisor
return new_v
return new_v

@ -0,0 +1,50 @@
""" PyTorch Involution Layer
Official impl: https://github.com/d-li14/involution/blob/main/cls/mmcls/models/utils/involution_naive.py
Paper: `Involution: Inverting the Inherence of Convolution for Visual Recognition` - https://arxiv.org/abs/2103.06255
"""
import torch.nn as nn
from .conv_bn_act import ConvBnAct
from .create_conv2d import create_conv2d
class Involution(nn.Module):
def __init__(
self,
channels,
kernel_size=3,
stride=1,
group_size=16,
rd_ratio=4,
norm_layer=nn.BatchNorm2d,
act_layer=nn.ReLU,
):
super(Involution, self).__init__()
self.kernel_size = kernel_size
self.stride = stride
self.channels = channels
self.group_size = group_size
self.groups = self.channels // self.group_size
self.conv1 = ConvBnAct(
in_channels=channels,
out_channels=channels // rd_ratio,
kernel_size=1,
norm_layer=norm_layer,
act_layer=act_layer)
self.conv2 = self.conv = create_conv2d(
in_channels=channels // rd_ratio,
out_channels=kernel_size**2 * self.groups,
kernel_size=1,
stride=1)
self.avgpool = nn.AvgPool2d(stride, stride) if stride == 2 else nn.Identity()
self.unfold = nn.Unfold(kernel_size, 1, (kernel_size-1)//2, stride)
def forward(self, x):
weight = self.conv2(self.conv1(self.avgpool(x)))
B, C, H, W = weight.shape
KK = int(self.kernel_size ** 2)
weight = weight.view(B, self.groups, KK, H, W).unsqueeze(2)
out = self.unfold(x).view(B, self.groups, self.group_size, KK, H, W)
out = (weight * out).sum(dim=3).view(B, self.channels, H, W)
return out

@ -77,3 +77,26 @@ class GatedMlp(nn.Module):
x = self.fc2(x)
x = self.drop(x)
return x
class ConvMlp(nn.Module):
""" MLP using 1x1 convs that keeps spatial dims
"""
def __init__(
self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True)
self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.norm(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
return x

@ -0,0 +1,145 @@
""" Bilinear-Attention-Transform and Non-Local Attention
Paper: `Non-Local Neural Networks With Grouped Bilinear Attentional Transforms`
- https://openaccess.thecvf.com/content_CVPR_2020/html/Chi_Non-Local_Neural_Networks_With_Grouped_Bilinear_Attentional_Transforms_CVPR_2020_paper.html
Adapted from original code: https://github.com/BA-Transform/BAT-Image-Classification
"""
import torch
from torch import nn
from torch.nn import functional as F
from .conv_bn_act import ConvBnAct
from .helpers import make_divisible
class NonLocalAttn(nn.Module):
"""Spatial NL block for image classification.
This was adapted from https://github.com/BA-Transform/BAT-Image-Classification
Their NonLocal impl inspired by https://github.com/facebookresearch/video-nonlocal-net.
"""
def __init__(self, in_channels, use_scale=True, rd_ratio=1/8, rd_channels=None, rd_divisor=8, **kwargs):
super(NonLocalAttn, self).__init__()
if rd_channels is None:
rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)
self.scale = in_channels ** -0.5 if use_scale else 1.0
self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True)
self.norm = nn.BatchNorm2d(in_channels)
self.reset_parameters()
def forward(self, x):
shortcut = x
t = self.t(x)
p = self.p(x)
g = self.g(x)
B, C, H, W = t.size()
t = t.view(B, C, -1).permute(0, 2, 1)
p = p.view(B, C, -1)
g = g.view(B, C, -1).permute(0, 2, 1)
att = torch.bmm(t, p) * self.scale
att = F.softmax(att, dim=2)
x = torch.bmm(att, g)
x = x.permute(0, 2, 1).reshape(B, C, H, W)
x = self.z(x)
x = self.norm(x) + shortcut
return x
def reset_parameters(self):
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
if len(list(m.parameters())) > 1:
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 0)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.weight, 0)
nn.init.constant_(m.bias, 0)
class BilinearAttnTransform(nn.Module):
def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(BilinearAttnTransform, self).__init__()
self.conv1 = ConvBnAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer)
self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1))
self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size))
self.conv2 = ConvBnAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
self.block_size = block_size
self.groups = groups
self.in_channels = in_channels
def resize_mat(self, x, t):
B, C, block_size, block_size1 = x.shape
assert block_size == block_size1
if t <= 1:
return x
x = x.view(B * C, -1, 1, 1)
x = x * torch.eye(t, t, dtype=x.dtype, device=x.device)
x = x.view(B * C, block_size, block_size, t, t)
x = torch.cat(torch.split(x, 1, dim=1), dim=3)
x = torch.cat(torch.split(x, 1, dim=2), dim=4)
x = x.view(B, C, block_size * t, block_size * t)
return x
def forward(self, x):
assert x.shape[-1] % self.block_size == 0 and x.shape[-2] % self.block_size == 0
B, C, H, W = x.shape
out = self.conv1(x)
rp = F.adaptive_max_pool2d(out, (self.block_size, 1))
cp = F.adaptive_max_pool2d(out, (1, self.block_size))
p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size)
q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size)
p = F.sigmoid(p)
q = F.sigmoid(q)
p = p / p.sum(dim=3, keepdim=True)
q = q / q.sum(dim=2, keepdim=True)
p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(
0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous()
p = p.view(B, C, self.block_size, self.block_size)
q = q.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(
0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous()
q = q.view(B, C, self.block_size, self.block_size)
p = self.resize_mat(p, H // self.block_size)
q = self.resize_mat(q, W // self.block_size)
y = p.matmul(x)
y = y.matmul(q)
y = self.conv2(y)
return y
class BatNonLocalAttn(nn.Module):
""" BAT
Adapted from: https://github.com/BA-Transform/BAT-Image-Classification
"""
def __init__(
self, in_channels, block_size=7, groups=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
drop_rate=0.2, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, **_):
super().__init__()
if rd_channels is None:
rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)
self.conv1 = ConvBnAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer)
self.conv2 = ConvBnAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
self.dropout = nn.Dropout2d(p=drop_rate)
def forward(self, x):
xl = self.conv1(x)
y = self.ba(xl)
y = self.conv2(y)
y = self.dropout(y)
return y + x

@ -12,3 +12,12 @@ class GroupNorm(nn.GroupNorm):
def forward(self, x):
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
class LayerNorm2d(nn.LayerNorm):
""" Layernorm for channels of '2d' spatial BCHW tensors """
def __init__(self, num_channels):
super().__init__([num_channels, 1, 1])
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)

@ -15,7 +15,7 @@ from .helpers import to_2tuple
class PatchEmbed(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None):
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
@ -23,6 +23,7 @@ class PatchEmbed(nn.Module):
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
@ -31,6 +32,8 @@ class PatchEmbed(nn.Module):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
return x

@ -1,50 +0,0 @@
from torch import nn as nn
import torch.nn.functional as F
from .create_act import create_act_layer
from .helpers import make_divisible
class SEModule(nn.Module):
""" SE Module as defined in original SE-Nets with a few additions
Additions include:
* min_channels can be specified to keep reduced channel count at a minimum (default: 8)
* divisor can be specified to keep channels rounded to specified values (default: 1)
* reduction channels can be specified directly by arg (if reduction_channels is set)
* reduction channels can be specified by float ratio (if reduction_ratio is set)
"""
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, gate_layer='sigmoid',
reduction_ratio=None, reduction_channels=None, min_channels=8, divisor=1):
super(SEModule, self).__init__()
if reduction_channels is not None:
reduction_channels = reduction_channels # direct specification highest priority, no rounding/min done
elif reduction_ratio is not None:
reduction_channels = make_divisible(channels * reduction_ratio, divisor, min_channels)
else:
reduction_channels = make_divisible(channels // reduction, divisor, min_channels)
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
x_se = self.fc1(x_se)
x_se = self.act(x_se)
x_se = self.fc2(x_se)
return x * self.gate(x_se)
class EffectiveSEModule(nn.Module):
""" 'Effective Squeeze-Excitation
From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
"""
def __init__(self, channels, gate_layer='hard_sigmoid'):
super(EffectiveSEModule, self).__init__()
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
self.gate = create_act_layer(gate_layer, inplace=True)
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
x_se = self.fc(x_se)
return x * self.gate(x_se)

@ -8,6 +8,7 @@ import torch
from torch import nn as nn
from .conv_bn_act import ConvBnAct
from .helpers import make_divisible
def _kernel_valid(k):
@ -45,10 +46,10 @@ class SelectiveKernelAttn(nn.Module):
return x
class SelectiveKernelConv(nn.Module):
class SelectiveKernel(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1,
attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False,
def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1,
rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True,
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None):
""" Selective Kernel Convolution Module
@ -66,8 +67,7 @@ class SelectiveKernelConv(nn.Module):
stride (int): stride for convolutions
dilation (int): dilation for module as a whole, impacts dilation of each branch
groups (int): number of groups for each branch
attn_reduction (int, float): reduction factor for attention features
min_attn_channels (int): minimum attention feature channels
rd_ratio (int, float): reduction factor for attention features
keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations
split_input (bool): split input channels evenly across each convolution branch, keeps param count lower,
can be viewed as grouping by path, output expands to module out_channels count
@ -75,7 +75,8 @@ class SelectiveKernelConv(nn.Module):
act_layer (nn.Module): activation layer to use
norm_layer (nn.Module): batchnorm/norm layer to use
"""
super(SelectiveKernelConv, self).__init__()
super(SelectiveKernel, self).__init__()
out_channels = out_channels or in_channels
kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation
_kernel_valid(kernel_size)
if not isinstance(kernel_size, list):
@ -101,7 +102,7 @@ class SelectiveKernelConv(nn.Module):
ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
for k, d in zip(kernel_size, dilation)])
attn_channels = max(int(out_channels / attn_reduction), min_attn_channels)
attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor)
self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
self.drop_block = drop_block

@ -10,6 +10,8 @@ import torch
import torch.nn.functional as F
from torch import nn
from .helpers import make_divisible
class RadixSoftmax(nn.Module):
def __init__(self, radix, cardinality):
@ -28,41 +30,37 @@ class RadixSoftmax(nn.Module):
return x
class SplitAttnConv2d(nn.Module):
"""Split-Attention Conv2d
class SplitAttn(nn.Module):
"""Split-Attention (aka Splat)
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=1, groups=1, bias=False, radix=2, reduction_factor=4,
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None,
dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs):
super(SplitAttnConv2d, self).__init__()
super(SplitAttn, self).__init__()
out_channels = out_channels or in_channels
self.radix = radix
self.drop_block = drop_block
mid_chs = out_channels * radix
attn_chs = max(in_channels * radix // reduction_factor, 32)
if rd_channels is None:
attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor)
else:
attn_chs = rd_channels * radix
padding = kernel_size // 2 if padding is None else padding
self.conv = nn.Conv2d(
in_channels, mid_chs, kernel_size, stride, padding, dilation,
groups=groups * radix, bias=bias, **kwargs)
self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None
self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity()
self.act0 = act_layer(inplace=True)
self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)
self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None
self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity()
self.act1 = act_layer(inplace=True)
self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)
self.rsoftmax = RadixSoftmax(radix, groups)
@property
def in_channels(self):
return self.conv.in_channels
@property
def out_channels(self):
return self.fc1.out_channels
def forward(self, x):
x = self.conv(x)
if self.bn0 is not None:
x = self.bn0(x)
x = self.bn0(x)
if self.drop_block is not None:
x = self.drop_block(x)
x = self.act0(x)
@ -73,10 +71,9 @@ class SplitAttnConv2d(nn.Module):
x_gap = x.sum(dim=1)
else:
x_gap = x
x_gap = F.adaptive_avg_pool2d(x_gap, 1)
x_gap = x_gap.mean((2, 3), keepdim=True)
x_gap = self.fc1(x_gap)
if self.bn1 is not None:
x_gap = self.bn1(x_gap)
x_gap = self.bn1(x_gap)
x_gap = self.act1(x_gap)
x_attn = self.fc2(x_gap)

@ -0,0 +1,74 @@
""" Squeeze-and-Excitation Channel Attention
An SE implementation originally based on PyTorch SE-Net impl.
Has since evolved with additional functionality / configuration.
Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507
Also included is Effective Squeeze-Excitation (ESE).
Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
Hacked together by / Copyright 2021 Ross Wightman
"""
from torch import nn as nn
from .create_act import create_act_layer
from .helpers import make_divisible
class SEModule(nn.Module):
""" SE Module as defined in original SE-Nets with a few additions
Additions include:
* divisor can be specified to keep channels % div == 0 (default: 8)
* reduction channels can be specified directly by arg (if rd_channels is set)
* reduction channels can be specified by float rd_ratio (default: 1/16)
* global max pooling can be added to the squeeze aggregation
* customizable activation, normalization, and gate layer
"""
def __init__(
self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False,
act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'):
super(SEModule, self).__init__()
self.add_maxpool = add_maxpool
if not rd_channels:
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True)
self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity()
self.act = create_act_layer(act_layer, inplace=True)
self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
if self.add_maxpool:
# experimental codepath, may remove or change
x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
x_se = self.fc1(x_se)
x_se = self.act(self.bn(x_se))
x_se = self.fc2(x_se)
return x * self.gate(x_se)
SqueezeExcite = SEModule # alias
class EffectiveSEModule(nn.Module):
""" 'Effective Squeeze-Excitation
From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
"""
def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_):
super(EffectiveSEModule, self).__init__()
self.add_maxpool = add_maxpool
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
if self.add_maxpool:
# experimental codepath, may remove or change
x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
x_se = self.fc(x_se)
return x * self.gate(x_se)
EffectiveSqueezeExcite = EffectiveSEModule # alias

@ -0,0 +1,570 @@
""" LeViT
Paper: `LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference`
- https://arxiv.org/abs/2104.01136
@article{graham2021levit,
title={LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference},
author={Benjamin Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Herv\'e J\'egou and Matthijs Douze},
journal={arXiv preprint arXiv:22104.01136},
year={2021}
}
Adapted from official impl at https://github.com/facebookresearch/LeViT, original copyright bellow.
This version combines both conv/linear models and fixes torchscript compatibility.
Modifications by/coyright Copyright 2021 Ross Wightman
"""
# Copyright (c) 2015-present, Facebook, Inc.
# All rights reserved.
# Modified from
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
# Copyright 2020 Ross Wightman, Apache-2.0 License
import itertools
from copy import deepcopy
from functools import partial
from typing import Dict
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import to_ntuple, get_act_layer
from .vision_transformer import trunc_normal_
from .registry import register_model
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embed.0.c', 'classifier': ('head.l', 'head_dist.l'),
**kwargs
}
default_cfgs = dict(
levit_128s=_cfg(
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128S-96703c44.pth'
),
levit_128=_cfg(
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-128-b88c2750.pth'
),
levit_192=_cfg(
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-192-92712e41.pth'
),
levit_256=_cfg(
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-256-13b5763e.pth'
),
levit_384=_cfg(
url='https://dl.fbaipublicfiles.com/LeViT/LeViT-384-9bdaf2e2.pth'
),
)
model_cfgs = dict(
levit_128s=dict(
embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 6, 8), depth=(2, 3, 4)),
levit_128=dict(
embed_dim=(128, 256, 384), key_dim=16, num_heads=(4, 8, 12), depth=(4, 4, 4)),
levit_192=dict(
embed_dim=(192, 288, 384), key_dim=32, num_heads=(3, 5, 6), depth=(4, 4, 4)),
levit_256=dict(
embed_dim=(256, 384, 512), key_dim=32, num_heads=(4, 6, 8), depth=(4, 4, 4)),
levit_384=dict(
embed_dim=(384, 512, 768), key_dim=32, num_heads=(6, 9, 12), depth=(4, 4, 4)),
)
__all__ = ['Levit']
@register_model
def levit_128s(pretrained=False, fuse=False,distillation=True, use_conv=False, **kwargs):
return create_levit(
'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_128(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
return create_levit(
'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_192(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
return create_levit(
'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_256(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
return create_levit(
'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_384(pretrained=False, fuse=False, distillation=True, use_conv=False, **kwargs):
return create_levit(
'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_128s(pretrained=False, fuse=False, distillation=True, use_conv=True,**kwargs):
return create_levit(
'levit_128s', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_128(pretrained=False, fuse=False,distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_128', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_192(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_192', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_256(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_256', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
@register_model
def levit_c_384(pretrained=False, fuse=False, distillation=True, use_conv=True, **kwargs):
return create_levit(
'levit_384', pretrained=pretrained, fuse=fuse, distillation=distillation, use_conv=use_conv, **kwargs)
class ConvNorm(nn.Sequential):
def __init__(
self, a, b, ks=1, stride=1, pad=0, dilation=1, groups=1, bn_weight_init=1, resolution=-10000):
super().__init__()
self.add_module('c', nn.Conv2d(a, b, ks, stride, pad, dilation, groups, bias=False))
bn = nn.BatchNorm2d(b)
nn.init.constant_(bn.weight, bn_weight_init)
nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
@torch.no_grad()
def fuse(self):
c, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
m = nn.Conv2d(
w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class LinearNorm(nn.Sequential):
def __init__(self, a, b, bn_weight_init=1, resolution=-100000):
super().__init__()
self.add_module('c', nn.Linear(a, b, bias=False))
bn = nn.BatchNorm1d(b)
nn.init.constant_(bn.weight, bn_weight_init)
nn.init.constant_(bn.bias, 0)
self.add_module('bn', bn)
@torch.no_grad()
def fuse(self):
l, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
w = l.weight * w[:, None]
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
m = nn.Linear(w.size(1), w.size(0))
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
def forward(self, x):
x = self.c(x)
return self.bn(x.flatten(0, 1)).reshape_as(x)
class NormLinear(nn.Sequential):
def __init__(self, a, b, bias=True, std=0.02):
super().__init__()
self.add_module('bn', nn.BatchNorm1d(a))
l = nn.Linear(a, b, bias=bias)
trunc_normal_(l.weight, std=std)
if bias:
nn.init.constant_(l.bias, 0)
self.add_module('l', l)
@torch.no_grad()
def fuse(self):
bn, l = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
b = bn.bias - self.bn.running_mean * self.bn.weight / (bn.running_var + bn.eps) ** 0.5
w = l.weight * w[None, :]
if l.bias is None:
b = b @ self.l.weight.T
else:
b = (l.weight @ b[:, None]).view(-1) + self.l.bias
m = nn.Linear(w.size(1), w.size(0))
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
def stem_b16(in_chs, out_chs, activation, resolution=224):
return nn.Sequential(
ConvNorm(in_chs, out_chs // 8, 3, 2, 1, resolution=resolution),
activation(),
ConvNorm(out_chs // 8, out_chs // 4, 3, 2, 1, resolution=resolution // 2),
activation(),
ConvNorm(out_chs // 4, out_chs // 2, 3, 2, 1, resolution=resolution // 4),
activation(),
ConvNorm(out_chs // 2, out_chs, 3, 2, 1, resolution=resolution // 8))
class Residual(nn.Module):
def __init__(self, m, drop):
super().__init__()
self.m = m
self.drop = drop
def forward(self, x):
if self.training and self.drop > 0:
return x + self.m(x) * torch.rand(
x.size(0), 1, 1, device=x.device).ge_(self.drop).div(1 - self.drop).detach()
else:
return x + self.m(x)
class Subsample(nn.Module):
def __init__(self, stride, resolution):
super().__init__()
self.stride = stride
self.resolution = resolution
def forward(self, x):
B, N, C = x.shape
x = x.view(B, self.resolution, self.resolution, C)[:, ::self.stride, ::self.stride]
return x.reshape(B, -1, C)
class Attention(nn.Module):
ab: Dict[str, torch.Tensor]
def __init__(
self, dim, key_dim, num_heads=8, attn_ratio=4, act_layer=None, resolution=14, use_conv=False):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = int(attn_ratio * key_dim) * num_heads
self.attn_ratio = attn_ratio
self.use_conv = use_conv
ln_layer = ConvNorm if self.use_conv else LinearNorm
h = self.dh + nh_kd * 2
self.qkv = ln_layer(dim, h, resolution=resolution)
self.proj = nn.Sequential(
act_layer(),
ln_layer(self.dh, dim, bn_weight_init=0, resolution=resolution))
points = list(itertools.product(range(resolution), range(resolution)))
N = len(points)
attention_offsets = {}
idxs = []
for p1 in points:
for p2 in points:
offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1]))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N, N))
self.ab = {}
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and self.ab:
self.ab = {} # clear ab cache
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
if self.training:
return self.attention_biases[:, self.attention_bias_idxs]
else:
device_key = str(device)
if device_key not in self.ab:
self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs]
return self.ab[device_key]
def forward(self, x): # x (B,C,H,W)
if self.use_conv:
B, C, H, W = x.shape
q, k, v = self.qkv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.key_dim, self.d], dim=2)
attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
else:
B, N, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.view(B, N, self.num_heads, -1).split([self.key_dim, self.key_dim, self.d], dim=3)
q = q.permute(0, 2, 1, 3)
k = k.permute(0, 2, 1, 3)
v = v.permute(0, 2, 1, 3)
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh)
x = self.proj(x)
return x
class AttentionSubsample(nn.Module):
ab: Dict[str, torch.Tensor]
def __init__(
self, in_dim, out_dim, key_dim, num_heads=8, attn_ratio=2,
act_layer=None, stride=2, resolution=14, resolution_=7, use_conv=False):
super().__init__()
self.num_heads = num_heads
self.scale = key_dim ** -0.5
self.key_dim = key_dim
self.nh_kd = nh_kd = key_dim * num_heads
self.d = int(attn_ratio * key_dim)
self.dh = self.d * self.num_heads
self.attn_ratio = attn_ratio
self.resolution_ = resolution_
self.resolution_2 = resolution_ ** 2
self.use_conv = use_conv
if self.use_conv:
ln_layer = ConvNorm
sub_layer = partial(nn.AvgPool2d, kernel_size=1, padding=0)
else:
ln_layer = LinearNorm
sub_layer = partial(Subsample, resolution=resolution)
h = self.dh + nh_kd
self.kv = ln_layer(in_dim, h, resolution=resolution)
self.q = nn.Sequential(
sub_layer(stride=stride),
ln_layer(in_dim, nh_kd, resolution=resolution_))
self.proj = nn.Sequential(
act_layer(),
ln_layer(self.dh, out_dim, resolution=resolution_))
self.stride = stride
self.resolution = resolution
points = list(itertools.product(range(resolution), range(resolution)))
points_ = list(itertools.product(range(resolution_), range(resolution_)))
N = len(points)
N_ = len(points_)
attention_offsets = {}
idxs = []
for p1 in points_:
for p2 in points:
size = 1
offset = (
abs(p1[0] * stride - p2[0] + (size - 1) / 2),
abs(p1[1] * stride - p2[1] + (size - 1) / 2))
if offset not in attention_offsets:
attention_offsets[offset] = len(attention_offsets)
idxs.append(attention_offsets[offset])
self.attention_biases = nn.Parameter(torch.zeros(num_heads, len(attention_offsets)))
self.register_buffer('attention_bias_idxs', torch.LongTensor(idxs).view(N_, N))
self.ab = {} # per-device attention_biases cache
@torch.no_grad()
def train(self, mode=True):
super().train(mode)
if mode and self.ab:
self.ab = {} # clear ab cache
def get_attention_biases(self, device: torch.device) -> torch.Tensor:
if self.training:
return self.attention_biases[:, self.attention_bias_idxs]
else:
device_key = str(device)
if device_key not in self.ab:
self.ab[device_key] = self.attention_biases[:, self.attention_bias_idxs]
return self.ab[device_key]
def forward(self, x):
if self.use_conv:
B, C, H, W = x.shape
k, v = self.kv(x).view(B, self.num_heads, -1, H * W).split([self.key_dim, self.d], dim=2)
q = self.q(x).view(B, self.num_heads, self.key_dim, self.resolution_2)
attn = (q.transpose(-2, -1) @ k) * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).reshape(B, -1, self.resolution_, self.resolution_)
else:
B, N, C = x.shape
k, v = self.kv(x).view(B, N, self.num_heads, -1).split([self.key_dim, self.d], dim=3)
k = k.permute(0, 2, 1, 3) # BHNC
v = v.permute(0, 2, 1, 3) # BHNC
q = self.q(x).view(B, self.resolution_2, self.num_heads, self.key_dim).permute(0, 2, 1, 3)
attn = q @ k.transpose(-2, -1) * self.scale + self.get_attention_biases(x.device)
attn = attn.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, -1, self.dh)
x = self.proj(x)
return x
class Levit(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=(192,),
key_dim=64,
depth=(12,),
num_heads=(3,),
attn_ratio=2,
mlp_ratio=2,
hybrid_backbone=None,
down_ops=None,
act_layer='hard_swish',
attn_act_layer='hard_swish',
distillation=True,
use_conv=False,
drop_path=0):
super().__init__()
act_layer = get_act_layer(act_layer)
attn_act_layer = get_act_layer(attn_act_layer)
if isinstance(img_size, tuple):
# FIXME origin impl passes single img/res dim through whole hierarchy,
# not sure this model will be used enough to spend time fixing it.
assert img_size[0] == img_size[1]
img_size = img_size[0]
self.num_classes = num_classes
self.num_features = embed_dim[-1]
self.embed_dim = embed_dim
N = len(embed_dim)
assert len(depth) == len(num_heads) == N
key_dim = to_ntuple(N)(key_dim)
attn_ratio = to_ntuple(N)(attn_ratio)
mlp_ratio = to_ntuple(N)(mlp_ratio)
down_ops = down_ops or (
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
('Subsample', key_dim[0], embed_dim[0] // key_dim[0], 4, 2, 2),
('Subsample', key_dim[0], embed_dim[1] // key_dim[1], 4, 2, 2),
('',)
)
self.distillation = distillation
self.use_conv = use_conv
ln_layer = ConvNorm if self.use_conv else LinearNorm
self.patch_embed = hybrid_backbone or stem_b16(in_chans, embed_dim[0], activation=act_layer)
self.blocks = []
resolution = img_size // patch_size
for i, (ed, kd, dpth, nh, ar, mr, do) in enumerate(
zip(embed_dim, key_dim, depth, num_heads, attn_ratio, mlp_ratio, down_ops)):
for _ in range(dpth):
self.blocks.append(
Residual(
Attention(
ed, kd, nh, attn_ratio=ar, act_layer=attn_act_layer,
resolution=resolution, use_conv=use_conv),
drop_path))
if mr > 0:
h = int(ed * mr)
self.blocks.append(
Residual(nn.Sequential(
ln_layer(ed, h, resolution=resolution),
act_layer(),
ln_layer(h, ed, bn_weight_init=0, resolution=resolution),
), drop_path))
if do[0] == 'Subsample':
# ('Subsample',key_dim, num_heads, attn_ratio, mlp_ratio, stride)
resolution_ = (resolution - 1) // do[5] + 1
self.blocks.append(
AttentionSubsample(
*embed_dim[i:i + 2], key_dim=do[1], num_heads=do[2],
attn_ratio=do[3], act_layer=attn_act_layer, stride=do[5],
resolution=resolution, resolution_=resolution_, use_conv=use_conv))
resolution = resolution_
if do[4] > 0: # mlp_ratio
h = int(embed_dim[i + 1] * do[4])
self.blocks.append(
Residual(nn.Sequential(
ln_layer(embed_dim[i + 1], h, resolution=resolution),
act_layer(),
ln_layer(h, embed_dim[i + 1], bn_weight_init=0, resolution=resolution),
), drop_path))
self.blocks = nn.Sequential(*self.blocks)
# Classifier head
self.head = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
if distillation:
self.head_dist = NormLinear(embed_dim[-1], num_classes) if num_classes > 0 else nn.Identity()
else:
self.head_dist = None
@torch.jit.ignore
def no_weight_decay(self):
return {x for x in self.state_dict().keys() if 'attention_biases' in x}
def forward(self, x):
x = self.patch_embed(x)
if not self.use_conv:
x = x.flatten(2).transpose(1, 2)
x = self.blocks(x)
x = x.mean((-2, -1)) if self.use_conv else x.mean(1)
if self.head_dist is not None:
x, x_dist = self.head(x), self.head_dist(x)
if self.training and not torch.jit.is_scripting():
return x, x_dist
else:
# during inference, return the average of both classifier predictions
return (x + x_dist) / 2
else:
x = self.head(x)
return x
def checkpoint_filter_fn(state_dict, model):
if 'model' in state_dict:
# For deit models
state_dict = state_dict['model']
D = model.state_dict()
for k in state_dict.keys():
if D[k].ndim == 4 and state_dict[k].ndim == 2:
state_dict[k] = state_dict[k][:, :, None, None]
return state_dict
def create_levit(variant, pretrained=False, default_cfg=None, fuse=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model_cfg = dict(**model_cfgs[variant], **kwargs)
model = build_model_with_cfg(
Levit, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
**model_cfg)
#if fuse:
# utils.replace_batchnorm(model)
return model

@ -80,6 +80,15 @@ default_cfgs = dict(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_mixer_l16_224_in21k-846aa33c.pth',
num_classes=21843
),
# Mixer ImageNet-21K-P pretraining
mixer_b16_224_miil_in21k=_cfg(
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil_in21k.pth',
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear', num_classes=11221,
),
mixer_b16_224_miil=_cfg(
url='https://miil-public-eu.oss-eu-central-1.aliyuncs.com/model-zoo/ImageNet_21K_P/models/timm/mixer_b16_224_miil.pth',
mean=(0, 0, 0), std=(1, 1, 1), crop_pct=0.875, interpolation='bilinear',
),
gmixer_12_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
gmixer_24_224=_cfg(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD),
@ -264,25 +273,14 @@ def _init_weights(m, n: str, head_bias: float = 0.):
nn.init.ones_(m.weight)
def _create_mixer(variant, pretrained=False, default_cfg=None, **kwargs):
if default_cfg is None:
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
def _create_mixer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for MLP-Mixer models.')
model = build_model_with_cfg(
MlpMixer, variant, pretrained,
default_cfg=default_cfg,
img_size=img_size,
num_classes=num_classes,
default_cfg=default_cfgs[variant],
**kwargs)
return model
@ -365,6 +363,23 @@ def mixer_l16_224_in21k(pretrained=False, **kwargs):
model = _create_mixer('mixer_l16_224_in21k', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_b16_224_miil(pretrained=False, **kwargs):
""" Mixer-B/16 224x224. ImageNet-21k pretrained weights.
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
"""
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs)
model = _create_mixer('mixer_b16_224_miil', pretrained=pretrained, **model_args)
return model
@register_model
def mixer_b16_224_miil_in21k(pretrained=False, **kwargs):
""" Mixer-B/16 224x224. ImageNet-1k pretrained weights.
Weights taken from: https://github.com/Alibaba-MIIL/ImageNet21K
"""
model_args = dict(patch_size=16, num_blocks=12, hidden_dim=768, **kwargs)
model = _create_mixer('mixer_b16_224_miil_in21k', pretrained=pretrained, **model_args)
return model
@register_model
def gmixer_12_224(pretrained=False, **kwargs):

@ -72,6 +72,10 @@ default_cfgs = {
'tf_mobilenetv3_small_minimal_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'fbnetv3_b': _cfg(),
'fbnetv3_d': _cfg(),
'fbnetv3_g': _cfg(),
}
@ -86,7 +90,7 @@ class MobileNetV3(nn.Module):
"""
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
pad_type='', act_layer=None, norm_layer=None, se_layer=None,
pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True,
round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'):
super(MobileNetV3, self).__init__()
act_layer = act_layer or nn.ReLU
@ -104,7 +108,7 @@ class MobileNetV3(nn.Module):
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
output_stride=32, pad_type=pad_type, round_chs_fn=round_chs_fn,
output_stride=32, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp,
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate)
self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = builder.features
@ -161,8 +165,8 @@ class MobileNetV3Features(nn.Module):
and object detection models.
"""
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
in_chans=3, stem_size=16, output_stride=32, pad_type='', round_chs_fn=round_channels,
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,
stem_size=16, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=True,
act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
super(MobileNetV3Features, self).__init__()
act_layer = act_layer or nn.ReLU
@ -178,7 +182,7 @@ class MobileNetV3Features(nn.Module):
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn,
output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp,
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer,
drop_path_rate=drop_path_rate, feature_location=feature_location)
self.blocks = nn.Sequential(*builder(stem_size, block_args))
@ -262,7 +266,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
se_layer=partial(SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), reduce_from_block=False),
se_layer=partial(SqueezeExcite, gate_layer='hard_sigmoid'),
**kwargs,
)
model = _create_mnv3(variant, pretrained, **model_kwargs)
@ -350,8 +354,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
# stage 6, 7x7 in
['cn_r1_k1_s1_c960'], # hard-swish
]
se_layer = partial(
SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, reduce_from_block=False, divisor=8)
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
num_features=num_features,
@ -366,6 +369,67 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
return model
def _gen_fbnetv3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
""" FBNetV3
Paper: `FBNetV3: Joint Architecture-Recipe Search using Predictor Pretraining`
- https://arxiv.org/abs/2006.02049
FIXME untested, this is a preliminary impl of some FBNet-V3 variants.
"""
vl = variant.split('_')[-1]
if vl in ('a', 'b'):
stem_size = 16
arch_def = [
['ds_r2_k3_s1_e1_c16'],
['ir_r1_k5_s2_e4_c24', 'ir_r3_k5_s1_e2_c24'],
['ir_r1_k5_s2_e5_c40_se0.25', 'ir_r4_k5_s1_e3_c40_se0.25'],
['ir_r1_k5_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'],
['ir_r1_k3_s1_e5_c120_se0.25', 'ir_r5_k5_s1_e3_c120_se0.25'],
['ir_r1_k3_s2_e6_c184_se0.25', 'ir_r5_k5_s1_e4_c184_se0.25', 'ir_r1_k5_s1_e6_c224_se0.25'],
['cn_r1_k1_s1_c1344'],
]
elif vl == 'd':
stem_size = 24
arch_def = [
['ds_r2_k3_s1_e1_c16'],
['ir_r1_k3_s2_e5_c24', 'ir_r5_k3_s1_e2_c24'],
['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r4_k3_s1_e3_c40_se0.25'],
['ir_r1_k3_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'],
['ir_r1_k3_s1_e5_c128_se0.25', 'ir_r6_k5_s1_e3_c128_se0.25'],
['ir_r1_k3_s2_e6_c208_se0.25', 'ir_r5_k5_s1_e5_c208_se0.25', 'ir_r1_k5_s1_e6_c240_se0.25'],
['cn_r1_k1_s1_c1440'],
]
elif vl == 'g':
stem_size = 32
arch_def = [
['ds_r3_k3_s1_e1_c24'],
['ir_r1_k5_s2_e4_c40', 'ir_r4_k5_s1_e2_c40'],
['ir_r1_k5_s2_e4_c56_se0.25', 'ir_r4_k5_s1_e3_c56_se0.25'],
['ir_r1_k5_s2_e5_c104', 'ir_r4_k3_s1_e3_c104'],
['ir_r1_k3_s1_e5_c160_se0.25', 'ir_r8_k5_s1_e3_c160_se0.25'],
['ir_r1_k3_s2_e6_c264_se0.25', 'ir_r6_k5_s1_e5_c264_se0.25', 'ir_r2_k5_s1_e6_c288_se0.25'],
['cn_r1_k1_s1_c1728'],
]
else:
raise NotImplemented
round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.95)
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=round_chs_fn)
act_layer = resolve_act_layer(kwargs, 'hard_swish')
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
num_features=1984,
head_bias=False,
stem_size=stem_size,
round_chs_fn=round_chs_fn,
se_from_exp=False,
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=act_layer,
se_layer=se_layer,
**kwargs,
)
model = _create_mnv3(variant, pretrained, **model_kwargs)
return model
@register_model
def mobilenetv3_large_075(pretrained=False, **kwargs):
""" MobileNet V3 """
@ -474,3 +538,24 @@ def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
kwargs['pad_type'] = 'same'
model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def fbnetv3_b(pretrained=False, **kwargs):
""" FBNetV3-B """
model = _gen_fbnetv3('fbnetv3_b', pretrained=pretrained, **kwargs)
return model
@register_model
def fbnetv3_d(pretrained=False, **kwargs):
""" FBNetV3-D """
model = _gen_fbnetv3('fbnetv3_d', pretrained=pretrained, **kwargs)
return model
@register_model
def fbnetv3_g(pretrained=False, **kwargs):
""" FBNetV3-G """
model = _gen_fbnetv3('fbnetv3_g', pretrained=pretrained, **kwargs)
return model

@ -110,6 +110,12 @@ default_cfgs = dict(
eca_nfnet_l1=_dcfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ecanfnet_l1_ra2-7dce93cd.pth',
pool_size=(8, 8), input_size=(3, 256, 256), test_input_size=(3, 320, 320), crop_pct=1.0),
eca_nfnet_l2=_dcfg(
url='',
pool_size=(9, 9), input_size=(3, 288, 288), test_input_size=(3, 352, 352), crop_pct=1.0),
eca_nfnet_l3=_dcfg(
url='',
pool_size=(10, 10), input_size=(3, 320, 320), test_input_size=(3, 384, 384), crop_pct=1.0),
nf_regnet_b0=_dcfg(
url='', pool_size=(6, 6), input_size=(3, 192, 192), test_input_size=(3, 256, 256), first_conv='stem.conv'),
@ -176,7 +182,7 @@ def _nfres_cfg(
def _nfreg_cfg(depths, channels=(48, 104, 208, 440)):
num_features = 1280 * channels[-1] // 440
attn_kwargs = dict(reduction_ratio=0.5, divisor=8)
attn_kwargs = dict(rd_ratio=0.5)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25,
num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs)
@ -187,7 +193,7 @@ def _nfnet_cfg(
depths, channels=(256, 512, 1536, 1536), group_size=128, bottle_ratio=0.5, feat_mult=2.,
act_layer='gelu', attn_layer='se', attn_kwargs=None):
num_features = int(channels[-1] * feat_mult)
attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(reduction_ratio=0.5, divisor=8)
attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(rd_ratio=0.5)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=group_size,
bottle_ratio=bottle_ratio, extra_conv=True, num_features=num_features, act_layer=act_layer,
@ -196,11 +202,10 @@ def _nfnet_cfg(
def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', skipinit=True):
attn_kwargs = dict(reduction_ratio=0.5, divisor=8)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=128,
bottle_ratio=0.5, extra_conv=True, gamma_in_act=True, same_padding=True, skipinit=skipinit,
num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=attn_kwargs)
num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=dict(rd_ratio=0.5))
return cfg
@ -237,13 +242,19 @@ model_cfgs = dict(
# Experimental 'light' versions of NFNet-F that are little leaner
nfnet_l0=_nfnet_cfg(
depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
attn_kwargs=dict(reduction_ratio=0.25, divisor=8), act_layer='silu'),
attn_kwargs=dict(rd_ratio=0.25, rd_divisor=8), act_layer='silu'),
eca_nfnet_l0=_nfnet_cfg(
depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
eca_nfnet_l1=_nfnet_cfg(
depths=(2, 4, 12, 6), feat_mult=2, group_size=64, bottle_ratio=0.25,
attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
eca_nfnet_l2=_nfnet_cfg(
depths=(3, 6, 18, 9), feat_mult=2, group_size=64, bottle_ratio=0.25,
attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
eca_nfnet_l3=_nfnet_cfg(
depths=(4, 8, 24, 12), feat_mult=2, group_size=64, bottle_ratio=0.25,
attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
# EffNet influenced RegNet defs.
# NOTE: These aren't quite the official ver, ch_div=1 must be set for exact ch counts. I round to ch_div=8.
@ -260,9 +271,9 @@ model_cfgs = dict(
nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)),
nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)),
nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)),
nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)),
nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)),
nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()),
nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()),
@ -814,6 +825,22 @@ def eca_nfnet_l1(pretrained=False, **kwargs):
return _create_normfreenet('eca_nfnet_l1', pretrained=pretrained, **kwargs)
@register_model
def eca_nfnet_l2(pretrained=False, **kwargs):
""" ECA-NFNet-L2 w/ SiLU
My experimental 'light' model w/ F2 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn
"""
return _create_normfreenet('eca_nfnet_l2', pretrained=pretrained, **kwargs)
@register_model
def eca_nfnet_l3(pretrained=False, **kwargs):
""" ECA-NFNet-L3 w/ SiLU
My experimental 'light' model w/ F3 repeats, 2.0x final_conv mult, 64 group_size, .25 bottleneck & ECA attn
"""
return _create_normfreenet('eca_nfnet_l3', pretrained=pretrained, **kwargs)
@register_model
def nf_regnet_b0(pretrained=False, **kwargs):
""" Normalization-Free RegNet-B0

@ -251,24 +251,14 @@ def checkpoint_filter_fn(state_dict, model):
def _create_pit(variant, pretrained=False, **kwargs):
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
img_size = kwargs.pop('img_size', default_img_size)
num_classes = kwargs.pop('num_classes', default_num_classes)
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
PoolingVisionTransformer, variant, pretrained,
default_cfg=default_cfg,
img_size=img_size,
num_classes=num_classes,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model

@ -50,7 +50,7 @@ def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
def list_models(filter='', module='', pretrained=False, exclude_filters=''):
def list_models(filter='', module='', pretrained=False, exclude_filters='', name_matches_cfg=False):
""" Return list of available model names, sorted alphabetically
Args:
@ -58,6 +58,7 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''):
module (str) - Limit model selection to a specific sub-module (ie 'gen_efficientnet')
pretrained (bool) - Include only models with pretrained weights if True
exclude_filters (str or list[str]) - Wildcard filters to exclude models after including them with filter
name_matches_cfg (bool) - Include only models w/ model_name matching default_cfg name (excludes some aliases)
Example:
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
@ -70,7 +71,7 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''):
if filter:
models = fnmatch.filter(models, filter) # include these models
if exclude_filters:
if not isinstance(exclude_filters, list):
if not isinstance(exclude_filters, (tuple, list)):
exclude_filters = [exclude_filters]
for xf in exclude_filters:
exclude_models = fnmatch.filter(models, xf) # exclude these models
@ -78,6 +79,8 @@ def list_models(filter='', module='', pretrained=False, exclude_filters=''):
models = set(models).difference(exclude_models)
if pretrained:
models = _model_has_pretrained.intersection(models)
if name_matches_cfg:
models = set(_model_default_cfgs).intersection(models)
return list(sorted(models, key=_natural_key))

@ -146,7 +146,7 @@ class Bottleneck(nn.Module):
groups=groups, **cargs)
if se_ratio:
se_channels = int(round(in_chs * se_ratio))
self.se = SEModule(bottleneck_chs, reduction_channels=se_channels)
self.se = SEModule(bottleneck_chs, rd_channels=se_channels)
else:
self.se = None
cargs['act_layer'] = None

@ -11,7 +11,7 @@ from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import SplitAttnConv2d
from .layers import SplitAttn
from .registry import register_model
from .resnet import ResNet
@ -83,11 +83,11 @@ class ResNestBottleneck(nn.Module):
self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None
if self.radix >= 1:
self.conv2 = SplitAttnConv2d(
self.conv2 = SplitAttn(
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block)
self.bn2 = None # FIXME revisit, here to satisfy current torchscript fussyness
self.act2 = None
self.bn2 = nn.Identity()
self.act2 = nn.Identity()
else:
self.conv2 = nn.Conv2d(
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
@ -117,11 +117,10 @@ class ResNestBottleneck(nn.Module):
out = self.avd_first(out)
out = self.conv2(out)
if self.bn2 is not None:
out = self.bn2(out)
if self.drop_block is not None:
out = self.drop_block(out)
out = self.act2(out)
out = self.bn2(out)
if self.drop_block is not None:
out = self.drop_block(out)
out = self.act2(out)
if self.avd_last is not None:
out = self.avd_last(out)

@ -1122,7 +1122,7 @@ def resnetrs50(pretrained=False, **kwargs):
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
@ -1135,7 +1135,7 @@ def resnetrs101(pretrained=False, **kwargs):
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
@ -1148,7 +1148,7 @@ def resnetrs152(pretrained=False, **kwargs):
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
@ -1161,7 +1161,7 @@ def resnetrs200(pretrained=False, **kwargs):
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
@ -1174,7 +1174,7 @@ def resnetrs270(pretrained=False, **kwargs):
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
@ -1188,7 +1188,7 @@ def resnetrs350(pretrained=False, **kwargs):
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
@ -1201,7 +1201,7 @@ def resnetrs420(pretrained=False, **kwargs):
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)

@ -11,11 +11,12 @@ Copyright 2020 Ross Wightman
"""
import torch.nn as nn
from functools import partial
from math import ceil
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath, make_divisible
from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath, make_divisible, SEModule
from .registry import register_model
from .efficientnet_builder import efficientnet_init_weights
@ -48,26 +49,7 @@ default_cfgs = dict(
url=''),
)
class SEWithNorm(nn.Module):
def __init__(self, channels, se_ratio=1 / 12., act_layer=nn.ReLU, divisor=1, reduction_channels=None,
gate_layer='sigmoid'):
super(SEWithNorm, self).__init__()
reduction_channels = reduction_channels or make_divisible(int(channels * se_ratio), divisor=divisor)
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
self.bn = nn.BatchNorm2d(reduction_channels)
self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
x_se = self.fc1(x_se)
x_se = self.bn(x_se)
x_se = self.act(x_se)
x_se = self.fc2(x_se)
return x * self.gate(x_se)
SEWithNorm = partial(SEModule, norm_layer=nn.BatchNorm2d)
class LinearBottleneck(nn.Module):
@ -86,7 +68,10 @@ class LinearBottleneck(nn.Module):
self.conv_exp = None
self.conv_dw = ConvBnAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False)
self.se = SEWithNorm(dw_chs, se_ratio=se_ratio, divisor=ch_div) if se_ratio > 0. else None
if se_ratio > 0:
self.se = SEWithNorm(dw_chs, rd_channels=make_divisible(int(dw_chs * se_ratio), ch_div))
else:
self.se = None
self.act_dw = create_act_layer(dw_act_layer)
self.conv_pwl = ConvBnAct(dw_chs, out_chs, 1, apply_act=False)

@ -14,7 +14,7 @@ from torch import nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import SelectiveKernelConv, ConvBnAct, create_attn
from .layers import SelectiveKernel, ConvBnAct, create_attn
from .registry import register_model
from .resnet import ResNet
@ -59,7 +59,7 @@ class SelectiveKernelBasic(nn.Module):
outplanes = planes * self.expansion
first_dilation = first_dilation or dilation
self.conv1 = SelectiveKernelConv(
self.conv1 = SelectiveKernel(
inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs)
conv_kwargs['act_layer'] = None
self.conv2 = ConvBnAct(
@ -107,7 +107,7 @@ class SelectiveKernelBottleneck(nn.Module):
first_dilation = first_dilation or dilation
self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs)
self.conv2 = SelectiveKernelConv(
self.conv2 = SelectiveKernel(
first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality,
**conv_kwargs, **sk_kwargs)
conv_kwargs['act_layer'] = None
@ -153,10 +153,7 @@ def skresnet18(pretrained=False, **kwargs):
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
variation splits the input channels to the selective convolutions to keep param count down.
"""
sk_kwargs = dict(
min_attn_channels=16,
attn_reduction=8,
split_input=True)
sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True)
model_args = dict(
block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs),
zero_init_last_bn=False, **kwargs)
@ -170,10 +167,7 @@ def skresnet34(pretrained=False, **kwargs):
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
variation splits the input channels to the selective convolutions to keep param count down.
"""
sk_kwargs = dict(
min_attn_channels=16,
attn_reduction=8,
split_input=True)
sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True)
model_args = dict(
block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs),
zero_init_last_bn=False, **kwargs)
@ -213,8 +207,9 @@ def skresnext50_32x4d(pretrained=False, **kwargs):
"""Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to
the SKNet-50 model in the Select Kernel Paper
"""
sk_kwargs = dict(rd_ratio=1/16, rd_divisor=32, split_input=False)
model_args = dict(
block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
zero_init_last_bn=False, **kwargs)
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
return _create_skresnet('skresnext50_32x4d', pretrained, **model_args)

@ -12,9 +12,11 @@ import torch.nn as nn
from functools import partial
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.helpers import load_pretrained
from timm.models.helpers import build_model_with_cfg
from timm.models.layers import Mlp, DropPath, trunc_normal_
from timm.models.layers.helpers import to_2tuple
from timm.models.registry import register_model
from timm.models.vision_transformer import resize_pos_embed
def _cfg(url='', **kwargs):
@ -118,11 +120,15 @@ class PixelEmbed(nn.Module):
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, in_dim=48, stride=4):
super().__init__()
num_patches = (img_size // patch_size) ** 2
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
# grid_size property necessary for resizing positional embedding
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
num_patches = (self.grid_size[0]) * (self.grid_size[1])
self.img_size = img_size
self.num_patches = num_patches
self.in_dim = in_dim
new_patch_size = math.ceil(patch_size / stride)
new_patch_size = [math.ceil(ps / stride) for ps in patch_size]
self.new_patch_size = new_patch_size
self.proj = nn.Conv2d(in_chans, self.in_dim, kernel_size=7, padding=3, stride=stride)
@ -130,11 +136,11 @@ class PixelEmbed(nn.Module):
def forward(self, x, pixel_pos):
B, C, H, W = x.shape
assert H == self.img_size and W == self.img_size, \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size}*{self.img_size})."
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
x = self.unfold(x)
x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size, self.new_patch_size)
x = x.transpose(1, 2).reshape(B * self.num_patches, self.in_dim, self.new_patch_size[0], self.new_patch_size[1])
x = x + pixel_pos
x = x.reshape(B * self.num_patches, self.in_dim, -1).transpose(1, 2)
return x
@ -155,7 +161,7 @@ class TNT(nn.Module):
num_patches = self.pixel_embed.num_patches
self.num_patches = num_patches
new_patch_size = self.pixel_embed.new_patch_size
num_pixel = new_patch_size ** 2
num_pixel = new_patch_size[0] * new_patch_size[1]
self.norm1_proj = norm_layer(num_pixel * in_dim)
self.proj = nn.Linear(num_pixel * in_dim, embed_dim)
@ -163,7 +169,7 @@ class TNT(nn.Module):
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.patch_pos = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size, new_patch_size))
self.pixel_pos = nn.Parameter(torch.zeros(1, in_dim, new_patch_size[0], new_patch_size[1]))
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
@ -224,23 +230,39 @@ class TNT(nn.Module):
return x
def checkpoint_filter_fn(state_dict, model):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
if state_dict['patch_pos'].shape != model.patch_pos.shape:
state_dict['patch_pos'] = resize_pos_embed(state_dict['patch_pos'],
model.patch_pos, getattr(model, 'num_tokens', 1), model.pixel_embed.grid_size)
return state_dict
def _create_tnt(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
TNT, variant, pretrained,
default_cfg=default_cfgs[variant],
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model
@register_model
def tnt_s_patch16_224(pretrained=False, **kwargs):
model = TNT(patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4,
model_cfg = dict(
patch_size=16, embed_dim=384, in_dim=24, depth=12, num_heads=6, in_num_head=4,
qkv_bias=False, **kwargs)
model.default_cfg = default_cfgs['tnt_s_patch16_224']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model = _create_tnt('tnt_s_patch16_224', pretrained=pretrained, **model_cfg)
return model
@register_model
def tnt_b_patch16_224(pretrained=False, **kwargs):
model = TNT(patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4,
model_cfg = dict(
patch_size=16, embed_dim=640, in_dim=40, depth=12, num_heads=10, in_num_head=4,
qkv_bias=False, **kwargs)
model.default_cfg = default_cfgs['tnt_b_patch16_224']
if pretrained:
load_pretrained(
model, num_classes=model.num_classes, in_chans=kwargs.get('in_chans', 3))
model = _create_tnt('tnt_b_patch16_224', pretrained=pretrained, **model_cfg)
return model

@ -84,8 +84,8 @@ class BasicBlock(nn.Module):
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
reduction_chs = max(planes * self.expansion // 4, 64)
self.se = SEModule(planes * self.expansion, reduction_channels=reduction_chs) if use_se else None
rd_chs = max(planes * self.expansion // 4, 64)
self.se = SEModule(planes * self.expansion, rd_channels=rd_chs) if use_se else None
def forward(self, x):
if self.downsample is not None:
@ -125,7 +125,7 @@ class Bottleneck(nn.Module):
aa_layer(channels=planes, filt_size=3, stride=2))
reduction_chs = max(planes * self.expansion // 8, 64)
self.se = SEModule(planes, reduction_channels=reduction_chs) if use_se else None
self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None
self.conv3 = conv2d_iabn(
planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity")

@ -0,0 +1,420 @@
""" Twins
A PyTorch impl of : `Twins: Revisiting the Design of Spatial Attention in Vision Transformers`
- https://arxiv.org/pdf/2104.13840.pdf
Code/weights from https://github.com/Meituan-AutoML/Twins, original copyright/license info below
"""
# --------------------------------------------------------
# Twins
# Copyright (c) 2021 Meituan
# Licensed under The Apache 2.0 License [see LICENSE for details]
# Written by Xinjie Li, Xiangxiang Chu
# --------------------------------------------------------
import math
from copy import deepcopy
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .layers import Mlp, DropPath, to_2tuple, trunc_normal_
from .registry import register_model
from .vision_transformer import Attention
from .helpers import build_model_with_cfg, overlay_external_default_cfg
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'patch_embeds.0.proj', 'classifier': 'head',
**kwargs
}
default_cfgs = {
'twins_pcpvt_small': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_small-e70e7e7a.pth',
),
'twins_pcpvt_base': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_base-e5ecb09b.pth',
),
'twins_pcpvt_large': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_pcpvt_large-d273f802.pth',
),
'twins_svt_small': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_small-42e5f78c.pth',
),
'twins_svt_base': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_base-c2265010.pth',
),
'twins_svt_large': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/twins_svt_large-90f6aaa9.pth',
),
}
Size_ = Tuple[int, int]
class LocallyGroupedAttn(nn.Module):
""" LSA: self attention within a group
"""
def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., ws=1):
assert ws != 1
super(LocallyGroupedAttn, self).__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=True)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.ws = ws
def forward(self, x, size: Size_):
# There are two implementations for this function, zero padding or mask. We don't observe obvious difference for
# both. You can choose any one, we recommend forward_padding because it's neat. However,
# the masking implementation is more reasonable and accurate.
B, N, C = x.shape
H, W = size
x = x.view(B, H, W, C)
pad_l = pad_t = 0
pad_r = (self.ws - W % self.ws) % self.ws
pad_b = (self.ws - H % self.ws) % self.ws
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
_, Hp, Wp, _ = x.shape
_h, _w = Hp // self.ws, Wp // self.ws
x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3)
qkv = self.qkv(x).reshape(
B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C)
x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C)
if pad_r > 0 or pad_b > 0:
x = x[:, :H, :W, :].contiguous()
x = x.reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
# def forward_mask(self, x, size: Size_):
# B, N, C = x.shape
# H, W = size
# x = x.view(B, H, W, C)
# pad_l = pad_t = 0
# pad_r = (self.ws - W % self.ws) % self.ws
# pad_b = (self.ws - H % self.ws) % self.ws
# x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
# _, Hp, Wp, _ = x.shape
# _h, _w = Hp // self.ws, Wp // self.ws
# mask = torch.zeros((1, Hp, Wp), device=x.device)
# mask[:, -pad_b:, :].fill_(1)
# mask[:, :, -pad_r:].fill_(1)
#
# x = x.reshape(B, _h, self.ws, _w, self.ws, C).transpose(2, 3) # B, _h, _w, ws, ws, C
# mask = mask.reshape(1, _h, self.ws, _w, self.ws).transpose(2, 3).reshape(1, _h * _w, self.ws * self.ws)
# attn_mask = mask.unsqueeze(2) - mask.unsqueeze(3) # 1, _h*_w, ws*ws, ws*ws
# attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-1000.0)).masked_fill(attn_mask == 0, float(0.0))
# qkv = self.qkv(x).reshape(
# B, _h * _w, self.ws * self.ws, 3, self.num_heads, C // self.num_heads).permute(3, 0, 1, 4, 2, 5)
# # n_h, B, _w*_h, nhead, ws*ws, dim
# q, k, v = qkv[0], qkv[1], qkv[2] # B, _h*_w, n_head, ws*ws, dim_head
# attn = (q @ k.transpose(-2, -1)) * self.scale # B, _h*_w, n_head, ws*ws, ws*ws
# attn = attn + attn_mask.unsqueeze(2)
# attn = attn.softmax(dim=-1)
# attn = self.attn_drop(attn) # attn @v -> B, _h*_w, n_head, ws*ws, dim_head
# attn = (attn @ v).transpose(2, 3).reshape(B, _h, _w, self.ws, self.ws, C)
# x = attn.transpose(2, 3).reshape(B, _h * self.ws, _w * self.ws, C)
# if pad_r > 0 or pad_b > 0:
# x = x[:, :H, :W, :].contiguous()
# x = x.reshape(B, N, C)
# x = self.proj(x)
# x = self.proj_drop(x)
# return x
class GlobalSubSampleAttn(nn.Module):
""" GSA: using a key to summarize the information for a group to be efficient.
"""
def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0., sr_ratio=1):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=True)
self.kv = nn.Linear(dim, dim * 2, bias=True)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
else:
self.sr = None
self.norm = None
def forward(self, x, size: Size_):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
if self.sr is not None:
x = x.permute(0, 2, 1).reshape(B, C, *size)
x = self.sr(x).reshape(B, C, -1).permute(0, 2, 1)
x = self.norm(x)
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1, ws=None):
super().__init__()
self.norm1 = norm_layer(dim)
if ws is None:
self.attn = Attention(dim, num_heads, False, None, attn_drop, drop)
elif ws == 1:
self.attn = GlobalSubSampleAttn(dim, num_heads, attn_drop, drop, sr_ratio)
else:
self.attn = LocallyGroupedAttn(dim, num_heads, attn_drop, drop, ws)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
def forward(self, x, size: Size_):
x = x + self.drop_path(self.attn(self.norm1(x), size))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class PosConv(nn.Module):
# PEG from https://arxiv.org/abs/2102.10882
def __init__(self, in_chans, embed_dim=768, stride=1):
super(PosConv, self).__init__()
self.proj = nn.Sequential(nn.Conv2d(in_chans, embed_dim, 3, stride, 1, bias=True, groups=embed_dim), )
self.stride = stride
def forward(self, x, size: Size_):
B, N, C = x.shape
cnn_feat_token = x.transpose(1, 2).view(B, C, *size)
x = self.proj(cnn_feat_token)
if self.stride == 1:
x += cnn_feat_token
x = x.flatten(2).transpose(1, 2)
return x
def no_weight_decay(self):
return ['proj.%d.weight' % i for i in range(4)]
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
assert img_size[0] % patch_size[0] == 0 and img_size[1] % patch_size[1] == 0, \
f"img_size {img_size} should be divided by patch_size {patch_size}."
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x) -> Tuple[torch.Tensor, Size_]:
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2)
x = self.norm(x)
out_size = (H // self.patch_size[0], W // self.patch_size[1])
return x, out_size
class Twins(nn.Module):
""" Twins Vision Transfomer (Revisiting Spatial Attention)
Adapted from PVT (PyramidVisionTransformer) class at https://github.com/whai362/PVT.git
"""
def __init__(
self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dims=(64, 128, 256, 512),
num_heads=(1, 2, 4, 8), mlp_ratios=(4, 4, 4, 4), drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=(3, 4, 6, 3), sr_ratios=(8, 4, 2, 1), wss=None,
block_cls=Block):
super().__init__()
self.num_classes = num_classes
self.depths = depths
img_size = to_2tuple(img_size)
prev_chs = in_chans
self.patch_embeds = nn.ModuleList()
self.pos_drops = nn.ModuleList()
for i in range(len(depths)):
self.patch_embeds.append(PatchEmbed(img_size, patch_size, prev_chs, embed_dims[i]))
self.pos_drops.append(nn.Dropout(p=drop_rate))
prev_chs = embed_dims[i]
img_size = tuple(t // patch_size for t in img_size)
patch_size = 2
self.blocks = nn.ModuleList()
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
cur = 0
for k in range(len(depths)):
_block = nn.ModuleList([block_cls(
dim=embed_dims[k], num_heads=num_heads[k], mlp_ratio=mlp_ratios[k], drop=drop_rate,
attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer, sr_ratio=sr_ratios[k],
ws=1 if wss is None or i % 2 == 1 else wss[k]) for i in range(depths[k])])
self.blocks.append(_block)
cur += depths[k]
self.pos_block = nn.ModuleList([PosConv(embed_dim, embed_dim) for embed_dim in embed_dims])
self.norm = norm_layer(embed_dims[-1])
# classification head
self.head = nn.Linear(embed_dims[-1], num_classes) if num_classes > 0 else nn.Identity()
# init weights
self.apply(self._init_weights)
@torch.jit.ignore
def no_weight_decay(self):
return set(['pos_block.' + n for n, p in self.pos_block.named_parameters()])
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
def forward_features(self, x):
B = x.shape[0]
for i, (embed, drop, blocks, pos_blk) in enumerate(
zip(self.patch_embeds, self.pos_drops, self.blocks, self.pos_block)):
x, size = embed(x)
x = drop(x)
for j, blk in enumerate(blocks):
x = blk(x, size)
if j == 0:
x = pos_blk(x, size) # PEG here
if i < len(self.depths) - 1:
x = x.reshape(B, *size, -1).permute(0, 3, 1, 2).contiguous()
x = self.norm(x)
return x.mean(dim=1) # GAP here
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _create_twins(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
Twins, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
return model
@register_model
def twins_pcpvt_small(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs)
return _create_twins('twins_pcpvt_small', pretrained=pretrained, **model_kwargs)
@register_model
def twins_pcpvt_base(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1], **kwargs)
return _create_twins('twins_pcpvt_base', pretrained=pretrained, **model_kwargs)
@register_model
def twins_pcpvt_large(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4],
depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1], **kwargs)
return _create_twins('twins_pcpvt_large', pretrained=pretrained, **model_kwargs)
@register_model
def twins_svt_small(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=4, embed_dims=[64, 128, 256, 512], num_heads=[2, 4, 8, 16], mlp_ratios=[4, 4, 4, 4],
depths=[2, 2, 10, 4], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs)
return _create_twins('twins_svt_small', pretrained=pretrained, **model_kwargs)
@register_model
def twins_svt_base(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=4, embed_dims=[96, 192, 384, 768], num_heads=[3, 6, 12, 24], mlp_ratios=[4, 4, 4, 4],
depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs)
return _create_twins('twins_svt_base', pretrained=pretrained, **model_kwargs)
@register_model
def twins_svt_large(pretrained=False, **kwargs):
model_kwargs = dict(
patch_size=4, embed_dims=[128, 256, 512, 1024], num_heads=[4, 8, 16, 32], mlp_ratios=[4, 4, 4, 4],
depths=[2, 2, 18, 2], wss=[7, 7, 7, 7], sr_ratios=[8, 4, 2, 1], **kwargs)
return _create_twins('twins_svt_large', pretrained=pretrained, **model_kwargs)

@ -0,0 +1,405 @@
""" Visformer
Paper: Visformer: The Vision-friendly Transformer - https://arxiv.org/abs/2104.12533
From original at https://github.com/danczs/Visformer
"""
from copy import deepcopy
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d
from .registry import register_model
__all__ = ['Visformer']
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.0', 'classifier': 'head',
**kwargs
}
default_cfgs = dict(
visformer_tiny=_cfg(),
visformer_small=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vt3p-weights/visformer_small-839e1f5b.pth'
),
)
class SpatialMlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None,
act_layer=nn.GELU, drop=0., group=8, spatial_conv=False):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.in_features = in_features
self.out_features = out_features
self.spatial_conv = spatial_conv
if self.spatial_conv:
if group < 2: # net setting
hidden_features = in_features * 5 // 6
else:
hidden_features = in_features * 2
self.hidden_features = hidden_features
self.group = group
self.drop = nn.Dropout(drop)
self.conv1 = nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0, bias=False)
self.act1 = act_layer()
if self.spatial_conv:
self.conv2 = nn.Conv2d(
hidden_features, hidden_features, 3, stride=1, padding=1, groups=self.group, bias=False)
self.act2 = act_layer()
else:
self.conv2 = None
self.act2 = None
self.conv3 = nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0, bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.act1(x)
x = self.drop(x)
if self.conv2 is not None:
x = self.conv2(x)
x = self.act2(x)
x = self.conv3(x)
x = self.drop(x)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, head_dim_ratio=1., attn_drop=0., proj_drop=0.):
super().__init__()
self.dim = dim
self.num_heads = num_heads
head_dim = round(dim // num_heads * head_dim_ratio)
self.head_dim = head_dim
self.scale = head_dim ** -0.5
self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Conv2d(self.head_dim * self.num_heads, dim, 1, stride=1, padding=0, bias=False)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, C, H, W = x.shape
x = self.qkv(x).reshape(B, 3, self.num_heads, self.head_dim, -1).permute(1, 0, 2, 4, 3)
q, k, v = x[0], x[1], x[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.permute(0, 1, 3, 2).reshape(B, -1, H, W)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4.,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm2d,
group=8, attn_disabled=False, spatial_conv=False):
super().__init__()
self.spatial_conv = spatial_conv
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
if attn_disabled:
self.norm1 = None
self.attn = None
else:
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim, num_heads=num_heads, head_dim_ratio=head_dim_ratio, attn_drop=attn_drop, proj_drop=drop)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = SpatialMlp(
in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop,
group=group, spatial_conv=spatial_conv) # new setting
def forward(self, x):
if self.attn is not None:
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class Visformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384,
depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111',
vit_stem=False, group=8, pool=True, conv_init=False, embed_norm=None):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
self.init_channels = init_channels
self.img_size = img_size
self.vit_stem = vit_stem
self.pool = pool
self.conv_init = conv_init
if isinstance(depth, (list, tuple)):
self.stage_num1, self.stage_num2, self.stage_num3 = depth
depth = sum(depth)
else:
self.stage_num1 = self.stage_num3 = depth // 3
self.stage_num2 = depth - self.stage_num1 - self.stage_num3
self.pos_embed = pos_embed
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
# stage 1
if self.vit_stem:
self.stem = None
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size, in_chans=in_chans,
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
img_size //= 16
else:
if self.init_channels is None:
self.stem = None
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 2, in_chans=in_chans,
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
img_size //= 8
else:
self.stem = nn.Sequential(
nn.Conv2d(in_chans, self.init_channels, 7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(self.init_channels),
nn.ReLU(inplace=True)
)
img_size //= 2
self.patch_embed1 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 4, in_chans=self.init_channels,
embed_dim=embed_dim // 2, norm_layer=embed_norm, flatten=False)
img_size //= 4
if self.pos_embed:
if self.vit_stem:
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size))
else:
self.pos_embed1 = nn.Parameter(torch.zeros(1, embed_dim//2, img_size, img_size))
self.pos_drop = nn.Dropout(p=drop_rate)
self.stage1 = nn.ModuleList([
Block(
dim=embed_dim//2, num_heads=num_heads, head_dim_ratio=0.5, mlp_ratio=mlp_ratio,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
group=group, attn_disabled=(attn_stage[0] == '0'), spatial_conv=(spatial_conv[0] == '1')
)
for i in range(self.stage_num1)
])
#stage2
if not self.vit_stem:
self.patch_embed2 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim // 2,
embed_dim=embed_dim, norm_layer=embed_norm, flatten=False)
img_size //= 2
if self.pos_embed:
self.pos_embed2 = nn.Parameter(torch.zeros(1, embed_dim, img_size, img_size))
self.stage2 = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
group=group, attn_disabled=(attn_stage[1] == '0'), spatial_conv=(spatial_conv[1] == '1')
)
for i in range(self.stage_num1, self.stage_num1+self.stage_num2)
])
# stage 3
if not self.vit_stem:
self.patch_embed3 = PatchEmbed(
img_size=img_size, patch_size=patch_size // 8, in_chans=embed_dim,
embed_dim=embed_dim * 2, norm_layer=embed_norm, flatten=False)
img_size //= 2
if self.pos_embed:
self.pos_embed3 = nn.Parameter(torch.zeros(1, embed_dim*2, img_size, img_size))
self.stage3 = nn.ModuleList([
Block(
dim=embed_dim*2, num_heads=num_heads, head_dim_ratio=1.0, mlp_ratio=mlp_ratio,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
group=group, attn_disabled=(attn_stage[2] == '0'), spatial_conv=(spatial_conv[2] == '1')
)
for i in range(self.stage_num1+self.stage_num2, depth)
])
# head
if self.pool:
self.global_pooling = nn.AdaptiveAvgPool2d(1)
head_dim = embed_dim if self.vit_stem else embed_dim * 2
self.norm = norm_layer(head_dim)
self.head = nn.Linear(head_dim, num_classes)
# weights init
if self.pos_embed:
trunc_normal_(self.pos_embed1, std=0.02)
if not self.vit_stem:
trunc_normal_(self.pos_embed2, std=0.02)
trunc_normal_(self.pos_embed3, std=0.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
if self.conv_init:
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
else:
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0.)
def forward(self, x):
if self.stem is not None:
x = self.stem(x)
# stage 1
x = self.patch_embed1(x)
if self.pos_embed:
x = x + self.pos_embed1
x = self.pos_drop(x)
for b in self.stage1:
x = b(x)
# stage 2
if not self.vit_stem:
x = self.patch_embed2(x)
if self.pos_embed:
x = x + self.pos_embed2
x = self.pos_drop(x)
for b in self.stage2:
x = b(x)
# stage3
if not self.vit_stem:
x = self.patch_embed3(x)
if self.pos_embed:
x = x + self.pos_embed3
x = self.pos_drop(x)
for b in self.stage3:
x = b(x)
# head
x = self.norm(x)
if self.pool:
x = self.global_pooling(x)
else:
x = x[:, :, 0, 0]
x = self.head(x.view(x.size(0), -1))
return x
def _create_visformer(variant, pretrained=False, default_cfg=None, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
Visformer, variant, pretrained,
default_cfg=default_cfgs[variant],
**kwargs)
return model
@register_model
def visformer_tiny(pretrained=False, **kwargs):
model_cfg = dict(
img_size=224, init_channels=16, embed_dim=192, depth=(7, 4, 4), num_heads=3, mlp_ratio=4., group=8,
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
embed_norm=nn.BatchNorm2d, **kwargs)
model = _create_visformer('visformer_tiny', pretrained=pretrained, **model_cfg)
return model
@register_model
def visformer_small(pretrained=False, **kwargs):
model_cfg = dict(
img_size=224, init_channels=32, embed_dim=384, depth=(7, 4, 4), num_heads=6, mlp_ratio=4., group=8,
attn_stage='011', spatial_conv='100', norm_layer=nn.BatchNorm2d, conv_init=True,
embed_norm=nn.BatchNorm2d, **kwargs)
model = _create_visformer('visformer_small', pretrained=pretrained, **model_cfg)
return model
# @register_model
# def visformer_net1(pretrained=False, **kwargs):
# model = Visformer(
# init_channels=None, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111',
# spatial_conv='000', vit_stem=True, conv_init=True, **kwargs)
# model.default_cfg = _cfg()
# return model
#
#
# @register_model
# def visformer_net2(pretrained=False, **kwargs):
# model = Visformer(
# init_channels=32, embed_dim=384, depth=(0, 12, 0), num_heads=6, mlp_ratio=4., attn_stage='111',
# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
# model.default_cfg = _cfg()
# return model
#
#
# @register_model
# def visformer_net3(pretrained=False, **kwargs):
# model = Visformer(
# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111',
# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
# model.default_cfg = _cfg()
# return model
#
#
# @register_model
# def visformer_net4(pretrained=False, **kwargs):
# model = Visformer(
# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., attn_stage='111',
# spatial_conv='000', vit_stem=False, conv_init=True, **kwargs)
# model.default_cfg = _cfg()
# return model
#
#
# @register_model
# def visformer_net5(pretrained=False, **kwargs):
# model = Visformer(
# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111',
# spatial_conv='111', vit_stem=False, conv_init=True, **kwargs)
# model.default_cfg = _cfg()
# return model
#
#
# @register_model
# def visformer_net6(pretrained=False, **kwargs):
# model = Visformer(
# init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., group=1, attn_stage='111',
# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs)
# model.default_cfg = _cfg()
# return model
#
#
# @register_model
# def visformer_net7(pretrained=False, **kwargs):
# model = Visformer(
# init_channels=32, embed_dim=384, depth=(6, 7, 7), num_heads=6, group=1, attn_stage='000',
# pos_embed=False, spatial_conv='111', conv_init=True, **kwargs)
# model.default_cfg = _cfg()
# return model

@ -352,7 +352,7 @@ def _init_vit_weights(m, n: str = '', head_bias: float = 0., jax_impl: bool = Fa
nn.init.ones_(m.weight)
def resize_pos_embed(posemb, posemb_new, num_tokens=1):
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
@ -363,11 +363,13 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1):
else:
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
gs_old = int(math.sqrt(len(posemb_grid)))
gs_new = int(math.sqrt(ntok_new))
_logger.info('Position embedding grid-size from %s to %s', gs_old, gs_new)
if not len(gs_new): # backwards compatibility
gs_new = [int(math.sqrt(ntok_new))] * 2
assert len(gs_new) >= 2
_logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode='bilinear')
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1)
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear')
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
return posemb
@ -385,20 +387,20 @@ def checkpoint_filter_fn(state_dict, model):
v = v.reshape(O, -1, H, W)
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
# To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed(v, model.pos_embed, getattr(model, 'num_tokens', 1))
v = resize_pos_embed(
v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
out_dict[k] = v
return out_dict
def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kwargs):
if default_cfg is None:
default_cfg = deepcopy(default_cfgs[variant])
overlay_external_default_cfg(default_cfg, kwargs)
default_num_classes = default_cfg['num_classes']
default_img_size = default_cfg['input_size'][-2:]
default_cfg = default_cfg or default_cfgs[variant]
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
num_classes = kwargs.pop('num_classes', default_num_classes)
img_size = kwargs.pop('img_size', default_img_size)
# NOTE this extra code to support handling of repr size for in21k pretrained models
default_num_classes = default_cfg['num_classes']
num_classes = kwargs.get('num_classes', default_num_classes)
repr_size = kwargs.pop('representation_size', None)
if repr_size is not None and num_classes != default_num_classes:
# Remove representation layer if fine-tuning. This may not always be the desired action,
@ -406,18 +408,12 @@ def _create_vision_transformer(variant, pretrained=False, default_cfg=None, **kw
_logger.warning("Removing representation layer for fine-tuning.")
repr_size = None
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
model = build_model_with_cfg(
VisionTransformer, variant, pretrained,
default_cfg=default_cfg,
img_size=img_size,
num_classes=num_classes,
representation_size=repr_size,
pretrained_filter_fn=checkpoint_filter_fn,
**kwargs)
return model

@ -27,7 +27,7 @@ def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
'crop_pct': .9, 'interpolation': 'bicubic',
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
'mean': (0.5, 0.5, 0.5), 'std': (0.5, 0.5, 0.5),
'first_conv': 'patch_embed.backbone.stem.conv', 'classifier': 'head',
**kwargs
@ -107,11 +107,10 @@ class HybridEmbed(nn.Module):
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
default_cfg = deepcopy(default_cfgs[variant])
embed_layer = partial(HybridEmbed, backbone=backbone)
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
return _create_vision_transformer(
variant, pretrained=pretrained, default_cfg=default_cfg, embed_layer=embed_layer, **kwargs)
variant, pretrained=pretrained, embed_layer=embed_layer, default_cfg=default_cfgs[variant], **kwargs)
def _resnetv2(layers=(3, 4, 9), **kwargs):

@ -1 +1 @@
__version__ = '0.4.9'
__version__ = '0.4.11'

Loading…
Cancel
Save