Merge pull request #668 from rwightman/more_attn

Add Gather-Excite, Global Context, BAT, Non-Local attn modules and refactored all attn modules and factory for improved consistency. EfficientNet / MobileNetV3 backbones able to use a wider variety of attention modules.
pull/679/head
Ross Wightman 3 years ago committed by GitHub
commit 54a6cca27a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -36,11 +36,16 @@ jobs:
run: pip install --no-cache-dir torch==${{ matrix.torch }} torchvision==${{ matrix.torchvision }}
- name: Install torch on ubuntu
if: startsWith(matrix.os, 'ubuntu')
run: pip install --no-cache-dir torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
run: |
pip install --no-cache-dir torch==${{ matrix.torch }}+cpu torchvision==${{ matrix.torchvision }}+cpu -f https://download.pytorch.org/whl/torch_stable.html
sudo apt update
sudo apt install -y google-perftools
- name: Install requirements
run: |
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
pip install --no-cache-dir git+https://github.com/mapillary/inplace_abn.git@v1.0.12
- name: Run tests
env:
LD_PRELOAD: /usr/lib/x86_64-linux-gnu/libtcmalloc.so.4
run: |
pytest -vv --durations=0 ./tests

@ -295,10 +295,24 @@ Several (less common) features that I often utilize in my projects are included.
* SplitBachNorm - allows splitting batch norm layers between clean and augmented (auxiliary batch norm) data
* DropPath aka "Stochastic Depth" (https://arxiv.org/abs/1603.09382)
* DropBlock (https://arxiv.org/abs/1810.12890)
* Efficient Channel Attention - ECA (https://arxiv.org/abs/1910.03151)
* Blur Pooling (https://arxiv.org/abs/1904.11486)
* Space-to-Depth by [mrT23](https://github.com/mrT23/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py) (https://arxiv.org/abs/1801.04590) -- original paper?
* Adaptive Gradient Clipping (https://arxiv.org/abs/2102.06171, https://github.com/deepmind/deepmind-research/tree/master/nfnets)
* An extensive selection of channel and/or spatial attention modules:
* Bottleneck Transformer - https://arxiv.org/abs/2101.11605
* CBAM - https://arxiv.org/abs/1807.06521
* Effective Squeeze-Excitation (ESE) - https://arxiv.org/abs/1911.06667
* Efficient Channel Attention (ECA) - https://arxiv.org/abs/1910.03151
* Gather-Excite (GE) - https://arxiv.org/abs/1810.12348
* Global Context (GC) - https://arxiv.org/abs/1904.11492
* Halo - https://arxiv.org/abs/2103.12731
* Involution - https://arxiv.org/abs/2103.06255
* Lambda Layer - https://arxiv.org/abs/2102.08602
* Non-Local (NL) - https://arxiv.org/abs/1711.07971
* Squeeze-and-Excitation (SE) - https://arxiv.org/abs/1709.01507
* Selective Kernel (SK) - (https://arxiv.org/abs/1903.06586
* Split (SPLAT) - https://arxiv.org/abs/2004.08955
* Shifted Window (SWIN) - https://arxiv.org/abs/2103.14030
## Results

@ -24,7 +24,7 @@ NUM_NON_STD = len(NON_STD_FILTERS)
if 'GITHUB_ACTIONS' in os.environ: # and 'Linux' in platform.system():
# GitHub Linux runner is slower and hits memory limits sooner than MacOS, exclude bigger models
EXCLUDE_FILTERS = [
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm',
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*',
'*resnetrs350*', '*resnetrs420*']
else:

@ -17,7 +17,6 @@ from .inception_resnet_v2 import *
from .inception_v3 import *
from .inception_v4 import *
from .levit import *
#from .levit import *
from .mlp_mixer import *
from .mobilenetv3 import *
from .nasnet import *

@ -12,24 +12,12 @@ Consider all of the models definitions here as experimental WIP and likely to ch
Hacked together by / copyright Ross Wightman, 2021.
"""
import math
from dataclasses import dataclass, field
from collections import OrderedDict
from typing import Tuple, List, Optional, Union, Any, Callable
from functools import partial
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .byobnet import BlocksCfg, ByobCfg, create_byob_stem, create_byob_stages, create_downsample,\
reduce_feat_size, register_block, num_groups, LayerFn, _init_weights
from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks
from .helpers import build_model_with_cfg
from .layers import ClassifierHead, ConvBnAct, DropPath, get_act_layer, convert_norm_act, get_attn, get_self_attn,\
make_divisible, to_2tuple
from .registry import register_model
__all__ = ['ByoaNet']
__all__ = []
def _cfg(url='', **kwargs):
@ -63,100 +51,68 @@ default_cfgs = {
'swinnet50ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'eca_swinnext26ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'rednet26t': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)),
'rednet50ts': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)),
'rednet26t': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'rednet50ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
}
@dataclass
class ByoaBlocksCfg(BlocksCfg):
# FIXME allow overriding self_attn layer or args per block/stage,
pass
@dataclass
class ByoaCfg(ByobCfg):
blocks: Tuple[Union[ByoaBlocksCfg, Tuple[ByoaBlocksCfg, ...]], ...] = None
self_attn_layer: Optional[str] = None
self_attn_fixed_size: bool = False
self_attn_kwargs: dict = field(default_factory=lambda: dict())
def interleave_attn(
types : Tuple[str, str], every: Union[int, List[int]], d, first: bool = False, **kwargs
) -> Tuple[ByoaBlocksCfg]:
""" interleave attn blocks
"""
assert len(types) == 2
if isinstance(every, int):
every = list(range(0 if first else every, d, every))
if not every:
every = [d - 1]
set(every)
blocks = []
for i in range(d):
block_type = types[1] if i in every else types[0]
blocks += [ByoaBlocksCfg(type=block_type, d=1, **kwargs)]
return tuple(blocks)
model_cfgs = dict(
botnet26t=ByoaCfg(
botnet26t=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
fixed_input_size=True,
self_attn_layer='bottleneck',
self_attn_fixed_size=True,
self_attn_kwargs=dict()
),
botnet50ts=ByoaCfg(
botnet50ts=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='',
num_features=0,
fixed_input_size=True,
act_layer='silu',
self_attn_layer='bottleneck',
self_attn_fixed_size=True,
self_attn_kwargs=dict()
),
eca_botnext26ts=ByoaCfg(
eca_botnext26ts=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25),
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
fixed_input_size=True,
act_layer='silu',
attn_layer='eca',
self_attn_layer='bottleneck',
self_attn_fixed_size=True,
self_attn_kwargs=dict()
),
halonet_h1=ByoaCfg(
halonet_h1=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='self_attn', d=3, c=64, s=1, gs=0, br=1.0),
ByoaBlocksCfg(type='self_attn', d=3, c=128, s=2, gs=0, br=1.0),
ByoaBlocksCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
ByoaBlocksCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
ByoBlockCfg(type='self_attn', d=3, c=64, s=1, gs=0, br=1.0),
ByoBlockCfg(type='self_attn', d=3, c=128, s=2, gs=0, br=1.0),
ByoBlockCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
ByoBlockCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
),
stem_chs=64,
stem_type='7x7',
@ -165,12 +121,12 @@ model_cfgs = dict(
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=3),
),
halonet_h1_c4c5=ByoaCfg(
halonet_h1_c4c5=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=64, s=1, gs=0, br=1.0),
ByoaBlocksCfg(type='bottle', d=3, c=128, s=2, gs=0, br=1.0),
ByoaBlocksCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
ByoaBlocksCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=0, br=1.0),
ByoBlockCfg(type='bottle', d=3, c=128, s=2, gs=0, br=1.0),
ByoBlockCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
ByoBlockCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
),
stem_chs=64,
stem_type='tiered',
@ -179,12 +135,12 @@ model_cfgs = dict(
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=3),
),
halonet26t=ByoaCfg(
halonet26t=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
@ -193,12 +149,12 @@ model_cfgs = dict(
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res
),
halonet50ts=ByoaCfg(
halonet50ts=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
@ -208,12 +164,12 @@ model_cfgs = dict(
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=2)
),
eca_halonext26ts=ByoaCfg(
eca_halonext26ts=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
),
stem_chs=64,
stem_type='tiered',
@ -225,12 +181,12 @@ model_cfgs = dict(
self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res
),
lambda_resnet26t=ByoaCfg(
lambda_resnet26t=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
@ -239,12 +195,12 @@ model_cfgs = dict(
self_attn_layer='lambda',
self_attn_kwargs=dict()
),
lambda_resnet50t=ByoaCfg(
lambda_resnet50t=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=3, d=6, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=3, d=6, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
@ -253,12 +209,12 @@ model_cfgs = dict(
self_attn_layer='lambda',
self_attn_kwargs=dict()
),
eca_lambda_resnext26ts=ByoaCfg(
eca_lambda_resnext26ts=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
),
stem_chs=64,
stem_type='tiered',
@ -270,77 +226,76 @@ model_cfgs = dict(
self_attn_kwargs=dict()
),
swinnet26t=ByoaCfg(
swinnet26t=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
fixed_input_size=True,
self_attn_layer='swin',
self_attn_fixed_size=True,
self_attn_kwargs=dict(win_size=8)
),
swinnet50ts=ByoaCfg(
swinnet50ts=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=4, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
fixed_input_size=True,
act_layer='silu',
self_attn_layer='swin',
self_attn_fixed_size=True,
self_attn_kwargs=dict(win_size=8)
),
eca_swinnext26ts=ByoaCfg(
eca_swinnext26ts=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=16, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=16, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
fixed_input_size=True,
act_layer='silu',
attn_layer='eca',
self_attn_layer='swin',
self_attn_fixed_size=True,
self_attn_kwargs=dict(win_size=8)
),
rednet26t=ByoaCfg(
rednet26t=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='self_attn', d=2, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=512, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=512, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered', # FIXME RedNet uses involution in middle of stem
stem_pool='maxpool',
num_features=0,
self_attn_layer='involution',
self_attn_fixed_size=False,
self_attn_kwargs=dict()
),
rednet50ts=ByoaCfg(
rednet50ts=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='self_attn', d=3, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=4, c=512, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=4, c=512, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
@ -348,161 +303,14 @@ model_cfgs = dict(
num_features=0,
act_layer='silu',
self_attn_layer='involution',
self_attn_fixed_size=False,
self_attn_kwargs=dict()
),
)
@dataclass
class ByoaLayerFn(LayerFn):
self_attn: Optional[Callable] = None
class SelfAttnBlock(nn.Module):
""" ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', extra_conv=False, linear_out=False, post_attn_na=True, feat_size=None,
layers: ByoaLayerFn = None, drop_block=None, drop_path_rate=0.):
super(SelfAttnBlock, self).__init__()
assert layers is not None
mid_chs = make_divisible(out_chs * bottle_ratio)
groups = num_groups(group_size, mid_chs)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = create_downsample(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
apply_act=False, layers=layers)
else:
self.shortcut = nn.Identity()
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
if extra_conv:
self.conv2_kxk = layers.conv_norm_act(
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block)
stride = 1 # striding done via conv if enabled
else:
self.conv2_kxk = nn.Identity()
opt_kwargs = {} if feat_size is None else dict(feat_size=feat_size)
# FIXME need to dilate self attn to have dilated network support, moop moop
self.self_attn = layers.self_attn(mid_chs, stride=stride, **opt_kwargs)
self.post_attn = layers.norm_act(mid_chs) if post_attn_na else nn.Identity()
self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
if zero_init_last_bn:
nn.init.zeros_(self.conv3_1x1.bn.weight)
if hasattr(self.self_attn, 'reset_parameters'):
self.self_attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)
x = self.conv1_1x1(x)
x = self.conv2_kxk(x)
x = self.self_attn(x)
x = self.post_attn(x)
x = self.conv3_1x1(x)
x = self.drop_path(x)
x = self.act(x + shortcut)
return x
register_block('self_attn', SelfAttnBlock)
def _byoa_block_args(block_kwargs, block_cfg: ByoaBlocksCfg, model_cfg: ByoaCfg, feat_size=None):
if block_cfg.type == 'self_attn' and model_cfg.self_attn_fixed_size:
assert feat_size is not None
block_kwargs['feat_size'] = feat_size
return block_kwargs
def get_layer_fns(cfg: ByoaCfg):
act = get_act_layer(cfg.act_layer)
norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act)
conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act)
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
self_attn = partial(get_self_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None
layer_fn = ByoaLayerFn(
conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn)
return layer_fn
class ByoaNet(nn.Module):
""" 'Bring-your-own-attention' Net
A ResNet inspired backbone that supports interleaving traditional residual blocks with
'Self Attention' bottleneck blocks that replace the bottleneck kxk conv w/ a self-attention
or similar module.
FIXME This class network definition is almost the same as ByobNet, I'd like to merge them but
torchscript limitations prevent sensible inheritance overrides.
"""
def __init__(self, cfg: ByoaCfg, num_classes=1000, in_chans=3, output_stride=32, global_pool='avg',
zero_init_last_bn=True, img_size=None, drop_rate=0., drop_path_rate=0.):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
layers = get_layer_fns(cfg)
feat_size = to_2tuple(img_size) if img_size is not None else None
self.feature_info = []
stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor))
self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers)
self.feature_info.extend(stem_feat[:-1])
feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction'])
self.stages, stage_feat = create_byob_stages(
cfg, drop_path_rate, output_stride, stem_feat[-1],
feat_size=feat_size, layers=layers, extra_args_fn=_byoa_block_args)
self.feature_info.extend(stage_feat[:-1])
prev_chs = stage_feat[-1]['num_chs']
if cfg.num_features:
self.num_features = int(round(cfg.width_factor * cfg.num_features))
self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1)
else:
self.num_features = prev_chs
self.final_conv = nn.Identity()
self.feature_info += [
dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')]
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
for n, m in self.named_modules():
_init_weights(m, n)
for m in self.modules():
# call each block's weight init for block-specific overrides to init above
if hasattr(m, 'init_weights'):
m.init_weights(zero_init_last_bn=zero_init_last_bn)
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
x = self.final_conv(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs):
return build_model_with_cfg(
ByoaNet, variant, pretrained,
ByobNet, variant, pretrained,
default_cfg=default_cfgs[variant],
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
feature_cfg=dict(flatten_sequential=True),

@ -26,8 +26,7 @@ Hacked together by / copyright Ross Wightman, 2021.
"""
import math
from dataclasses import dataclass, field, replace
from collections import OrderedDict
from typing import Tuple, List, Optional, Union, Any, Callable, Sequence
from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence
from functools import partial
import torch
@ -36,10 +35,10 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible
create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple
from .registry import register_model
__all__ = ['ByobNet', 'ByobCfg', 'BlocksCfg', 'create_byob_stem', 'create_block']
__all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block']
def _cfg(url='', **kwargs):
@ -87,35 +86,59 @@ default_cfgs = {
'repvgg_b3g4': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_b3g4-73c370bf.pth',
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv')),
# experimental configs
'resnet51q': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet51q_ra2-d47dcc76.pth',
first_conv='stem.conv1', input_size=(3, 256, 256), pool_size=(8, 8),
test_input_size=(3, 288, 288), crop_pct=1.0),
'resnet61q': _cfg(
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'geresnet50t': _cfg(
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
'gcresnet50t': _cfg(
first_conv='stem.conv1.conv', input_size=(3, 256, 256), pool_size=(8, 8), interpolation='bicubic'),
}
@dataclass
class BlocksCfg:
class ByoBlockCfg:
type: Union[str, nn.Module]
d: int # block depth (number of block repeats in stage)
c: int # number of output channels for each block in stage
s: int = 2 # stride of stage (first block)
gs: Optional[Union[int, Callable]] = None # group-size of blocks in stage, conv is depthwise if gs == 1
br: float = 1. # bottleneck-ratio of blocks in stage
no_attn: bool = False # disable channel attn (ie SE) when layer is set for model
# NOTE: these config items override the model cfgs that are applied to all blocks by default
attn_layer: Optional[str] = None
attn_kwargs: Optional[Dict[str, Any]] = None
self_attn_layer: Optional[str] = None
self_attn_kwargs: Optional[Dict[str, Any]] = None
block_kwargs: Optional[Dict[str, Any]] = None
@dataclass
class ByobCfg:
blocks: Tuple[Union[BlocksCfg, Tuple[BlocksCfg, ...]], ...]
class ByoModelCfg:
blocks: Tuple[Union[ByoBlockCfg, Tuple[ByoBlockCfg, ...]], ...]
downsample: str = 'conv1x1'
stem_type: str = '3x3'
stem_pool: str = ''
stem_pool: Optional[str] = 'maxpool'
stem_chs: int = 32
width_factor: float = 1.0
num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0
zero_init_last_bn: bool = True
fixed_input_size: bool = False # model constrained to a fixed-input size / img_size must be provided on creation
act_layer: str = 'relu'
norm_layer: str = 'batchnorm'
# NOTE: these config items will be overridden by the block cfg (per-block) if they are set there
attn_layer: Optional[str] = None
attn_kwargs: dict = field(default_factory=lambda: dict())
self_attn_layer: Optional[str] = None
self_attn_kwargs: dict = field(default_factory=lambda: dict())
block_kwargs: Dict[str, Any] = field(default_factory=lambda: dict())
def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0):
@ -123,103 +146,287 @@ def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0):
group_size = 0
if groups > 0:
group_size = lambda chs, idx: chs // groups if (idx + 1) % 2 == 0 else 0
bcfg = tuple([BlocksCfg(type='rep', d=d, c=c * wf, gs=group_size) for d, c, wf in zip(d, c, wf)])
bcfg = tuple([ByoBlockCfg(type='rep', d=d, c=c * wf, gs=group_size) for d, c, wf in zip(d, c, wf)])
return bcfg
model_cfgs = dict(
def interleave_blocks(
types: Tuple[str, str], every: Union[int, List[int]], d, first: bool = False, **kwargs
) -> Tuple[ByoBlockCfg]:
""" interleave 2 block types in stack
"""
assert len(types) == 2
if isinstance(every, int):
every = list(range(0 if first else every, d, every))
if not every:
every = [d - 1]
set(every)
blocks = []
for i in range(d):
block_type = types[1] if i in every else types[0]
blocks += [ByoBlockCfg(type=block_type, d=1, **kwargs)]
return tuple(blocks)
gernet_l=ByobCfg(
model_cfgs = dict(
gernet_l=ByoModelCfg(
blocks=(
BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
BlocksCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
BlocksCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
BlocksCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.),
BlocksCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.),
ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
ByoBlockCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.),
ByoBlockCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.),
),
stem_chs=32,
stem_pool=None,
num_features=2560,
),
gernet_m=ByobCfg(
gernet_m=ByoModelCfg(
blocks=(
BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
BlocksCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
BlocksCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
BlocksCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.),
BlocksCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.),
ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
ByoBlockCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.),
ByoBlockCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.),
),
stem_chs=32,
stem_pool=None,
num_features=2560,
),
gernet_s=ByobCfg(
gernet_s=ByoModelCfg(
blocks=(
BlocksCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.),
BlocksCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.),
BlocksCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4),
BlocksCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.),
BlocksCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.),
ByoBlockCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.),
ByoBlockCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.),
ByoBlockCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4),
ByoBlockCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.),
ByoBlockCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.),
),
stem_chs=13,
stem_pool=None,
num_features=1920,
),
repvgg_a2=ByobCfg(
repvgg_a2=ByoModelCfg(
blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1.5, 1.5, 1.5, 2.75)),
stem_type='rep',
stem_chs=64,
),
repvgg_b0=ByobCfg(
repvgg_b0=ByoModelCfg(
blocks=_rep_vgg_bcfg(wf=(1., 1., 1., 2.5)),
stem_type='rep',
stem_chs=64,
),
repvgg_b1=ByobCfg(
repvgg_b1=ByoModelCfg(
blocks=_rep_vgg_bcfg(wf=(2., 2., 2., 4.)),
stem_type='rep',
stem_chs=64,
),
repvgg_b1g4=ByobCfg(
repvgg_b1g4=ByoModelCfg(
blocks=_rep_vgg_bcfg(wf=(2., 2., 2., 4.), groups=4),
stem_type='rep',
stem_chs=64,
),
repvgg_b2=ByobCfg(
repvgg_b2=ByoModelCfg(
blocks=_rep_vgg_bcfg(wf=(2.5, 2.5, 2.5, 5.)),
stem_type='rep',
stem_chs=64,
),
repvgg_b2g4=ByobCfg(
repvgg_b2g4=ByoModelCfg(
blocks=_rep_vgg_bcfg(wf=(2.5, 2.5, 2.5, 5.), groups=4),
stem_type='rep',
stem_chs=64,
),
repvgg_b3=ByobCfg(
repvgg_b3=ByoModelCfg(
blocks=_rep_vgg_bcfg(wf=(3., 3., 3., 5.)),
stem_type='rep',
stem_chs=64,
),
repvgg_b3g4=ByobCfg(
repvgg_b3g4=ByoModelCfg(
blocks=_rep_vgg_bcfg(wf=(3., 3., 3., 5.), groups=4),
stem_type='rep',
stem_chs=64,
),
resnet52q=ByobCfg(
# WARN: experimental, may vanish/change
resnet51q=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
ByoBlockCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0),
),
stem_chs=128,
stem_type='quad2',
stem_pool=None,
num_features=2048,
act_layer='silu',
),
resnet61q=ByoModelCfg(
blocks=(
BlocksCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
BlocksCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
BlocksCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25),
BlocksCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0),
ByoBlockCfg(type='edge', d=1, c=256, s=1, gs=0, br=1.0, block_kwargs=dict()),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
ByoBlockCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0),
),
stem_chs=128,
stem_type='quad',
stem_pool=None,
num_features=2048,
act_layer='silu',
block_kwargs=dict(extra_conv=True),
),
# WARN: experimental, may vanish/change
geresnet50t=ByoModelCfg(
blocks=(
ByoBlockCfg(type='edge', d=3, c=256, s=1, br=0.25),
ByoBlockCfg(type='edge', d=4, c=512, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool=None,
attn_layer='ge',
attn_kwargs=dict(extent=8, extra_params=True),
#attn_kwargs=dict(extent=8),
#block_kwargs=dict(attn_last=True)
),
# WARN: experimental, may vanish/change
gcresnet50t=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool=None,
attn_layer='gc'
),
)
def expand_blocks_cfg(stage_blocks_cfg: Union[BlocksCfg, Sequence[BlocksCfg]]) -> List[BlocksCfg]:
@register_model
def gernet_l(pretrained=False, **kwargs):
""" GEResNet-Large (GENet-Large from official impl)
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
"""
return _create_byobnet('gernet_l', pretrained=pretrained, **kwargs)
@register_model
def gernet_m(pretrained=False, **kwargs):
""" GEResNet-Medium (GENet-Normal from official impl)
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
"""
return _create_byobnet('gernet_m', pretrained=pretrained, **kwargs)
@register_model
def gernet_s(pretrained=False, **kwargs):
""" EResNet-Small (GENet-Small from official impl)
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
"""
return _create_byobnet('gernet_s', pretrained=pretrained, **kwargs)
@register_model
def repvgg_a2(pretrained=False, **kwargs):
""" RepVGG-A2
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_a2', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b0(pretrained=False, **kwargs):
""" RepVGG-B0
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b0', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b1(pretrained=False, **kwargs):
""" RepVGG-B1
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b1', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b1g4(pretrained=False, **kwargs):
""" RepVGG-B1g4
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b1g4', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b2(pretrained=False, **kwargs):
""" RepVGG-B2
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b2', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b2g4(pretrained=False, **kwargs):
""" RepVGG-B2g4
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b2g4', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b3(pretrained=False, **kwargs):
""" RepVGG-B3
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b3', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b3g4(pretrained=False, **kwargs):
""" RepVGG-B3g4
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b3g4', pretrained=pretrained, **kwargs)
@register_model
def resnet51q(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('resnet51q', pretrained=pretrained, **kwargs)
@register_model
def resnet61q(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('resnet61q', pretrained=pretrained, **kwargs)
@register_model
def geresnet50t(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('geresnet50t', pretrained=pretrained, **kwargs)
@register_model
def gcresnet50t(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('gcresnet50t', pretrained=pretrained, **kwargs)
def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]:
if not isinstance(stage_blocks_cfg, Sequence):
stage_blocks_cfg = (stage_blocks_cfg,)
block_cfgs = []
@ -243,6 +450,7 @@ class LayerFn:
norm_act: Callable = BatchNormAct2d
act: Callable = nn.ReLU
attn: Optional[Callable] = None
self_attn: Optional[Callable] = None
class DownsampleAvg(nn.Module):
@ -275,7 +483,8 @@ class BasicBlock(nn.Module):
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), group_size=None, bottle_ratio=1.0,
downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None,
drop_path_rate=0.):
super(BasicBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
@ -289,15 +498,19 @@ class BasicBlock(nn.Module):
self.shortcut = nn.Identity()
self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0])
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
self.conv2_kxk = layers.conv_norm_act(
mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block, apply_act=False)
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
def init_weights(self, zero_init_last_bn: bool = False):
if zero_init_last_bn:
nn.init.zeros_(self.conv2_kxk.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)
@ -317,7 +530,8 @@ class BottleneckBlock(nn.Module):
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', linear_out=False, layers : LayerFn = None, drop_block=None, drop_path_rate=0.):
downsample='avg', attn_last=False, linear_out=False, extra_conv=False, layers: LayerFn = None,
drop_block=None, drop_path_rate=0.):
super(BottleneckBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
@ -334,22 +548,36 @@ class BottleneckBlock(nn.Module):
self.conv2_kxk = layers.conv_norm_act(
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block)
self.attn = nn.Identity() if layers.attn is None else layers.attn(mid_chs)
self.conv2_kxk = layers.conv_norm_act(
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block)
if extra_conv:
self.conv2b_kxk = layers.conv_norm_act(
mid_chs, mid_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block)
else:
self.conv2b_kxk = nn.Identity()
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
def init_weights(self, zero_init_last_bn: bool = False):
if zero_init_last_bn:
nn.init.zeros_(self.conv3_1x1.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)
x = self.conv1_1x1(x)
x = self.conv2_kxk(x)
x = self.conv2b_kxk(x)
x = self.attn(x)
x = self.conv3_1x1(x)
x = self.attn_last(x)
x = self.drop_path(x)
x = self.act(x + shortcut)
@ -368,7 +596,8 @@ class DarkBlock(nn.Module):
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None,
drop_path_rate=0.):
super(DarkBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
@ -382,23 +611,28 @@ class DarkBlock(nn.Module):
self.shortcut = nn.Identity()
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
self.conv2_kxk = layers.conv_norm_act(
mid_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block, apply_act=False)
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
def init_weights(self, zero_init_last_bn: bool = False):
if zero_init_last_bn:
nn.init.zeros_(self.conv2_kxk.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)
x = self.conv1_1x1(x)
x = self.conv2_kxk(x)
x = self.attn(x)
x = self.conv2_kxk(x)
x = self.attn_last(x)
x = self.drop_path(x)
x = self.act(x + shortcut)
return x
@ -415,7 +649,8 @@ class EdgeBlock(nn.Module):
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None,
drop_block=None, drop_path_rate=0.):
super(EdgeBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
@ -431,14 +666,18 @@ class EdgeBlock(nn.Module):
self.conv1_kxk = layers.conv_norm_act(
in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block)
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
self.conv2_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
def init_weights(self, zero_init_last_bn: bool = False):
if zero_init_last_bn:
nn.init.zeros_(self.conv2_1x1.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)
@ -446,6 +685,7 @@ class EdgeBlock(nn.Module):
x = self.conv1_kxk(x)
x = self.attn(x)
x = self.conv2_1x1(x)
x = self.attn_last(x)
x = self.drop_path(x)
x = self.act(x + shortcut)
return x
@ -460,7 +700,7 @@ class RepVggBlock(nn.Module):
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='', layers : LayerFn = None, drop_block=None, drop_path_rate=0.):
downsample='', layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(RepVggBlock, self).__init__()
layers = layers or LayerFn()
groups = num_groups(group_size, in_chs)
@ -475,12 +715,14 @@ class RepVggBlock(nn.Module):
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
self.act = layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
def init_weights(self, zero_init_last_bn: bool = False):
# NOTE this init overrides that base model init with specific changes for the block type
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight, .1, .1)
nn.init.normal_(m.bias, 0, .1)
if hasattr(self.attn, 'reset_parameters'):
self.attn.reset_parameters()
def forward(self, x):
if self.identity is None:
@ -495,12 +737,68 @@ class RepVggBlock(nn.Module):
return x
class SelfAttnBlock(nn.Module):
""" ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', extra_conv=False, linear_out=False, post_attn_na=True, feat_size=None,
layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(SelfAttnBlock, self).__init__()
assert layers is not None
mid_chs = make_divisible(out_chs * bottle_ratio)
groups = num_groups(group_size, mid_chs)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = create_downsample(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
apply_act=False, layers=layers)
else:
self.shortcut = nn.Identity()
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
if extra_conv:
self.conv2_kxk = layers.conv_norm_act(
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block)
stride = 1 # striding done via conv if enabled
else:
self.conv2_kxk = nn.Identity()
opt_kwargs = {} if feat_size is None else dict(feat_size=feat_size)
# FIXME need to dilate self attn to have dilated network support, moop moop
self.self_attn = layers.self_attn(mid_chs, stride=stride, **opt_kwargs)
self.post_attn = layers.norm_act(mid_chs) if post_attn_na else nn.Identity()
self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn: bool = False):
if zero_init_last_bn:
nn.init.zeros_(self.conv3_1x1.bn.weight)
if hasattr(self.self_attn, 'reset_parameters'):
self.self_attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)
x = self.conv1_1x1(x)
x = self.conv2_kxk(x)
x = self.self_attn(x)
x = self.post_attn(x)
x = self.conv3_1x1(x)
x = self.drop_path(x)
x = self.act(x + shortcut)
return x
_block_registry = dict(
basic=BasicBlock,
bottle=BottleneckBlock,
dark=DarkBlock,
edge=EdgeBlock,
rep=RepVggBlock,
self_attn=SelfAttnBlock,
)
@ -552,7 +850,7 @@ class Stem(nn.Sequential):
curr_stride *= s
prev_feat = conv_name
if 'max' in pool.lower():
if pool and 'max' in pool.lower():
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
self.add_module('pool', nn.MaxPool2d(3, 2, 1))
curr_stride *= 2
@ -564,7 +862,7 @@ class Stem(nn.Sequential):
def create_byob_stem(in_chs, out_chs, stem_type='', pool_type='', feat_prefix='stem', layers: LayerFn = None):
layers = layers or LayerFn()
assert stem_type in ('', 'quad', 'tiered', 'deep', 'rep', '7x7', '3x3')
assert stem_type in ('', 'quad', 'quad2', 'tiered', 'deep', 'rep', '7x7', '3x3')
if 'quad' in stem_type:
# based on NFNet stem, stack of 4 3x3 convs
num_act = 2 if 'quad2' in stem_type else None
@ -601,9 +899,58 @@ def reduce_feat_size(feat_size, stride=2):
return None if feat_size is None else tuple([s // stride for s in feat_size])
def override_kwargs(block_kwargs, model_kwargs):
""" Override model level attn/self-attn/block kwargs w/ block level
NOTE: kwargs are NOT merged across levels, block_kwargs will fully replace model_kwargs
for the block if set to anything that isn't None.
i.e. an empty block_kwargs dict will remove kwargs set at model level for that block
"""
out_kwargs = block_kwargs if block_kwargs is not None else model_kwargs
return out_kwargs or {} # make sure None isn't returned
def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, model_cfg: ByoModelCfg, ):
layer_fns = block_kwargs['layers']
# override attn layer / args with block local config
if block_cfg.attn_kwargs is not None or block_cfg.attn_layer is not None:
# override attn layer config
if not block_cfg.attn_layer:
# empty string for attn_layer type will disable attn for this block
attn_layer = None
else:
attn_kwargs = override_kwargs(block_cfg.attn_kwargs, model_cfg.attn_kwargs)
attn_layer = block_cfg.attn_layer or model_cfg.attn_layer
attn_layer = partial(get_attn(attn_layer), *attn_kwargs) if attn_layer is not None else None
layer_fns = replace(layer_fns, attn=attn_layer)
# override self-attn layer / args with block local cfg
if block_cfg.self_attn_kwargs is not None or block_cfg.self_attn_layer is not None:
# override attn layer config
if not block_cfg.self_attn_layer:
# empty string for self_attn_layer type will disable attn for this block
self_attn_layer = None
else:
self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs)
self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer
self_attn_layer = partial(get_attn(self_attn_layer), *self_attn_kwargs) \
if self_attn_layer is not None else None
layer_fns = replace(layer_fns, self_attn=self_attn_layer)
block_kwargs['layers'] = layer_fns
# add additional block_kwargs specified in block_cfg or model_cfg, precedence to block if set
block_kwargs.update(override_kwargs(block_cfg.block_kwargs, model_cfg.block_kwargs))
def create_byob_stages(
cfg, drop_path_rate, output_stride, stem_feat,
feat_size=None, layers=None, extra_args_fn=None):
cfg: ByoModelCfg, drop_path_rate: float, output_stride: int, stem_feat: Dict[str, Any],
feat_size: Optional[int] = None,
layers: Optional[LayerFn] = None,
block_kwargs_fn: Optional[Callable] = update_block_kwargs):
layers = layers or LayerFn()
feature_info = []
block_cfgs = [expand_blocks_cfg(s) for s in cfg.blocks]
@ -641,8 +988,10 @@ def create_byob_stages(
drop_path_rate=dpr[stage_idx][block_idx],
layers=layers,
)
if extra_args_fn is not None:
extra_args_fn(block_kwargs, block_cfg=block_cfg, model_cfg=cfg, feat_size=feat_size)
if block_cfg.type in ('self_attn',):
# add feat_size arg for blocks that support/need it
block_kwargs['feat_size'] = feat_size
block_kwargs_fn(block_kwargs, block_cfg=block_cfg, model_cfg=cfg)
blocks += [create_block(block_cfg.type, **block_kwargs)]
first_dilation = dilation
prev_chs = out_chs
@ -656,12 +1005,13 @@ def create_byob_stages(
return nn.Sequential(*stages), feature_info
def get_layer_fns(cfg: ByobCfg):
def get_layer_fns(cfg: ByoModelCfg):
act = get_act_layer(cfg.act_layer)
norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act)
conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act)
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn)
self_attn = partial(get_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None
layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn)
return layer_fn
@ -673,19 +1023,24 @@ class ByobNet(nn.Module):
Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act).
"""
def __init__(self, cfg: ByobCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
zero_init_last_bn=True, drop_rate=0., drop_path_rate=0.):
def __init__(self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
zero_init_last_bn=True, img_size=None, drop_rate=0., drop_path_rate=0.):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
layers = get_layer_fns(cfg)
if cfg.fixed_input_size:
assert img_size is not None, 'img_size argument is required for fixed input size model'
feat_size = to_2tuple(img_size) if img_size is not None else None
self.feature_info = []
stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor))
self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers)
self.feature_info.extend(stem_feat[:-1])
feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction'])
self.stages, stage_feat = create_byob_stages(cfg, drop_path_rate, output_stride, stem_feat[-1], layers=layers)
self.stages, stage_feat = create_byob_stages(
cfg, drop_path_rate, output_stride, stem_feat[-1], layers=layers, feat_size=feat_size)
self.feature_info.extend(stage_feat[:-1])
prev_chs = stage_feat[-1]['num_chs']
@ -748,91 +1103,3 @@ def _create_byobnet(variant, pretrained=False, **kwargs):
model_cfg=model_cfgs[variant],
feature_cfg=dict(flatten_sequential=True),
**kwargs)
@register_model
def gernet_l(pretrained=False, **kwargs):
""" GEResNet-Large (GENet-Large from official impl)
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
"""
return _create_byobnet('gernet_l', pretrained=pretrained, **kwargs)
@register_model
def gernet_m(pretrained=False, **kwargs):
""" GEResNet-Medium (GENet-Normal from official impl)
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
"""
return _create_byobnet('gernet_m', pretrained=pretrained, **kwargs)
@register_model
def gernet_s(pretrained=False, **kwargs):
""" EResNet-Small (GENet-Small from official impl)
`Neural Architecture Design for GPU-Efficient Networks` - https://arxiv.org/abs/2006.14090
"""
return _create_byobnet('gernet_s', pretrained=pretrained, **kwargs)
@register_model
def repvgg_a2(pretrained=False, **kwargs):
""" RepVGG-A2
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_a2', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b0(pretrained=False, **kwargs):
""" RepVGG-B0
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b0', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b1(pretrained=False, **kwargs):
""" RepVGG-B1
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b1', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b1g4(pretrained=False, **kwargs):
""" RepVGG-B1g4
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b1g4', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b2(pretrained=False, **kwargs):
""" RepVGG-B2
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b2', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b2g4(pretrained=False, **kwargs):
""" RepVGG-B2g4
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b2g4', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b3(pretrained=False, **kwargs):
""" RepVGG-B3
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b3', pretrained=pretrained, **kwargs)
@register_model
def repvgg_b3g4(pretrained=False, **kwargs):
""" RepVGG-B3g4
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b3g4', pretrained=pretrained, **kwargs)

@ -91,6 +91,12 @@ default_cfgs = {
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth',
interpolation='bilinear'),
# NOTE experimenting with alternate attention
'eca_efficientnet_b0': _cfg(
url=''),
'gc_efficientnet_b0': _cfg(
url=''),
'efficientnet_b0': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth'),
'efficientnet_b1': _cfg(
@ -1223,6 +1229,26 @@ def efficientnet_b0(pretrained=False, **kwargs):
return model
@register_model
def eca_efficientnet_b0(pretrained=False, **kwargs):
""" EfficientNet-B0 w/ ECA attn """
# NOTE experimental config
model = _gen_efficientnet(
'eca_efficientnet_b0', se_layer='ecam', channel_multiplier=1.0, depth_multiplier=1.0,
pretrained=pretrained, **kwargs)
return model
@register_model
def gc_efficientnet_b0(pretrained=False, **kwargs):
""" EfficientNet-B0 w/ GlobalContext """
# NOTE experminetal config
model = _gen_efficientnet(
'gc_efficientnet_b0', se_layer='gc', channel_multiplier=1.0, depth_multiplier=1.0,
pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_b1(pretrained=False, **kwargs):
""" EfficientNet-B1 """

@ -7,7 +7,7 @@ import torch
import torch.nn as nn
from torch.nn import functional as F
from .layers import create_conv2d, drop_path, make_divisible, get_act_fn, create_act_layer
from .layers import create_conv2d, drop_path, make_divisible, create_act_layer
from .layers.activations import sigmoid
__all__ = [
@ -19,33 +19,32 @@ class SqueezeExcite(nn.Module):
Args:
in_chs (int): input channels to layer
se_ratio (float): ratio of squeeze reduction
rd_ratio (float): ratio of squeeze reduction
act_layer (nn.Module): activation layer of containing block
gate_fn (Callable): attention gate function
block_in_chs (int): input channels of containing block (for calculating reduction from)
reduce_from_block (bool): calculate reduction from block input channels if True
gate_layer (Callable): attention gate function
force_act_layer (nn.Module): override block's activation fn if this is set/bound
divisor (int): make reduction channels divisible by this
rd_round_fn (Callable): specify a fn to calculate rounding of reduced chs
"""
def __init__(
self, in_chs, se_ratio=0.25, act_layer=nn.ReLU, gate_fn=sigmoid,
block_in_chs=None, reduce_from_block=True, force_act_layer=None, divisor=1):
self, in_chs, rd_ratio=0.25, rd_channels=None, act_layer=nn.ReLU,
gate_layer=nn.Sigmoid, force_act_layer=None, rd_round_fn=None):
super(SqueezeExcite, self).__init__()
reduced_chs = (block_in_chs or in_chs) if reduce_from_block else in_chs
reduced_chs = make_divisible(reduced_chs * se_ratio, divisor)
if rd_channels is None:
rd_round_fn = rd_round_fn or round
rd_channels = rd_round_fn(in_chs * rd_ratio)
act_layer = force_act_layer or act_layer
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
self.conv_reduce = nn.Conv2d(in_chs, rd_channels, 1, bias=True)
self.act1 = create_act_layer(act_layer, inplace=True)
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
self.gate_fn = get_act_fn(gate_fn)
self.conv_expand = nn.Conv2d(rd_channels, in_chs, 1, bias=True)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
x_se = self.conv_reduce(x_se)
x_se = self.act1(x_se)
x_se = self.conv_expand(x_se)
return x * self.gate_fn(x_se)
return x * self.gate(x_se)
class ConvBnAct(nn.Module):
@ -87,10 +86,9 @@ class DepthwiseSeparableConv(nn.Module):
"""
def __init__(
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
noskip=False, pw_kernel_size=1, pw_act=False, se_ratio=0.,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
noskip=False, pw_kernel_size=1, pw_act=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
se_layer=None, drop_path_rate=0.):
super(DepthwiseSeparableConv, self).__init__()
has_se = se_layer is not None and se_ratio > 0.
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
self.has_pw_act = pw_act # activation after point-wise conv
self.drop_path_rate = drop_path_rate
@ -101,7 +99,7 @@ class DepthwiseSeparableConv(nn.Module):
self.act1 = act_layer(inplace=True)
# Squeeze-and-excitation
self.se = se_layer(in_chs, se_ratio=se_ratio, act_layer=act_layer) if has_se else nn.Identity()
self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity()
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
self.bn2 = norm_layer(out_chs)
@ -146,12 +144,11 @@ class InvertedResidual(nn.Module):
def __init__(
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0.,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.):
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.):
super(InvertedResidual, self).__init__()
conv_kwargs = conv_kwargs or {}
mid_chs = make_divisible(in_chs * exp_ratio)
has_se = se_layer is not None and se_ratio > 0.
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_path_rate = drop_path_rate
@ -168,8 +165,7 @@ class InvertedResidual(nn.Module):
self.act2 = act_layer(inplace=True)
# Squeeze-and-excitation
self.se = se_layer(
mid_chs, se_ratio=se_ratio, act_layer=act_layer, block_in_chs=in_chs) if has_se else nn.Identity()
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
# Point-wise linear projection
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
@ -215,8 +211,8 @@ class CondConvResidual(InvertedResidual):
def __init__(
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, se_ratio=0.,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.):
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.):
self.num_experts = num_experts
conv_kwargs = dict(num_experts=self.num_experts)
@ -224,8 +220,8 @@ class CondConvResidual(InvertedResidual):
super(CondConvResidual, self).__init__(
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type,
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_layer=se_layer,
norm_layer=norm_layer, conv_kwargs=conv_kwargs, drop_path_rate=drop_path_rate)
pw_kernel_size=pw_kernel_size, se_layer=se_layer, norm_layer=norm_layer, conv_kwargs=conv_kwargs,
drop_path_rate=drop_path_rate)
self.routing_fn = nn.Linear(in_chs, self.num_experts)
@ -274,8 +270,8 @@ class EdgeResidual(nn.Module):
def __init__(
self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, pad_type='',
force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, se_ratio=0.,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
super(EdgeResidual, self).__init__()
if force_in_chs > 0:
mid_chs = make_divisible(force_in_chs * exp_ratio)
@ -292,8 +288,7 @@ class EdgeResidual(nn.Module):
self.act1 = act_layer(inplace=True)
# Squeeze-and-excitation
self.se = SqueezeExcite(
mid_chs, se_ratio=se_ratio, act_layer=act_layer, block_in_chs=in_chs) if has_se else nn.Identity()
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
# Point-wise linear projection
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type)

@ -10,11 +10,12 @@ import logging
import math
import re
from copy import deepcopy
from functools import partial
import torch.nn as nn
from .efficientnet_blocks import *
from .layers import CondConv2d, get_condconv_initializer, get_act_layer, make_divisible
from .layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
@ -120,7 +121,9 @@ def _decode_block_str(block_str):
elif v == 'hs':
value = get_act_layer('hard_swish')
elif v == 'sw':
value = get_act_layer('swish')
value = get_act_layer('swish') # aka SiLU
elif v == 'mi':
value = get_act_layer('mish')
else:
continue
options[key] = value
@ -265,14 +268,20 @@ class EfficientNetBuilder:
https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
"""
def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels,
def __init__(self, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=False,
act_layer=None, norm_layer=None, se_layer=None, drop_path_rate=0., feature_location=''):
self.output_stride = output_stride
self.pad_type = pad_type
self.round_chs_fn = round_chs_fn
self.se_from_exp = se_from_exp # calculate se channel reduction from expanded (mid) chs
self.act_layer = act_layer
self.norm_layer = norm_layer
self.se_layer = se_layer
self.se_layer = get_attn(se_layer)
try:
self.se_layer(8, rd_ratio=1.0) # test if attn layer accepts rd_ratio arg
self.se_has_ratio = True
except TypeError:
self.se_has_ratio = False
self.drop_path_rate = drop_path_rate
if feature_location == 'depthwise':
# old 'depthwise' mode renamed 'expansion' to match TF impl, old expansion mode didn't make sense
@ -299,16 +308,21 @@ class EfficientNetBuilder:
ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer
assert ba['act_layer'] is not None
ba['norm_layer'] = self.norm_layer
ba['drop_path_rate'] = drop_path_rate
if bt != 'cn':
ba['se_layer'] = self.se_layer
ba['drop_path_rate'] = drop_path_rate
se_ratio = ba.pop('se_ratio')
if se_ratio and self.se_layer is not None:
if not self.se_from_exp:
# adjust se_ratio by expansion ratio if calculating se channels from block input
se_ratio /= ba.get('exp_ratio', 1.0)
if self.se_has_ratio:
ba['se_layer'] = partial(self.se_layer, rd_ratio=se_ratio)
else:
ba['se_layer'] = self.se_layer
if bt == 'ir':
_log_info_if(' InvertedResidual {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
if ba.get('num_experts', 0) > 0:
block = CondConvResidual(**ba)
else:
block = InvertedResidual(**ba)
block = CondConvResidual(**ba) if ba.get('num_experts', 0) else InvertedResidual(**ba)
elif bt == 'ds' or bt == 'dsa':
_log_info_if(' DepthwiseSeparable {}, Args: {}'.format(block_idx, str(ba)), self.verbose)
block = DepthwiseSeparableConv(**ba)
@ -418,28 +432,28 @@ def _init_weight_goog(m, n='', fix_group_fanout=True):
if fix_group_fanout:
fan_out //= m.groups
init_weight_fn = get_condconv_initializer(
lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
lambda w: nn.init.normal_(w, 0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
init_weight_fn(m.weight)
if m.bias is not None:
m.bias.data.zero_()
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
if fix_group_fanout:
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
nn.init.normal_(m.weight, 0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
fan_out = m.weight.size(0) # fan-out
fan_in = 0
if 'routing_fn' in n:
fan_in = m.weight.size(1)
init_range = 1.0 / math.sqrt(fan_in + fan_out)
m.weight.data.uniform_(-init_range, init_range)
m.bias.data.zero_()
nn.init.uniform_(m.weight, -init_range, init_range)
nn.init.zeros_(m.bias)
def efficientnet_init_weights(model: nn.Module, init_fn=None):

@ -40,7 +40,7 @@ default_cfgs = {
}
_SE_LAYER = partial(SqueezeExcite, gate_fn='hard_sigmoid', divisor=4)
_SE_LAYER = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=partial(make_divisible, divisor=4))
class GhostModule(nn.Module):
@ -92,7 +92,7 @@ class GhostBottleneck(nn.Module):
self.bn_dw = None
# Squeeze-and-excitation
self.se = _SE_LAYER(mid_chs, se_ratio=se_ratio) if has_se else None
self.se = _SE_LAYER(mid_chs, rd_ratio=se_ratio) if has_se else None
# Point-wise linear projection
self.ghost2 = GhostModule(mid_chs, out_chs, relu=False)

@ -4,7 +4,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .efficientnet_blocks import SqueezeExcite
from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args
from .efficientnet_builder import decode_arch_def, resolve_act_layer, resolve_bn_args, round_channels
from .helpers import build_model_with_cfg, default_cfg_for_features
from .layers import get_act_fn
from .mobilenetv3 import MobileNetV3, MobileNetV3Features
@ -39,8 +39,7 @@ def _gen_hardcorenas(pretrained, variant, arch_def, **kwargs):
"""
num_features = 1280
se_layer = partial(
SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, reduce_from_block=False, divisor=8)
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
num_features=num_features,

@ -12,26 +12,28 @@ from .create_act import create_act_layer, get_act_layer, get_act_fn
from .create_attn import get_attn, create_attn
from .create_conv2d import create_conv2d
from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act
from .create_self_attn import get_self_attn, create_self_attn
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .eca import EcaModule, CecaModule
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
from .gather_excite import GatherExcite
from .global_context import GlobalContext
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
from .inplace_abn import InplaceAbn
from .involution import Involution
from .linear import Linear
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp
from .norm import GroupNorm
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct
from .padding import get_padding, get_same_padding, pad_same
from .patch_embed import PatchEmbed
from .pool2d_same import AvgPool2dSame, create_pool2d
from .se import SEModule
from .selective_kernel import SelectiveKernelConv
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernel
from .separable_conv import SeparableConv2d, SeparableConvBnAct
from .space_to_depth import SpaceToDepthModule
from .split_attn import SplitAttnConv2d
from .split_attn import SplitAttn
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
from .test_time_pool import TestTimePoolHead, apply_test_time_pool

@ -7,78 +7,87 @@ some tasks, especially fine-grained it seems. I may end up removing this impl.
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from torch import nn as nn
import torch.nn.functional as F
from .conv_bn_act import ConvBnAct
from .create_act import create_act_layer, get_act_layer
from .helpers import make_divisible
class ChannelAttn(nn.Module):
""" Original CBAM channel attention module, currently avg + max pool variant only.
"""
def __init__(self, channels, reduction=16, act_layer=nn.ReLU):
def __init__(
self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
super(ChannelAttn, self).__init__()
self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False)
if not rd_channels:
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias)
self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False)
self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_avg = x.mean((2, 3), keepdim=True)
x_max = F.adaptive_max_pool2d(x, 1)
x_avg = self.fc2(self.act(self.fc1(x_avg)))
x_max = self.fc2(self.act(self.fc1(x_max)))
x_attn = x_avg + x_max
return x * x_attn.sigmoid()
x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True))))
x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True))))
return x * self.gate(x_avg + x_max)
class LightChannelAttn(ChannelAttn):
"""An experimental 'lightweight' that sums avg + max pool first
"""
def __init__(self, channels, reduction=16):
super(LightChannelAttn, self).__init__(channels, reduction)
def __init__(
self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
super(LightChannelAttn, self).__init__(
channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias)
def forward(self, x):
x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * F.adaptive_max_pool2d(x, 1)
x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True)
x_attn = self.fc2(self.act(self.fc1(x_pool)))
return x * x_attn.sigmoid()
return x * F.sigmoid(x_attn)
class SpatialAttn(nn.Module):
""" Original CBAM spatial attention module
"""
def __init__(self, kernel_size=7):
def __init__(self, kernel_size=7, gate_layer='sigmoid'):
super(SpatialAttn, self).__init__()
self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_avg = torch.mean(x, dim=1, keepdim=True)
x_max = torch.max(x, dim=1, keepdim=True)[0]
x_attn = torch.cat([x_avg, x_max], dim=1)
x_attn = torch.cat([x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1)
x_attn = self.conv(x_attn)
return x * x_attn.sigmoid()
return x * self.gate(x_attn)
class LightSpatialAttn(nn.Module):
"""An experimental 'lightweight' variant that sums avg_pool and max_pool results.
"""
def __init__(self, kernel_size=7):
def __init__(self, kernel_size=7, gate_layer='sigmoid'):
super(LightSpatialAttn, self).__init__()
self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_avg = torch.mean(x, dim=1, keepdim=True)
x_max = torch.max(x, dim=1, keepdim=True)[0]
x_attn = 0.5 * x_avg + 0.5 * x_max
x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True)
x_attn = self.conv(x_attn)
return x * x_attn.sigmoid()
return x * self.gate(x_attn)
class CbamModule(nn.Module):
def __init__(self, channels, spatial_kernel_size=7):
def __init__(
self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
super(CbamModule, self).__init__()
self.channel = ChannelAttn(channels)
self.spatial = SpatialAttn(spatial_kernel_size)
self.channel = ChannelAttn(
channels, rd_ratio=rd_ratio, rd_channels=rd_channels,
rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias)
self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer)
def forward(self, x):
x = self.channel(x)
@ -87,9 +96,13 @@ class CbamModule(nn.Module):
class LightCbamModule(nn.Module):
def __init__(self, channels, spatial_kernel_size=7):
def __init__(
self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
super(LightCbamModule, self).__init__()
self.channel = LightChannelAttn(channels)
self.channel = LightChannelAttn(
channels, rd_ratio=rd_ratio, rd_channels=rd_channels,
rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias)
self.spatial = LightSpatialAttn(spatial_kernel_size)
def forward(self, x):

@ -1,11 +1,23 @@
""" Select AttentionFactory Method
""" Attention Factory
Hacked together by / Copyright 2020 Ross Wightman
Hacked together by / Copyright 2021 Ross Wightman
"""
import torch
from .se import SEModule, EffectiveSEModule
from .eca import EcaModule, CecaModule
from functools import partial
from .bottleneck_attn import BottleneckAttn
from .cbam import CbamModule, LightCbamModule
from .eca import EcaModule, CecaModule
from .gather_excite import GatherExcite
from .global_context import GlobalContext
from .halo_attn import HaloAttn
from .involution import Involution
from .lambda_layer import LambdaLayer
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .selective_kernel import SelectiveKernel
from .split_attn import SplitAttn
from .squeeze_excite import SEModule, EffectiveSEModule
from .swin_attn import WindowAttention
def get_attn(attn_type):
@ -15,18 +27,54 @@ def get_attn(attn_type):
if attn_type is not None:
if isinstance(attn_type, str):
attn_type = attn_type.lower()
# Lightweight attention modules (channel and/or coarse spatial).
# Typically added to existing network architecture blocks in addition to existing convolutions.
if attn_type == 'se':
module_cls = SEModule
elif attn_type == 'ese':
module_cls = EffectiveSEModule
elif attn_type == 'eca':
module_cls = EcaModule
elif attn_type == 'ecam':
module_cls = partial(EcaModule, use_mlp=True)
elif attn_type == 'ceca':
module_cls = CecaModule
elif attn_type == 'ge':
module_cls = GatherExcite
elif attn_type == 'gc':
module_cls = GlobalContext
elif attn_type == 'cbam':
module_cls = CbamModule
elif attn_type == 'lcbam':
module_cls = LightCbamModule
# Attention / attention-like modules w/ significant params
# Typically replace some of the existing workhorse convs in a network architecture.
# All of these accept a stride argument and can spatially downsample the input.
elif attn_type == 'sk':
module_cls = SelectiveKernel
elif attn_type == 'splat':
module_cls = SplitAttn
# Self-attention / attention-like modules w/ significant compute and/or params
# Typically replace some of the existing workhorse convs in a network architecture.
# All of these accept a stride argument and can spatially downsample the input.
elif attn_type == 'lambda':
return LambdaLayer
elif attn_type == 'bottleneck':
return BottleneckAttn
elif attn_type == 'halo':
return HaloAttn
elif attn_type == 'swin':
return WindowAttention
elif attn_type == 'involution':
return Involution
elif attn_type == 'nl':
module_cls = NonLocalAttn
elif attn_type == 'bat':
module_cls = BatNonLocalAttn
# Woops!
else:
assert False, "Invalid attn module (%s)" % attn_type
elif isinstance(attn_type, bool):

@ -1,25 +0,0 @@
from .bottleneck_attn import BottleneckAttn
from .halo_attn import HaloAttn
from .involution import Involution
from .lambda_layer import LambdaLayer
from .swin_attn import WindowAttention
def get_self_attn(attn_type):
if attn_type == 'bottleneck':
return BottleneckAttn
elif attn_type == 'halo':
return HaloAttn
elif attn_type == 'lambda':
return LambdaLayer
elif attn_type == 'swin':
return WindowAttention
elif attn_type == 'involution':
return Involution
else:
assert False, f"Unknown attn type ({attn_type})"
def create_self_attn(attn_type, dim, stride=1, **kwargs):
attn_fn = get_self_attn(attn_type)
return attn_fn(dim, stride=stride, **kwargs)

@ -38,6 +38,10 @@ from torch import nn
import torch.nn.functional as F
from .create_act import create_act_layer
from .helpers import make_divisible
class EcaModule(nn.Module):
"""Constructs an ECA module.
@ -48,23 +52,48 @@ class EcaModule(nn.Module):
refer to original paper https://arxiv.org/pdf/1910.03151.pdf
(default=None. if channel size not given, use k_size given for kernel size.)
kernel_size: Adaptive selection of kernel size (default=3)
gamm: used in kernel_size calc, see above
beta: used in kernel_size calc, see above
act_layer: optional non-linearity after conv, enables conv bias, this is an experiment
gate_layer: gating non-linearity to use
"""
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
def __init__(
self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid',
rd_ratio=1/8, rd_channels=None, rd_divisor=8, use_mlp=False):
super(EcaModule, self).__init__()
assert kernel_size % 2 == 1
if channels is not None:
t = int(abs(math.log(channels, 2) + beta) / gamma)
kernel_size = max(t if t % 2 else t + 1, 3)
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
assert kernel_size % 2 == 1
padding = (kernel_size - 1) // 2
if use_mlp:
# NOTE 'mlp' mode is a timm experiment, not in paper
assert channels is not None
if rd_channels is None:
rd_channels = make_divisible(channels * rd_ratio, divisor=rd_divisor)
act_layer = act_layer or nn.ReLU
self.conv = nn.Conv1d(1, rd_channels, kernel_size=1, padding=0, bias=True)
self.act = create_act_layer(act_layer)
self.conv2 = nn.Conv1d(rd_channels, 1, kernel_size=kernel_size, padding=padding, bias=True)
else:
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
self.act = None
self.conv2 = None
self.gate = create_act_layer(gate_layer)
def forward(self, x):
y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv
y = self.conv(y)
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
if self.conv2 is not None:
y = self.act(y)
y = self.conv2(y)
y = self.gate(y).view(x.shape[0], -1, 1, 1)
return x * y.expand_as(x)
EfficientChannelAttn = EcaModule # alias
class CecaModule(nn.Module):
"""Constructs a circular ECA module.
@ -83,25 +112,34 @@ class CecaModule(nn.Module):
refer to original paper https://arxiv.org/pdf/1910.03151.pdf
(default=None. if channel size not given, use k_size given for kernel size.)
kernel_size: Adaptive selection of kernel size (default=3)
gamm: used in kernel_size calc, see above
beta: used in kernel_size calc, see above
act_layer: optional non-linearity after conv, enables conv bias, this is an experiment
gate_layer: gating non-linearity to use
"""
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'):
super(CecaModule, self).__init__()
assert kernel_size % 2 == 1
if channels is not None:
t = int(abs(math.log(channels, 2) + beta) / gamma)
kernel_size = max(t if t % 2 else t + 1, 3)
has_act = act_layer is not None
assert kernel_size % 2 == 1
# PyTorch circular padding mode is buggy as of pytorch 1.4
# see https://github.com/pytorch/pytorch/pull/17240
# implement manual circular padding
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False)
self.padding = (kernel_size - 1) // 2
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
y = x.mean((2, 3)).view(x.shape[0], 1, -1)
# Manually implement circular padding, F.pad does not seemed to be bugged
y = F.pad(y, (self.padding, self.padding), mode='circular')
y = self.conv(y)
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
y = self.gate(y).view(x.shape[0], -1, 1, 1)
return x * y.expand_as(x)
CircularEfficientChannelAttn = CecaModule

@ -0,0 +1,90 @@
""" Gather-Excite Attention Block
Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348
Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet
I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another
impl that covers all of the cases.
NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation
Hacked together by / Copyright 2021 Ross Wightman
"""
import math
from torch import nn as nn
import torch.nn.functional as F
from .create_act import create_act_layer, get_act_layer
from .create_conv2d import create_conv2d
from .helpers import make_divisible
from .mlp import ConvMlp
class GatherExcite(nn.Module):
""" Gather-Excite Attention Module
"""
def __init__(
self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True,
rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'):
super(GatherExcite, self).__init__()
self.add_maxpool = add_maxpool
act_layer = get_act_layer(act_layer)
self.extent = extent
if extra_params:
self.gather = nn.Sequential()
if extent == 0:
assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params'
self.gather.add_module(
'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True))
if norm_layer:
self.gather.add_module(f'norm1', nn.BatchNorm2d(channels))
else:
assert extent % 2 == 0
num_conv = int(math.log2(extent))
for i in range(num_conv):
self.gather.add_module(
f'conv{i + 1}',
create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True))
if norm_layer:
self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels))
if i != num_conv - 1:
self.gather.add_module(f'act{i + 1}', act_layer(inplace=True))
else:
self.gather = None
if self.extent == 0:
self.gk = 0
self.gs = 0
else:
assert extent % 2 == 0
self.gk = self.extent * 2 - 1
self.gs = self.extent
if not rd_channels:
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity()
self.gate = create_act_layer(gate_layer)
def forward(self, x):
size = x.shape[-2:]
if self.gather is not None:
x_ge = self.gather(x)
else:
if self.extent == 0:
# global extent
x_ge = x.mean(dim=(2, 3), keepdims=True)
if self.add_maxpool:
# experimental codepath, may remove or change
x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True)
else:
x_ge = F.avg_pool2d(
x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False)
if self.add_maxpool:
# experimental codepath, may remove or change
x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2)
x_ge = self.mlp(x_ge)
if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1:
x_ge = F.interpolate(x_ge, size=size)
return x * self.gate(x_ge)

@ -0,0 +1,67 @@
""" Global Context Attention Block
Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond`
- https://arxiv.org/abs/1904.11492
Official code consulted as reference: https://github.com/xvjiarui/GCNet
Hacked together by / Copyright 2021 Ross Wightman
"""
from torch import nn as nn
import torch.nn.functional as F
from .create_act import create_act_layer, get_act_layer
from .helpers import make_divisible
from .mlp import ConvMlp
from .norm import LayerNorm2d
class GlobalContext(nn.Module):
def __init__(self, channels, use_attn=True, fuse_add=True, fuse_scale=False, init_last_zero=False,
rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'):
super(GlobalContext, self).__init__()
act_layer = get_act_layer(act_layer)
self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None
if rd_channels is None:
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
if fuse_add:
self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
else:
self.mlp_add = None
if fuse_scale:
self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
else:
self.mlp_scale = None
self.gate = create_act_layer(gate_layer)
self.init_last_zero = init_last_zero
self.reset_parameters()
def reset_parameters(self):
if self.conv_attn is not None:
nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu')
if self.mlp_add is not None:
nn.init.zeros_(self.mlp_add.fc2.weight)
def forward(self, x):
B, C, H, W = x.shape
if self.conv_attn is not None:
attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W)
attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1)
context = x.reshape(B, C, H * W).unsqueeze(1) @ attn
context = context.view(B, C, 1, 1)
else:
context = x.mean(dim=(2, 3), keepdim=True)
if self.mlp_scale is not None:
mlp_x = self.mlp_scale(context)
x = x * self.gate(mlp_x)
if self.mlp_add is not None:
mlp_x = self.mlp_add(context)
x = x + mlp_x
return x

@ -28,4 +28,4 @@ def make_divisible(v, divisor=8, min_value=None, round_limit=.9):
# Make sure that round down does not go down by more than 10%.
if new_v < round_limit * v:
new_v += divisor
return new_v
return new_v

@ -16,7 +16,7 @@ class Involution(nn.Module):
kernel_size=3,
stride=1,
group_size=16,
reduction_ratio=4,
rd_ratio=4,
norm_layer=nn.BatchNorm2d,
act_layer=nn.ReLU,
):
@ -28,12 +28,12 @@ class Involution(nn.Module):
self.groups = self.channels // self.group_size
self.conv1 = ConvBnAct(
in_channels=channels,
out_channels=channels // reduction_ratio,
out_channels=channels // rd_ratio,
kernel_size=1,
norm_layer=norm_layer,
act_layer=act_layer)
self.conv2 = self.conv = create_conv2d(
in_channels=channels // reduction_ratio,
in_channels=channels // rd_ratio,
out_channels=kernel_size**2 * self.groups,
kernel_size=1,
stride=1)

@ -77,3 +77,26 @@ class GatedMlp(nn.Module):
x = self.fc2(x)
x = self.drop(x)
return x
class ConvMlp(nn.Module):
""" MLP using 1x1 convs that keeps spatial dims
"""
def __init__(
self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True)
self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.norm(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
return x

@ -0,0 +1,145 @@
""" Bilinear-Attention-Transform and Non-Local Attention
Paper: `Non-Local Neural Networks With Grouped Bilinear Attentional Transforms`
- https://openaccess.thecvf.com/content_CVPR_2020/html/Chi_Non-Local_Neural_Networks_With_Grouped_Bilinear_Attentional_Transforms_CVPR_2020_paper.html
Adapted from original code: https://github.com/BA-Transform/BAT-Image-Classification
"""
import torch
from torch import nn
from torch.nn import functional as F
from .conv_bn_act import ConvBnAct
from .helpers import make_divisible
class NonLocalAttn(nn.Module):
"""Spatial NL block for image classification.
This was adapted from https://github.com/BA-Transform/BAT-Image-Classification
Their NonLocal impl inspired by https://github.com/facebookresearch/video-nonlocal-net.
"""
def __init__(self, in_channels, use_scale=True, rd_ratio=1/8, rd_channels=None, rd_divisor=8, **kwargs):
super(NonLocalAttn, self).__init__()
if rd_channels is None:
rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)
self.scale = in_channels ** -0.5 if use_scale else 1.0
self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True)
self.norm = nn.BatchNorm2d(in_channels)
self.reset_parameters()
def forward(self, x):
shortcut = x
t = self.t(x)
p = self.p(x)
g = self.g(x)
B, C, H, W = t.size()
t = t.view(B, C, -1).permute(0, 2, 1)
p = p.view(B, C, -1)
g = g.view(B, C, -1).permute(0, 2, 1)
att = torch.bmm(t, p) * self.scale
att = F.softmax(att, dim=2)
x = torch.bmm(att, g)
x = x.permute(0, 2, 1).reshape(B, C, H, W)
x = self.z(x)
x = self.norm(x) + shortcut
return x
def reset_parameters(self):
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
if len(list(m.parameters())) > 1:
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 0)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.weight, 0)
nn.init.constant_(m.bias, 0)
class BilinearAttnTransform(nn.Module):
def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(BilinearAttnTransform, self).__init__()
self.conv1 = ConvBnAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer)
self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1))
self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size))
self.conv2 = ConvBnAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
self.block_size = block_size
self.groups = groups
self.in_channels = in_channels
def resize_mat(self, x, t):
B, C, block_size, block_size1 = x.shape
assert block_size == block_size1
if t <= 1:
return x
x = x.view(B * C, -1, 1, 1)
x = x * torch.eye(t, t, dtype=x.dtype, device=x.device)
x = x.view(B * C, block_size, block_size, t, t)
x = torch.cat(torch.split(x, 1, dim=1), dim=3)
x = torch.cat(torch.split(x, 1, dim=2), dim=4)
x = x.view(B, C, block_size * t, block_size * t)
return x
def forward(self, x):
assert x.shape[-1] % self.block_size == 0 and x.shape[-2] % self.block_size == 0
B, C, H, W = x.shape
out = self.conv1(x)
rp = F.adaptive_max_pool2d(out, (self.block_size, 1))
cp = F.adaptive_max_pool2d(out, (1, self.block_size))
p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size)
q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size)
p = F.sigmoid(p)
q = F.sigmoid(q)
p = p / p.sum(dim=3, keepdim=True)
q = q / q.sum(dim=2, keepdim=True)
p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(
0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous()
p = p.view(B, C, self.block_size, self.block_size)
q = q.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(
0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous()
q = q.view(B, C, self.block_size, self.block_size)
p = self.resize_mat(p, H // self.block_size)
q = self.resize_mat(q, W // self.block_size)
y = p.matmul(x)
y = y.matmul(q)
y = self.conv2(y)
return y
class BatNonLocalAttn(nn.Module):
""" BAT
Adapted from: https://github.com/BA-Transform/BAT-Image-Classification
"""
def __init__(
self, in_channels, block_size=7, groups=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
drop_rate=0.2, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, **_):
super().__init__()
if rd_channels is None:
rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)
self.conv1 = ConvBnAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer)
self.conv2 = ConvBnAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
self.dropout = nn.Dropout2d(p=drop_rate)
def forward(self, x):
xl = self.conv1(x)
y = self.ba(xl)
y = self.conv2(y)
y = self.dropout(y)
return y + x

@ -12,3 +12,12 @@ class GroupNorm(nn.GroupNorm):
def forward(self, x):
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
class LayerNorm2d(nn.LayerNorm):
""" Layernorm for channels of '2d' spatial BCHW tensors """
def __init__(self, num_channels):
super().__init__([num_channels, 1, 1])
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)

@ -1,50 +0,0 @@
from torch import nn as nn
import torch.nn.functional as F
from .create_act import create_act_layer
from .helpers import make_divisible
class SEModule(nn.Module):
""" SE Module as defined in original SE-Nets with a few additions
Additions include:
* min_channels can be specified to keep reduced channel count at a minimum (default: 8)
* divisor can be specified to keep channels rounded to specified values (default: 1)
* reduction channels can be specified directly by arg (if reduction_channels is set)
* reduction channels can be specified by float ratio (if reduction_ratio is set)
"""
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, gate_layer='sigmoid',
reduction_ratio=None, reduction_channels=None, min_channels=8, divisor=1):
super(SEModule, self).__init__()
if reduction_channels is not None:
reduction_channels = reduction_channels # direct specification highest priority, no rounding/min done
elif reduction_ratio is not None:
reduction_channels = make_divisible(channels * reduction_ratio, divisor, min_channels)
else:
reduction_channels = make_divisible(channels // reduction, divisor, min_channels)
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
x_se = self.fc1(x_se)
x_se = self.act(x_se)
x_se = self.fc2(x_se)
return x * self.gate(x_se)
class EffectiveSEModule(nn.Module):
""" 'Effective Squeeze-Excitation
From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
"""
def __init__(self, channels, gate_layer='hard_sigmoid'):
super(EffectiveSEModule, self).__init__()
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
x_se = self.fc(x_se)
return x * self.gate(x_se)

@ -8,6 +8,7 @@ import torch
from torch import nn as nn
from .conv_bn_act import ConvBnAct
from .helpers import make_divisible
def _kernel_valid(k):
@ -45,10 +46,10 @@ class SelectiveKernelAttn(nn.Module):
return x
class SelectiveKernelConv(nn.Module):
class SelectiveKernel(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1,
attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False,
def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1,
rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True,
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None):
""" Selective Kernel Convolution Module
@ -66,8 +67,7 @@ class SelectiveKernelConv(nn.Module):
stride (int): stride for convolutions
dilation (int): dilation for module as a whole, impacts dilation of each branch
groups (int): number of groups for each branch
attn_reduction (int, float): reduction factor for attention features
min_attn_channels (int): minimum attention feature channels
rd_ratio (int, float): reduction factor for attention features
keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations
split_input (bool): split input channels evenly across each convolution branch, keeps param count lower,
can be viewed as grouping by path, output expands to module out_channels count
@ -75,7 +75,8 @@ class SelectiveKernelConv(nn.Module):
act_layer (nn.Module): activation layer to use
norm_layer (nn.Module): batchnorm/norm layer to use
"""
super(SelectiveKernelConv, self).__init__()
super(SelectiveKernel, self).__init__()
out_channels = out_channels or in_channels
kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation
_kernel_valid(kernel_size)
if not isinstance(kernel_size, list):
@ -101,7 +102,7 @@ class SelectiveKernelConv(nn.Module):
ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
for k, d in zip(kernel_size, dilation)])
attn_channels = max(int(out_channels / attn_reduction), min_attn_channels)
attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor)
self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
self.drop_block = drop_block

@ -10,6 +10,8 @@ import torch
import torch.nn.functional as F
from torch import nn
from .helpers import make_divisible
class RadixSoftmax(nn.Module):
def __init__(self, radix, cardinality):
@ -28,41 +30,37 @@ class RadixSoftmax(nn.Module):
return x
class SplitAttnConv2d(nn.Module):
"""Split-Attention Conv2d
class SplitAttn(nn.Module):
"""Split-Attention (aka Splat)
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0,
dilation=1, groups=1, bias=False, radix=2, reduction_factor=4,
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None,
dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs):
super(SplitAttnConv2d, self).__init__()
super(SplitAttn, self).__init__()
out_channels = out_channels or in_channels
self.radix = radix
self.drop_block = drop_block
mid_chs = out_channels * radix
attn_chs = max(in_channels * radix // reduction_factor, 32)
if rd_channels is None:
attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor)
else:
attn_chs = rd_channels * radix
padding = kernel_size // 2 if padding is None else padding
self.conv = nn.Conv2d(
in_channels, mid_chs, kernel_size, stride, padding, dilation,
groups=groups * radix, bias=bias, **kwargs)
self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None
self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity()
self.act0 = act_layer(inplace=True)
self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)
self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None
self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity()
self.act1 = act_layer(inplace=True)
self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)
self.rsoftmax = RadixSoftmax(radix, groups)
@property
def in_channels(self):
return self.conv.in_channels
@property
def out_channels(self):
return self.fc1.out_channels
def forward(self, x):
x = self.conv(x)
if self.bn0 is not None:
x = self.bn0(x)
x = self.bn0(x)
if self.drop_block is not None:
x = self.drop_block(x)
x = self.act0(x)
@ -73,10 +71,9 @@ class SplitAttnConv2d(nn.Module):
x_gap = x.sum(dim=1)
else:
x_gap = x
x_gap = F.adaptive_avg_pool2d(x_gap, 1)
x_gap = x_gap.mean((2, 3), keepdim=True)
x_gap = self.fc1(x_gap)
if self.bn1 is not None:
x_gap = self.bn1(x_gap)
x_gap = self.bn1(x_gap)
x_gap = self.act1(x_gap)
x_attn = self.fc2(x_gap)

@ -0,0 +1,74 @@
""" Squeeze-and-Excitation Channel Attention
An SE implementation originally based on PyTorch SE-Net impl.
Has since evolved with additional functionality / configuration.
Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507
Also included is Effective Squeeze-Excitation (ESE).
Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
Hacked together by / Copyright 2021 Ross Wightman
"""
from torch import nn as nn
from .create_act import create_act_layer
from .helpers import make_divisible
class SEModule(nn.Module):
""" SE Module as defined in original SE-Nets with a few additions
Additions include:
* divisor can be specified to keep channels % div == 0 (default: 8)
* reduction channels can be specified directly by arg (if rd_channels is set)
* reduction channels can be specified by float rd_ratio (default: 1/16)
* global max pooling can be added to the squeeze aggregation
* customizable activation, normalization, and gate layer
"""
def __init__(
self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False,
act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'):
super(SEModule, self).__init__()
self.add_maxpool = add_maxpool
if not rd_channels:
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True)
self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity()
self.act = create_act_layer(act_layer, inplace=True)
self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
if self.add_maxpool:
# experimental codepath, may remove or change
x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
x_se = self.fc1(x_se)
x_se = self.act(self.bn(x_se))
x_se = self.fc2(x_se)
return x * self.gate(x_se)
SqueezeExcite = SEModule # alias
class EffectiveSEModule(nn.Module):
""" 'Effective Squeeze-Excitation
From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
"""
def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_):
super(EffectiveSEModule, self).__init__()
self.add_maxpool = add_maxpool
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
if self.add_maxpool:
# experimental codepath, may remove or change
x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
x_se = self.fc(x_se)
return x * self.gate(x_se)
EffectiveSqueezeExcite = EffectiveSEModule # alias

@ -72,6 +72,10 @@ default_cfgs = {
'tf_mobilenetv3_small_minimal_100': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth',
mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD),
'fbnetv3_b': _cfg(),
'fbnetv3_d': _cfg(),
'fbnetv3_g': _cfg(),
}
@ -86,7 +90,7 @@ class MobileNetV3(nn.Module):
"""
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True,
pad_type='', act_layer=None, norm_layer=None, se_layer=None,
pad_type='', act_layer=None, norm_layer=None, se_layer=None, se_from_exp=True,
round_chs_fn=round_channels, drop_rate=0., drop_path_rate=0., global_pool='avg'):
super(MobileNetV3, self).__init__()
act_layer = act_layer or nn.ReLU
@ -104,7 +108,7 @@ class MobileNetV3(nn.Module):
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
output_stride=32, pad_type=pad_type, round_chs_fn=round_chs_fn,
output_stride=32, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp,
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer, drop_path_rate=drop_path_rate)
self.blocks = nn.Sequential(*builder(stem_size, block_args))
self.feature_info = builder.features
@ -161,8 +165,8 @@ class MobileNetV3Features(nn.Module):
and object detection models.
"""
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck',
in_chans=3, stem_size=16, output_stride=32, pad_type='', round_chs_fn=round_channels,
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bottleneck', in_chans=3,
stem_size=16, output_stride=32, pad_type='', round_chs_fn=round_channels, se_from_exp=True,
act_layer=None, norm_layer=None, se_layer=None, drop_rate=0., drop_path_rate=0.):
super(MobileNetV3Features, self).__init__()
act_layer = act_layer or nn.ReLU
@ -178,7 +182,7 @@ class MobileNetV3Features(nn.Module):
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn,
output_stride=output_stride, pad_type=pad_type, round_chs_fn=round_chs_fn, se_from_exp=se_from_exp,
act_layer=act_layer, norm_layer=norm_layer, se_layer=se_layer,
drop_path_rate=drop_path_rate, feature_location=feature_location)
self.blocks = nn.Sequential(*builder(stem_size, block_args))
@ -262,7 +266,7 @@ def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kw
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=resolve_act_layer(kwargs, 'hard_swish'),
se_layer=partial(SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), reduce_from_block=False),
se_layer=partial(SqueezeExcite, gate_layer='hard_sigmoid'),
**kwargs,
)
model = _create_mnv3(variant, pretrained, **model_kwargs)
@ -350,8 +354,7 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
# stage 6, 7x7 in
['cn_r1_k1_s1_c960'], # hard-swish
]
se_layer = partial(
SqueezeExcite, gate_fn=get_act_fn('hard_sigmoid'), force_act_layer=nn.ReLU, reduce_from_block=False, divisor=8)
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
num_features=num_features,
@ -366,6 +369,67 @@ def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwarg
return model
def _gen_fbnetv3(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
""" FBNetV3
Paper: `FBNetV3: Joint Architecture-Recipe Search using Predictor Pretraining`
- https://arxiv.org/abs/2006.02049
FIXME untested, this is a preliminary impl of some FBNet-V3 variants.
"""
vl = variant.split('_')[-1]
if vl in ('a', 'b'):
stem_size = 16
arch_def = [
['ds_r2_k3_s1_e1_c16'],
['ir_r1_k5_s2_e4_c24', 'ir_r3_k5_s1_e2_c24'],
['ir_r1_k5_s2_e5_c40_se0.25', 'ir_r4_k5_s1_e3_c40_se0.25'],
['ir_r1_k5_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'],
['ir_r1_k3_s1_e5_c120_se0.25', 'ir_r5_k5_s1_e3_c120_se0.25'],
['ir_r1_k3_s2_e6_c184_se0.25', 'ir_r5_k5_s1_e4_c184_se0.25', 'ir_r1_k5_s1_e6_c224_se0.25'],
['cn_r1_k1_s1_c1344'],
]
elif vl == 'd':
stem_size = 24
arch_def = [
['ds_r2_k3_s1_e1_c16'],
['ir_r1_k3_s2_e5_c24', 'ir_r5_k3_s1_e2_c24'],
['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r4_k3_s1_e3_c40_se0.25'],
['ir_r1_k3_s2_e5_c72', 'ir_r4_k3_s1_e3_c72'],
['ir_r1_k3_s1_e5_c128_se0.25', 'ir_r6_k5_s1_e3_c128_se0.25'],
['ir_r1_k3_s2_e6_c208_se0.25', 'ir_r5_k5_s1_e5_c208_se0.25', 'ir_r1_k5_s1_e6_c240_se0.25'],
['cn_r1_k1_s1_c1440'],
]
elif vl == 'g':
stem_size = 32
arch_def = [
['ds_r3_k3_s1_e1_c24'],
['ir_r1_k5_s2_e4_c40', 'ir_r4_k5_s1_e2_c40'],
['ir_r1_k5_s2_e4_c56_se0.25', 'ir_r4_k5_s1_e3_c56_se0.25'],
['ir_r1_k5_s2_e5_c104', 'ir_r4_k3_s1_e3_c104'],
['ir_r1_k3_s1_e5_c160_se0.25', 'ir_r8_k5_s1_e3_c160_se0.25'],
['ir_r1_k3_s2_e6_c264_se0.25', 'ir_r6_k5_s1_e5_c264_se0.25', 'ir_r2_k5_s1_e6_c288_se0.25'],
['cn_r1_k1_s1_c1728'],
]
else:
raise NotImplemented
round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.95)
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', rd_round_fn=round_chs_fn)
act_layer = resolve_act_layer(kwargs, 'hard_swish')
model_kwargs = dict(
block_args=decode_arch_def(arch_def),
num_features=1984,
head_bias=False,
stem_size=stem_size,
round_chs_fn=round_chs_fn,
se_from_exp=False,
norm_layer=partial(nn.BatchNorm2d, **resolve_bn_args(kwargs)),
act_layer=act_layer,
se_layer=se_layer,
**kwargs,
)
model = _create_mnv3(variant, pretrained, **model_kwargs)
return model
@register_model
def mobilenetv3_large_075(pretrained=False, **kwargs):
""" MobileNet V3 """
@ -474,3 +538,24 @@ def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs):
kwargs['pad_type'] = 'same'
model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def fbnetv3_b(pretrained=False, **kwargs):
""" FBNetV3-B """
model = _gen_fbnetv3('fbnetv3_b', pretrained=pretrained, **kwargs)
return model
@register_model
def fbnetv3_d(pretrained=False, **kwargs):
""" FBNetV3-D """
model = _gen_fbnetv3('fbnetv3_d', pretrained=pretrained, **kwargs)
return model
@register_model
def fbnetv3_g(pretrained=False, **kwargs):
""" FBNetV3-G """
model = _gen_fbnetv3('fbnetv3_g', pretrained=pretrained, **kwargs)
return model

@ -182,7 +182,7 @@ def _nfres_cfg(
def _nfreg_cfg(depths, channels=(48, 104, 208, 440)):
num_features = 1280 * channels[-1] // 440
attn_kwargs = dict(reduction_ratio=0.5, divisor=8)
attn_kwargs = dict(rd_ratio=0.5)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25,
num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs)
@ -193,7 +193,7 @@ def _nfnet_cfg(
depths, channels=(256, 512, 1536, 1536), group_size=128, bottle_ratio=0.5, feat_mult=2.,
act_layer='gelu', attn_layer='se', attn_kwargs=None):
num_features = int(channels[-1] * feat_mult)
attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(reduction_ratio=0.5, divisor=8)
attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(rd_ratio=0.5)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=group_size,
bottle_ratio=bottle_ratio, extra_conv=True, num_features=num_features, act_layer=act_layer,
@ -202,11 +202,10 @@ def _nfnet_cfg(
def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', skipinit=True):
attn_kwargs = dict(reduction_ratio=0.5, divisor=8)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=128,
bottle_ratio=0.5, extra_conv=True, gamma_in_act=True, same_padding=True, skipinit=skipinit,
num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=attn_kwargs)
num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=dict(rd_ratio=0.5))
return cfg
@ -243,7 +242,7 @@ model_cfgs = dict(
# Experimental 'light' versions of NFNet-F that are little leaner
nfnet_l0=_nfnet_cfg(
depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
attn_kwargs=dict(reduction_ratio=0.25, divisor=8), act_layer='silu'),
attn_kwargs=dict(rd_ratio=0.25, rd_divisor=8), act_layer='silu'),
eca_nfnet_l0=_nfnet_cfg(
depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
@ -272,9 +271,9 @@ model_cfgs = dict(
nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)),
nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)),
nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)),
nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)),
nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)),
nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()),
nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()),

@ -146,7 +146,7 @@ class Bottleneck(nn.Module):
groups=groups, **cargs)
if se_ratio:
se_channels = int(round(in_chs * se_ratio))
self.se = SEModule(bottleneck_chs, reduction_channels=se_channels)
self.se = SEModule(bottleneck_chs, rd_channels=se_channels)
else:
self.se = None
cargs['act_layer'] = None

@ -11,7 +11,7 @@ from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import SplitAttnConv2d
from .layers import SplitAttn
from .registry import register_model
from .resnet import ResNet
@ -83,11 +83,11 @@ class ResNestBottleneck(nn.Module):
self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None
if self.radix >= 1:
self.conv2 = SplitAttnConv2d(
self.conv2 = SplitAttn(
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block)
self.bn2 = None # FIXME revisit, here to satisfy current torchscript fussyness
self.act2 = None
self.bn2 = nn.Identity()
self.act2 = nn.Identity()
else:
self.conv2 = nn.Conv2d(
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
@ -117,11 +117,10 @@ class ResNestBottleneck(nn.Module):
out = self.avd_first(out)
out = self.conv2(out)
if self.bn2 is not None:
out = self.bn2(out)
if self.drop_block is not None:
out = self.drop_block(out)
out = self.act2(out)
out = self.bn2(out)
if self.drop_block is not None:
out = self.drop_block(out)
out = self.act2(out)
if self.avd_last is not None:
out = self.avd_last(out)

@ -1122,7 +1122,7 @@ def resnetrs50(pretrained=False, **kwargs):
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
@ -1135,7 +1135,7 @@ def resnetrs101(pretrained=False, **kwargs):
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
@ -1148,7 +1148,7 @@ def resnetrs152(pretrained=False, **kwargs):
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
@ -1161,7 +1161,7 @@ def resnetrs200(pretrained=False, **kwargs):
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
@ -1174,7 +1174,7 @@ def resnetrs270(pretrained=False, **kwargs):
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
@ -1188,7 +1188,7 @@ def resnetrs350(pretrained=False, **kwargs):
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
@ -1201,7 +1201,7 @@ def resnetrs420(pretrained=False, **kwargs):
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)

@ -11,11 +11,12 @@ Copyright 2020 Ross Wightman
"""
import torch.nn as nn
from functools import partial
from math import ceil
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath, make_divisible
from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath, make_divisible, SEModule
from .registry import register_model
from .efficientnet_builder import efficientnet_init_weights
@ -48,26 +49,7 @@ default_cfgs = dict(
url=''),
)
class SEWithNorm(nn.Module):
def __init__(self, channels, se_ratio=1 / 12., act_layer=nn.ReLU, divisor=1, reduction_channels=None,
gate_layer='sigmoid'):
super(SEWithNorm, self).__init__()
reduction_channels = reduction_channels or make_divisible(int(channels * se_ratio), divisor=divisor)
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
self.bn = nn.BatchNorm2d(reduction_channels)
self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
x_se = self.fc1(x_se)
x_se = self.bn(x_se)
x_se = self.act(x_se)
x_se = self.fc2(x_se)
return x * self.gate(x_se)
SEWithNorm = partial(SEModule, norm_layer=nn.BatchNorm2d)
class LinearBottleneck(nn.Module):
@ -86,7 +68,10 @@ class LinearBottleneck(nn.Module):
self.conv_exp = None
self.conv_dw = ConvBnAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False)
self.se = SEWithNorm(dw_chs, se_ratio=se_ratio, divisor=ch_div) if se_ratio > 0. else None
if se_ratio > 0:
self.se = SEWithNorm(dw_chs, rd_channels=make_divisible(int(dw_chs * se_ratio), ch_div))
else:
self.se = None
self.act_dw = create_act_layer(dw_act_layer)
self.conv_pwl = ConvBnAct(dw_chs, out_chs, 1, apply_act=False)

@ -14,7 +14,7 @@ from torch import nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import SelectiveKernelConv, ConvBnAct, create_attn
from .layers import SelectiveKernel, ConvBnAct, create_attn
from .registry import register_model
from .resnet import ResNet
@ -59,7 +59,7 @@ class SelectiveKernelBasic(nn.Module):
outplanes = planes * self.expansion
first_dilation = first_dilation or dilation
self.conv1 = SelectiveKernelConv(
self.conv1 = SelectiveKernel(
inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs)
conv_kwargs['act_layer'] = None
self.conv2 = ConvBnAct(
@ -107,7 +107,7 @@ class SelectiveKernelBottleneck(nn.Module):
first_dilation = first_dilation or dilation
self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs)
self.conv2 = SelectiveKernelConv(
self.conv2 = SelectiveKernel(
first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality,
**conv_kwargs, **sk_kwargs)
conv_kwargs['act_layer'] = None
@ -153,10 +153,7 @@ def skresnet18(pretrained=False, **kwargs):
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
variation splits the input channels to the selective convolutions to keep param count down.
"""
sk_kwargs = dict(
min_attn_channels=16,
attn_reduction=8,
split_input=True)
sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True)
model_args = dict(
block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs),
zero_init_last_bn=False, **kwargs)
@ -170,10 +167,7 @@ def skresnet34(pretrained=False, **kwargs):
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
variation splits the input channels to the selective convolutions to keep param count down.
"""
sk_kwargs = dict(
min_attn_channels=16,
attn_reduction=8,
split_input=True)
sk_kwargs = dict(rd_ratio=1 / 8, rd_divisor=16, split_input=True)
model_args = dict(
block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs),
zero_init_last_bn=False, **kwargs)
@ -213,8 +207,9 @@ def skresnext50_32x4d(pretrained=False, **kwargs):
"""Constructs a Select Kernel ResNeXt50-32x4d model. This should be equivalent to
the SKNet-50 model in the Select Kernel Paper
"""
sk_kwargs = dict(rd_ratio=1/16, rd_divisor=32, split_input=False)
model_args = dict(
block=SelectiveKernelBottleneck, layers=[3, 4, 6, 3], cardinality=32, base_width=4,
zero_init_last_bn=False, **kwargs)
block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)
return _create_skresnet('skresnext50_32x4d', pretrained, **model_args)

@ -84,8 +84,8 @@ class BasicBlock(nn.Module):
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
reduction_chs = max(planes * self.expansion // 4, 64)
self.se = SEModule(planes * self.expansion, reduction_channels=reduction_chs) if use_se else None
rd_chs = max(planes * self.expansion // 4, 64)
self.se = SEModule(planes * self.expansion, rd_channels=rd_chs) if use_se else None
def forward(self, x):
if self.downsample is not None:
@ -125,7 +125,7 @@ class Bottleneck(nn.Module):
aa_layer(channels=planes, filt_size=3, stride=2))
reduction_chs = max(planes * self.expansion // 8, 64)
self.se = SEModule(planes, reduction_channels=reduction_chs) if use_se else None
self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None
self.conv3 = conv2d_iabn(
planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity")

@ -13,7 +13,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed
from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d
from .registry import register_model
@ -39,15 +39,6 @@ default_cfgs = dict(
)
class LayerNormBHWC(nn.LayerNorm):
def __init__(self, dim):
super().__init__(dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
class SpatialMlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None,
act_layer=nn.GELU, drop=0., group=8, spatial_conv=False):
@ -119,7 +110,7 @@ class Attention(nn.Module):
class Block(nn.Module):
def __init__(self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4.,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormBHWC,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm2d,
group=8, attn_disabled=False, spatial_conv=False):
super().__init__()
self.spatial_conv = spatial_conv
@ -148,7 +139,7 @@ class Block(nn.Module):
class Visformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384,
depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=LayerNormBHWC, attn_stage='111', pos_embed=True, spatial_conv='111',
norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111',
vit_stem=False, group=8, pool=True, conv_init=False, embed_norm=None):
super().__init__()
self.num_classes = num_classes

@ -1 +1 @@
__version__ = '0.4.10'
__version__ = '0.4.11'

Loading…
Cancel
Save