Add Gather-Excite and Global Context attn modules. Refactor existing SE-like attn for consistency and refactor byob/byoanet for less redundancy.

more_attn
Ross Wightman 3 years ago
parent 9c78de8c02
commit 742c2d5247

@ -17,7 +17,6 @@ from .inception_resnet_v2 import *
from .inception_v3 import *
from .inception_v4 import *
from .levit import *
#from .levit import *
from .mlp_mixer import *
from .mobilenetv3 import *
from .nasnet import *

@ -12,24 +12,12 @@ Consider all of the models definitions here as experimental WIP and likely to ch
Hacked together by / copyright Ross Wightman, 2021.
"""
import math
from dataclasses import dataclass, field
from collections import OrderedDict
from typing import Tuple, List, Optional, Union, Any, Callable
from functools import partial
import torch
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .byobnet import BlocksCfg, ByobCfg, create_byob_stem, create_byob_stages, create_downsample,\
reduce_feat_size, register_block, num_groups, LayerFn, _init_weights
from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks
from .helpers import build_model_with_cfg
from .layers import ClassifierHead, ConvBnAct, DropPath, get_act_layer, convert_norm_act, get_attn, get_self_attn,\
make_divisible, to_2tuple
from .registry import register_model
__all__ = ['ByoaNet']
__all__ = []
def _cfg(url='', **kwargs):
@ -63,100 +51,68 @@ default_cfgs = {
'swinnet50ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'eca_swinnext26ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)),
'rednet26t': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)),
'rednet50ts': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)),
'rednet26t': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
'rednet50ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
}
@dataclass
class ByoaBlocksCfg(BlocksCfg):
# FIXME allow overriding self_attn layer or args per block/stage,
pass
@dataclass
class ByoaCfg(ByobCfg):
blocks: Tuple[Union[ByoaBlocksCfg, Tuple[ByoaBlocksCfg, ...]], ...] = None
self_attn_layer: Optional[str] = None
self_attn_fixed_size: bool = False
self_attn_kwargs: dict = field(default_factory=lambda: dict())
def interleave_attn(
types : Tuple[str, str], every: Union[int, List[int]], d, first: bool = False, **kwargs
) -> Tuple[ByoaBlocksCfg]:
""" interleave attn blocks
"""
assert len(types) == 2
if isinstance(every, int):
every = list(range(0 if first else every, d, every))
if not every:
every = [d - 1]
set(every)
blocks = []
for i in range(d):
block_type = types[1] if i in every else types[0]
blocks += [ByoaBlocksCfg(type=block_type, d=1, **kwargs)]
return tuple(blocks)
model_cfgs = dict(
botnet26t=ByoaCfg(
botnet26t=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
fixed_input_size=True,
self_attn_layer='bottleneck',
self_attn_fixed_size=True,
self_attn_kwargs=dict()
),
botnet50ts=ByoaCfg(
botnet50ts=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='',
num_features=0,
fixed_input_size=True,
act_layer='silu',
self_attn_layer='bottleneck',
self_attn_fixed_size=True,
self_attn_kwargs=dict()
),
eca_botnext26ts=ByoaCfg(
eca_botnext26ts=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25),
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
fixed_input_size=True,
act_layer='silu',
attn_layer='eca',
self_attn_layer='bottleneck',
self_attn_fixed_size=True,
self_attn_kwargs=dict()
),
halonet_h1=ByoaCfg(
halonet_h1=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='self_attn', d=3, c=64, s=1, gs=0, br=1.0),
ByoaBlocksCfg(type='self_attn', d=3, c=128, s=2, gs=0, br=1.0),
ByoaBlocksCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
ByoaBlocksCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
ByoBlockCfg(type='self_attn', d=3, c=64, s=1, gs=0, br=1.0),
ByoBlockCfg(type='self_attn', d=3, c=128, s=2, gs=0, br=1.0),
ByoBlockCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
ByoBlockCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
),
stem_chs=64,
stem_type='7x7',
@ -165,12 +121,12 @@ model_cfgs = dict(
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=3),
),
halonet_h1_c4c5=ByoaCfg(
halonet_h1_c4c5=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=64, s=1, gs=0, br=1.0),
ByoaBlocksCfg(type='bottle', d=3, c=128, s=2, gs=0, br=1.0),
ByoaBlocksCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
ByoaBlocksCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=0, br=1.0),
ByoBlockCfg(type='bottle', d=3, c=128, s=2, gs=0, br=1.0),
ByoBlockCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0),
ByoBlockCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0),
),
stem_chs=64,
stem_type='tiered',
@ -179,12 +135,12 @@ model_cfgs = dict(
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=3),
),
halonet26t=ByoaCfg(
halonet26t=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
@ -193,12 +149,12 @@ model_cfgs = dict(
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res
),
halonet50ts=ByoaCfg(
halonet50ts=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
@ -208,12 +164,12 @@ model_cfgs = dict(
self_attn_layer='halo',
self_attn_kwargs=dict(block_size=8, halo_size=2)
),
eca_halonext26ts=ByoaCfg(
eca_halonext26ts=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
),
stem_chs=64,
stem_type='tiered',
@ -225,12 +181,12 @@ model_cfgs = dict(
self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res
),
lambda_resnet26t=ByoaCfg(
lambda_resnet26t=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
@ -239,12 +195,12 @@ model_cfgs = dict(
self_attn_layer='lambda',
self_attn_kwargs=dict()
),
lambda_resnet50t=ByoaCfg(
lambda_resnet50t=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=3, d=6, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=3, d=6, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
@ -253,12 +209,12 @@ model_cfgs = dict(
self_attn_layer='lambda',
self_attn_kwargs=dict()
),
eca_lambda_resnext26ts=ByoaCfg(
eca_lambda_resnext26ts=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
),
stem_chs=64,
stem_type='tiered',
@ -270,77 +226,76 @@ model_cfgs = dict(
self_attn_kwargs=dict()
),
swinnet26t=ByoaCfg(
swinnet26t=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
fixed_input_size=True,
self_attn_layer='swin',
self_attn_fixed_size=True,
self_attn_kwargs=dict(win_size=8)
),
swinnet50ts=ByoaCfg(
swinnet50ts=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=4, c=512, s=2, gs=0, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=4, c=512, s=2, gs=0, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
fixed_input_size=True,
act_layer='silu',
self_attn_layer='swin',
self_attn_fixed_size=True,
self_attn_kwargs=dict(win_size=8)
),
eca_swinnext26ts=ByoaCfg(
eca_swinnext26ts=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=16, br=0.25),
interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=16, br=0.25),
interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool='maxpool',
num_features=0,
fixed_input_size=True,
act_layer='silu',
attn_layer='eca',
self_attn_layer='swin',
self_attn_fixed_size=True,
self_attn_kwargs=dict(win_size=8)
),
rednet26t=ByoaCfg(
rednet26t=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='self_attn', d=2, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=512, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=512, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered', # FIXME RedNet uses involution in middle of stem
stem_pool='maxpool',
num_features=0,
self_attn_layer='involution',
self_attn_fixed_size=False,
self_attn_kwargs=dict()
),
rednet50ts=ByoaCfg(
rednet50ts=ByoModelCfg(
blocks=(
ByoaBlocksCfg(type='self_attn', d=3, c=256, s=1, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=4, c=512, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25),
ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=256, s=1, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=4, c=512, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25),
ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25),
),
stem_chs=64,
stem_type='tiered',
@ -348,161 +303,14 @@ model_cfgs = dict(
num_features=0,
act_layer='silu',
self_attn_layer='involution',
self_attn_fixed_size=False,
self_attn_kwargs=dict()
),
)
@dataclass
class ByoaLayerFn(LayerFn):
self_attn: Optional[Callable] = None
class SelfAttnBlock(nn.Module):
""" ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', extra_conv=False, linear_out=False, post_attn_na=True, feat_size=None,
layers: ByoaLayerFn = None, drop_block=None, drop_path_rate=0.):
super(SelfAttnBlock, self).__init__()
assert layers is not None
mid_chs = make_divisible(out_chs * bottle_ratio)
groups = num_groups(group_size, mid_chs)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = create_downsample(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
apply_act=False, layers=layers)
else:
self.shortcut = nn.Identity()
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
if extra_conv:
self.conv2_kxk = layers.conv_norm_act(
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block)
stride = 1 # striding done via conv if enabled
else:
self.conv2_kxk = nn.Identity()
opt_kwargs = {} if feat_size is None else dict(feat_size=feat_size)
# FIXME need to dilate self attn to have dilated network support, moop moop
self.self_attn = layers.self_attn(mid_chs, stride=stride, **opt_kwargs)
self.post_attn = layers.norm_act(mid_chs) if post_attn_na else nn.Identity()
self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
if zero_init_last_bn:
nn.init.zeros_(self.conv3_1x1.bn.weight)
if hasattr(self.self_attn, 'reset_parameters'):
self.self_attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)
x = self.conv1_1x1(x)
x = self.conv2_kxk(x)
x = self.self_attn(x)
x = self.post_attn(x)
x = self.conv3_1x1(x)
x = self.drop_path(x)
x = self.act(x + shortcut)
return x
register_block('self_attn', SelfAttnBlock)
def _byoa_block_args(block_kwargs, block_cfg: ByoaBlocksCfg, model_cfg: ByoaCfg, feat_size=None):
if block_cfg.type == 'self_attn' and model_cfg.self_attn_fixed_size:
assert feat_size is not None
block_kwargs['feat_size'] = feat_size
return block_kwargs
def get_layer_fns(cfg: ByoaCfg):
act = get_act_layer(cfg.act_layer)
norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act)
conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act)
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
self_attn = partial(get_self_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None
layer_fn = ByoaLayerFn(
conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn)
return layer_fn
class ByoaNet(nn.Module):
""" 'Bring-your-own-attention' Net
A ResNet inspired backbone that supports interleaving traditional residual blocks with
'Self Attention' bottleneck blocks that replace the bottleneck kxk conv w/ a self-attention
or similar module.
FIXME This class network definition is almost the same as ByobNet, I'd like to merge them but
torchscript limitations prevent sensible inheritance overrides.
"""
def __init__(self, cfg: ByoaCfg, num_classes=1000, in_chans=3, output_stride=32, global_pool='avg',
zero_init_last_bn=True, img_size=None, drop_rate=0., drop_path_rate=0.):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
layers = get_layer_fns(cfg)
feat_size = to_2tuple(img_size) if img_size is not None else None
self.feature_info = []
stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor))
self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers)
self.feature_info.extend(stem_feat[:-1])
feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction'])
self.stages, stage_feat = create_byob_stages(
cfg, drop_path_rate, output_stride, stem_feat[-1],
feat_size=feat_size, layers=layers, extra_args_fn=_byoa_block_args)
self.feature_info.extend(stage_feat[:-1])
prev_chs = stage_feat[-1]['num_chs']
if cfg.num_features:
self.num_features = int(round(cfg.width_factor * cfg.num_features))
self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1)
else:
self.num_features = prev_chs
self.final_conv = nn.Identity()
self.feature_info += [
dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')]
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
for n, m in self.named_modules():
_init_weights(m, n)
for m in self.modules():
# call each block's weight init for block-specific overrides to init above
if hasattr(m, 'init_weights'):
m.init_weights(zero_init_last_bn=zero_init_last_bn)
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
x = self.final_conv(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs):
return build_model_with_cfg(
ByoaNet, variant, pretrained,
ByobNet, variant, pretrained,
default_cfg=default_cfgs[variant],
model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant],
feature_cfg=dict(flatten_sequential=True),

@ -26,8 +26,7 @@ Hacked together by / copyright Ross Wightman, 2021.
"""
import math
from dataclasses import dataclass, field, replace
from collections import OrderedDict
from typing import Tuple, List, Optional, Union, Any, Callable, Sequence
from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence
from functools import partial
import torch
@ -36,10 +35,10 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible
create_conv2d, get_act_layer, convert_norm_act, get_attn, get_self_attn, make_divisible, to_2tuple
from .registry import register_model
__all__ = ['ByobNet', 'ByobCfg', 'BlocksCfg', 'create_byob_stem', 'create_block']
__all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block']
def _cfg(url='', **kwargs):
@ -87,35 +86,52 @@ default_cfgs = {
'repvgg_b3g4': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_b3g4-73c370bf.pth',
first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv')),
# experimental configs
'resnet52qs': _cfg(first_conv='stem.conv1.conv'),
'geresnet50t': _cfg(first_conv='stem.conv1.conv'),
'gcresnet50t': _cfg(first_conv='stem.conv1.conv'),
}
@dataclass
class BlocksCfg:
class ByoBlockCfg:
type: Union[str, nn.Module]
d: int # block depth (number of block repeats in stage)
c: int # number of output channels for each block in stage
s: int = 2 # stride of stage (first block)
gs: Optional[Union[int, Callable]] = None # group-size of blocks in stage, conv is depthwise if gs == 1
br: float = 1. # bottleneck-ratio of blocks in stage
no_attn: bool = False # disable channel attn (ie SE) when layer is set for model
# NOTE: these config items override the model cfgs that are applied to all blocks by default
attn_layer: Optional[str] = None
attn_kwargs: Optional[Dict[str, Any]] = None
self_attn_layer: Optional[str] = None
self_attn_kwargs: Optional[Dict[str, Any]] = None
block_kwargs: Optional[Dict[str, Any]] = None
@dataclass
class ByobCfg:
blocks: Tuple[Union[BlocksCfg, Tuple[BlocksCfg, ...]], ...]
class ByoModelCfg:
blocks: Tuple[Union[ByoBlockCfg, Tuple[ByoBlockCfg, ...]], ...]
downsample: str = 'conv1x1'
stem_type: str = '3x3'
stem_pool: str = ''
stem_pool: Optional[str] = 'maxpool'
stem_chs: int = 32
width_factor: float = 1.0
num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0
zero_init_last_bn: bool = True
fixed_input_size: bool = False # model constrained to a fixed-input size / img_size must be provided on creation
act_layer: str = 'relu'
norm_layer: str = 'batchnorm'
# NOTE: these config items will be overridden by the block cfg (per-block) if they are set there
attn_layer: Optional[str] = None
attn_kwargs: dict = field(default_factory=lambda: dict())
self_attn_layer: Optional[str] = None
self_attn_kwargs: dict = field(default_factory=lambda: dict())
block_kwargs: Dict[str, Any] = field(default_factory=lambda: dict())
def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0):
@ -123,103 +139,155 @@ def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0):
group_size = 0
if groups > 0:
group_size = lambda chs, idx: chs // groups if (idx + 1) % 2 == 0 else 0
bcfg = tuple([BlocksCfg(type='rep', d=d, c=c * wf, gs=group_size) for d, c, wf in zip(d, c, wf)])
bcfg = tuple([ByoBlockCfg(type='rep', d=d, c=c * wf, gs=group_size) for d, c, wf in zip(d, c, wf)])
return bcfg
model_cfgs = dict(
def interleave_blocks(
types: Tuple[str, str], every: Union[int, List[int]], d, first: bool = False, **kwargs
) -> Tuple[ByoBlockCfg]:
""" interleave 2 block types in stack
"""
assert len(types) == 2
if isinstance(every, int):
every = list(range(0 if first else every, d, every))
if not every:
every = [d - 1]
set(every)
blocks = []
for i in range(d):
block_type = types[1] if i in every else types[0]
blocks += [ByoBlockCfg(type=block_type, d=1, **kwargs)]
return tuple(blocks)
gernet_l=ByobCfg(
model_cfgs = dict(
gernet_l=ByoModelCfg(
blocks=(
BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
BlocksCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
BlocksCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
BlocksCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.),
BlocksCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.),
ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
ByoBlockCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.),
ByoBlockCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.),
),
stem_chs=32,
stem_pool=None,
num_features=2560,
),
gernet_m=ByobCfg(
gernet_m=ByoModelCfg(
blocks=(
BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
BlocksCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
BlocksCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
BlocksCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.),
BlocksCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.),
ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.),
ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.),
ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4),
ByoBlockCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.),
ByoBlockCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.),
),
stem_chs=32,
stem_pool=None,
num_features=2560,
),
gernet_s=ByobCfg(
gernet_s=ByoModelCfg(
blocks=(
BlocksCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.),
BlocksCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.),
BlocksCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4),
BlocksCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.),
BlocksCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.),
ByoBlockCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.),
ByoBlockCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.),
ByoBlockCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4),
ByoBlockCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.),
ByoBlockCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.),
),
stem_chs=13,
stem_pool=None,
num_features=1920,
),
repvgg_a2=ByobCfg(
repvgg_a2=ByoModelCfg(
blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1.5, 1.5, 1.5, 2.75)),
stem_type='rep',
stem_chs=64,
),
repvgg_b0=ByobCfg(
repvgg_b0=ByoModelCfg(
blocks=_rep_vgg_bcfg(wf=(1., 1., 1., 2.5)),
stem_type='rep',
stem_chs=64,
),
repvgg_b1=ByobCfg(
repvgg_b1=ByoModelCfg(
blocks=_rep_vgg_bcfg(wf=(2., 2., 2., 4.)),
stem_type='rep',
stem_chs=64,
),
repvgg_b1g4=ByobCfg(
repvgg_b1g4=ByoModelCfg(
blocks=_rep_vgg_bcfg(wf=(2., 2., 2., 4.), groups=4),
stem_type='rep',
stem_chs=64,
),
repvgg_b2=ByobCfg(
repvgg_b2=ByoModelCfg(
blocks=_rep_vgg_bcfg(wf=(2.5, 2.5, 2.5, 5.)),
stem_type='rep',
stem_chs=64,
),
repvgg_b2g4=ByobCfg(
repvgg_b2g4=ByoModelCfg(
blocks=_rep_vgg_bcfg(wf=(2.5, 2.5, 2.5, 5.), groups=4),
stem_type='rep',
stem_chs=64,
),
repvgg_b3=ByobCfg(
repvgg_b3=ByoModelCfg(
blocks=_rep_vgg_bcfg(wf=(3., 3., 3., 5.)),
stem_type='rep',
stem_chs=64,
),
repvgg_b3g4=ByobCfg(
repvgg_b3g4=ByoModelCfg(
blocks=_rep_vgg_bcfg(wf=(3., 3., 3., 5.), groups=4),
stem_type='rep',
stem_chs=64,
),
resnet52q=ByobCfg(
# WARN: experimental, may vanish/change
resnet52q=ByoModelCfg(
blocks=(
BlocksCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
BlocksCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
BlocksCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25),
BlocksCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0),
ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25),
ByoBlockCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0),
),
stem_chs=128,
stem_type='quad',
num_features=2048,
act_layer='silu',
),
# WARN: experimental, may vanish/change
geresnet50t=ByoModelCfg(
blocks=(
ByoBlockCfg(type='edge', d=3, c=256, s=1, br=0.25),
ByoBlockCfg(type='edge', d=4, c=512, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool=None,
attn_layer='ge',
attn_kwargs=dict(extent=8, extra_params=True),
#attn_kwargs=dict(extent=8),
#block_kwargs=dict(attn_last=True)
),
# WARN: experimental, may vanish/change
gcresnet50t=ByoModelCfg(
blocks=(
ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25),
ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25),
ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25),
),
stem_chs=64,
stem_type='tiered',
stem_pool=None,
attn_layer='gc'
),
)
def expand_blocks_cfg(stage_blocks_cfg: Union[BlocksCfg, Sequence[BlocksCfg]]) -> List[BlocksCfg]:
def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]:
if not isinstance(stage_blocks_cfg, Sequence):
stage_blocks_cfg = (stage_blocks_cfg,)
block_cfgs = []
@ -243,6 +311,7 @@ class LayerFn:
norm_act: Callable = BatchNormAct2d
act: Callable = nn.ReLU
attn: Optional[Callable] = None
self_attn: Optional[Callable] = None
class DownsampleAvg(nn.Module):
@ -275,7 +344,8 @@ class BasicBlock(nn.Module):
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), group_size=None, bottle_ratio=1.0,
downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None,
drop_path_rate=0.):
super(BasicBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
@ -289,15 +359,19 @@ class BasicBlock(nn.Module):
self.shortcut = nn.Identity()
self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0])
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
self.conv2_kxk = layers.conv_norm_act(
mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block, apply_act=False)
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
def init_weights(self, zero_init_last_bn: bool = False):
if zero_init_last_bn:
nn.init.zeros_(self.conv2_kxk.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)
@ -317,7 +391,8 @@ class BottleneckBlock(nn.Module):
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', linear_out=False, layers : LayerFn = None, drop_block=None, drop_path_rate=0.):
downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None, drop_block=None,
drop_path_rate=0.):
super(BottleneckBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
@ -334,14 +409,18 @@ class BottleneckBlock(nn.Module):
self.conv2_kxk = layers.conv_norm_act(
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block)
self.attn = nn.Identity() if layers.attn is None else layers.attn(mid_chs)
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
def init_weights(self, zero_init_last_bn: bool = False):
if zero_init_last_bn:
nn.init.zeros_(self.conv3_1x1.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)
@ -350,6 +429,7 @@ class BottleneckBlock(nn.Module):
x = self.conv2_kxk(x)
x = self.attn(x)
x = self.conv3_1x1(x)
x = self.attn_last(x)
x = self.drop_path(x)
x = self.act(x + shortcut)
@ -368,7 +448,8 @@ class DarkBlock(nn.Module):
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None,
drop_path_rate=0.):
super(DarkBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
@ -382,23 +463,28 @@ class DarkBlock(nn.Module):
self.shortcut = nn.Identity()
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
self.conv2_kxk = layers.conv_norm_act(
mid_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block, apply_act=False)
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
def init_weights(self, zero_init_last_bn: bool = False):
if zero_init_last_bn:
nn.init.zeros_(self.conv2_kxk.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)
x = self.conv1_1x1(x)
x = self.conv2_kxk(x)
x = self.attn(x)
x = self.conv2_kxk(x)
x = self.attn_last(x)
x = self.drop_path(x)
x = self.act(x + shortcut)
return x
@ -415,7 +501,8 @@ class EdgeBlock(nn.Module):
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None,
drop_block=None, drop_path_rate=0.):
super(EdgeBlock, self).__init__()
layers = layers or LayerFn()
mid_chs = make_divisible(out_chs * bottle_ratio)
@ -431,14 +518,18 @@ class EdgeBlock(nn.Module):
self.conv1_kxk = layers.conv_norm_act(
in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block)
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
self.conv2_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
def init_weights(self, zero_init_last_bn: bool = False):
if zero_init_last_bn:
nn.init.zeros_(self.conv2_1x1.bn.weight)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)
@ -446,6 +537,7 @@ class EdgeBlock(nn.Module):
x = self.conv1_kxk(x)
x = self.attn(x)
x = self.conv2_1x1(x)
x = self.attn_last(x)
x = self.drop_path(x)
x = self.act(x + shortcut)
return x
@ -460,7 +552,7 @@ class RepVggBlock(nn.Module):
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None,
downsample='', layers : LayerFn = None, drop_block=None, drop_path_rate=0.):
downsample='', layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(RepVggBlock, self).__init__()
layers = layers or LayerFn()
groups = num_groups(group_size, in_chs)
@ -475,12 +567,15 @@ class RepVggBlock(nn.Module):
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
self.act = layers.act(inplace=True)
def init_weights(self, zero_init_last_bn=False):
def init_weights(self, zero_init_last_bn: bool = False):
# NOTE this init overrides that base model init with specific changes for the block type
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
nn.init.normal_(m.weight, .1, .1)
nn.init.normal_(m.bias, 0, .1)
for attn in (self.attn, self.attn_last):
if hasattr(attn, 'reset_parameters'):
attn.reset_parameters()
def forward(self, x):
if self.identity is None:
@ -495,12 +590,68 @@ class RepVggBlock(nn.Module):
return x
class SelfAttnBlock(nn.Module):
""" ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1
"""
def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None,
downsample='avg', extra_conv=False, linear_out=False, post_attn_na=True, feat_size=None,
layers: LayerFn = None, drop_block=None, drop_path_rate=0.):
super(SelfAttnBlock, self).__init__()
assert layers is not None
mid_chs = make_divisible(out_chs * bottle_ratio)
groups = num_groups(group_size, mid_chs)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
self.shortcut = create_downsample(
downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0],
apply_act=False, layers=layers)
else:
self.shortcut = nn.Identity()
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
if extra_conv:
self.conv2_kxk = layers.conv_norm_act(
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block)
stride = 1 # striding done via conv if enabled
else:
self.conv2_kxk = nn.Identity()
opt_kwargs = {} if feat_size is None else dict(feat_size=feat_size)
# FIXME need to dilate self attn to have dilated network support, moop moop
self.self_attn = layers.self_attn(mid_chs, stride=stride, **opt_kwargs)
self.post_attn = layers.norm_act(mid_chs) if post_attn_na else nn.Identity()
self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
def init_weights(self, zero_init_last_bn: bool = False):
if zero_init_last_bn:
nn.init.zeros_(self.conv3_1x1.bn.weight)
if hasattr(self.self_attn, 'reset_parameters'):
self.self_attn.reset_parameters()
def forward(self, x):
shortcut = self.shortcut(x)
x = self.conv1_1x1(x)
x = self.conv2_kxk(x)
x = self.self_attn(x)
x = self.post_attn(x)
x = self.conv3_1x1(x)
x = self.drop_path(x)
x = self.act(x + shortcut)
return x
_block_registry = dict(
basic=BasicBlock,
bottle=BottleneckBlock,
dark=DarkBlock,
edge=EdgeBlock,
rep=RepVggBlock,
self_attn=SelfAttnBlock,
)
@ -552,7 +703,7 @@ class Stem(nn.Sequential):
curr_stride *= s
prev_feat = conv_name
if 'max' in pool.lower():
if pool and 'max' in pool.lower():
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
self.add_module('pool', nn.MaxPool2d(3, 2, 1))
curr_stride *= 2
@ -601,9 +752,58 @@ def reduce_feat_size(feat_size, stride=2):
return None if feat_size is None else tuple([s // stride for s in feat_size])
def override_kwargs(block_kwargs, model_kwargs):
""" Override model level attn/self-attn/block kwargs w/ block level
NOTE: kwargs are NOT merged across levels, block_kwargs will fully replace model_kwargs
for the block if set to anything that isn't None.
i.e. an empty block_kwargs dict will remove kwargs set at model level for that block
"""
out_kwargs = block_kwargs if block_kwargs is not None else model_kwargs
return out_kwargs or {} # make sure None isn't returned
def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, model_cfg: ByoModelCfg, ):
layer_fns = block_kwargs['layers']
# override attn layer / args with block local config
if block_cfg.attn_kwargs is not None or block_cfg.attn_layer is not None:
# override attn layer config
if not block_cfg.attn_layer:
# empty string for attn_layer type will disable attn for this block
attn_layer = None
else:
attn_kwargs = override_kwargs(block_cfg.attn_kwargs, model_cfg.attn_kwargs)
attn_layer = block_cfg.attn_layer or model_cfg.attn_layer
attn_layer = partial(get_attn(attn_layer), *attn_kwargs) if attn_layer is not None else None
layer_fns = replace(layer_fns, attn=attn_layer)
# override self-attn layer / args with block local cfg
if block_cfg.self_attn_kwargs is not None or block_cfg.self_attn_layer is not None:
# override attn layer config
if not block_cfg.self_attn_layer:
# empty string for self_attn_layer type will disable attn for this block
self_attn_layer = None
else:
self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs)
self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer
self_attn_layer = partial(get_self_attn(self_attn_layer), *self_attn_kwargs) \
if self_attn_layer is not None else None
layer_fns = replace(layer_fns, self_attn=self_attn_layer)
block_kwargs['layers'] = layer_fns
# add additional block_kwargs specified in block_cfg or model_cfg, precedence to block if set
block_kwargs.update(override_kwargs(block_cfg.block_kwargs, model_cfg.block_kwargs))
def create_byob_stages(
cfg, drop_path_rate, output_stride, stem_feat,
feat_size=None, layers=None, extra_args_fn=None):
cfg: ByoModelCfg, drop_path_rate: float, output_stride: int, stem_feat: Dict[str, Any],
feat_size: Optional[int] = None,
layers: Optional[LayerFn] = None,
block_kwargs_fn: Optional[Callable] = update_block_kwargs):
layers = layers or LayerFn()
feature_info = []
block_cfgs = [expand_blocks_cfg(s) for s in cfg.blocks]
@ -641,8 +841,10 @@ def create_byob_stages(
drop_path_rate=dpr[stage_idx][block_idx],
layers=layers,
)
if extra_args_fn is not None:
extra_args_fn(block_kwargs, block_cfg=block_cfg, model_cfg=cfg, feat_size=feat_size)
if block_cfg.type in ('self_attn',):
# add feat_size arg for blocks that support/need it
block_kwargs['feat_size'] = feat_size
block_kwargs_fn(block_kwargs, block_cfg=block_cfg, model_cfg=cfg)
blocks += [create_block(block_cfg.type, **block_kwargs)]
first_dilation = dilation
prev_chs = out_chs
@ -656,12 +858,13 @@ def create_byob_stages(
return nn.Sequential(*stages), feature_info
def get_layer_fns(cfg: ByobCfg):
def get_layer_fns(cfg: ByoModelCfg):
act = get_act_layer(cfg.act_layer)
norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act)
conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act)
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn)
self_attn = partial(get_self_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None
layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn)
return layer_fn
@ -673,19 +876,24 @@ class ByobNet(nn.Module):
Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act).
"""
def __init__(self, cfg: ByobCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
zero_init_last_bn=True, drop_rate=0., drop_path_rate=0.):
def __init__(self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
zero_init_last_bn=True, img_size=None, drop_rate=0., drop_path_rate=0.):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
layers = get_layer_fns(cfg)
if cfg.fixed_input_size:
assert img_size is not None, 'img_size argument is required for fixed input size model'
feat_size = to_2tuple(img_size) if img_size is not None else None
self.feature_info = []
stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor))
self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers)
self.feature_info.extend(stem_feat[:-1])
feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction'])
self.stages, stage_feat = create_byob_stages(cfg, drop_path_rate, output_stride, stem_feat[-1], layers=layers)
self.stages, stage_feat = create_byob_stages(
cfg, drop_path_rate, output_stride, stem_feat[-1], layers=layers, feat_size=feat_size)
self.feature_info.extend(stage_feat[:-1])
prev_chs = stage_feat[-1]['num_chs']
@ -836,3 +1044,24 @@ def repvgg_b3g4(pretrained=False, **kwargs):
`Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697
"""
return _create_byobnet('repvgg_b3g4', pretrained=pretrained, **kwargs)
@register_model
def resnet52q(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('geresnet50t', pretrained=pretrained, **kwargs)
@register_model
def geresnet50t(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('geresnet50t', pretrained=pretrained, **kwargs)
@register_model
def gcresnet50t(pretrained=False, **kwargs):
"""
"""
return _create_byobnet('gcresnet50t', pretrained=pretrained, **kwargs)

@ -14,20 +14,22 @@ from .create_conv2d import create_conv2d
from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act
from .create_self_attn import get_self_attn, create_self_attn
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .eca import EcaModule, CecaModule
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
from .gather_excite import GatherExcite
from .global_context import GlobalContext
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
from .inplace_abn import InplaceAbn
from .involution import Involution
from .linear import Linear
from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp
from .norm import GroupNorm
from .norm import GroupNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct
from .padding import get_padding, get_same_padding, pad_same
from .patch_embed import PatchEmbed
from .pool2d_same import AvgPool2dSame, create_pool2d
from .se import SEModule
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernelConv
from .separable_conv import SeparableConv2d, SeparableConvBnAct
from .space_to_depth import SpaceToDepthModule

@ -7,78 +7,87 @@ some tasks, especially fine-grained it seems. I may end up removing this impl.
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from torch import nn as nn
import torch.nn.functional as F
from .conv_bn_act import ConvBnAct
from .create_act import create_act_layer, get_act_layer
from .helpers import make_divisible
class ChannelAttn(nn.Module):
""" Original CBAM channel attention module, currently avg + max pool variant only.
"""
def __init__(self, channels, reduction=16, act_layer=nn.ReLU):
def __init__(
self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
super(ChannelAttn, self).__init__()
self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False)
if not rd_channels:
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias)
self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False)
self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_avg = x.mean((2, 3), keepdim=True)
x_max = F.adaptive_max_pool2d(x, 1)
x_avg = self.fc2(self.act(self.fc1(x_avg)))
x_max = self.fc2(self.act(self.fc1(x_max)))
x_attn = x_avg + x_max
return x * x_attn.sigmoid()
x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True))))
x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True))))
return x * self.gate(x_avg + x_max)
class LightChannelAttn(ChannelAttn):
"""An experimental 'lightweight' that sums avg + max pool first
"""
def __init__(self, channels, reduction=16):
super(LightChannelAttn, self).__init__(channels, reduction)
def __init__(
self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
super(LightChannelAttn, self).__init__(
channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias)
def forward(self, x):
x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * F.adaptive_max_pool2d(x, 1)
x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True)
x_attn = self.fc2(self.act(self.fc1(x_pool)))
return x * x_attn.sigmoid()
return x * F.sigmoid(x_attn)
class SpatialAttn(nn.Module):
""" Original CBAM spatial attention module
"""
def __init__(self, kernel_size=7):
def __init__(self, kernel_size=7, gate_layer='sigmoid'):
super(SpatialAttn, self).__init__()
self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_avg = torch.mean(x, dim=1, keepdim=True)
x_max = torch.max(x, dim=1, keepdim=True)[0]
x_attn = torch.cat([x_avg, x_max], dim=1)
x_attn = torch.cat([x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1)
x_attn = self.conv(x_attn)
return x * x_attn.sigmoid()
return x * self.gate(x_attn)
class LightSpatialAttn(nn.Module):
"""An experimental 'lightweight' variant that sums avg_pool and max_pool results.
"""
def __init__(self, kernel_size=7):
def __init__(self, kernel_size=7, gate_layer='sigmoid'):
super(LightSpatialAttn, self).__init__()
self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_avg = torch.mean(x, dim=1, keepdim=True)
x_max = torch.max(x, dim=1, keepdim=True)[0]
x_attn = 0.5 * x_avg + 0.5 * x_max
x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True)
x_attn = self.conv(x_attn)
return x * x_attn.sigmoid()
return x * self.gate(x_attn)
class CbamModule(nn.Module):
def __init__(self, channels, spatial_kernel_size=7):
def __init__(
self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
super(CbamModule, self).__init__()
self.channel = ChannelAttn(channels)
self.spatial = SpatialAttn(spatial_kernel_size)
self.channel = ChannelAttn(
channels, rd_ratio=rd_ratio, rd_channels=rd_channels,
rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias)
self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer)
def forward(self, x):
x = self.channel(x)
@ -87,9 +96,13 @@ class CbamModule(nn.Module):
class LightCbamModule(nn.Module):
def __init__(self, channels, spatial_kernel_size=7):
def __init__(
self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1,
spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False):
super(LightCbamModule, self).__init__()
self.channel = LightChannelAttn(channels)
self.channel = LightChannelAttn(
channels, rd_ratio=rd_ratio, rd_channels=rd_channels,
rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias)
self.spatial = LightSpatialAttn(spatial_kernel_size)
def forward(self, x):

@ -3,9 +3,12 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
from .se import SEModule, EffectiveSEModule
from .eca import EcaModule, CecaModule
from .cbam import CbamModule, LightCbamModule
from .eca import EcaModule, CecaModule
from .gather_excite import GatherExcite
from .global_context import GlobalContext
from .squeeze_excite import SEModule, EffectiveSEModule
def get_attn(attn_type):
@ -23,6 +26,10 @@ def get_attn(attn_type):
module_cls = EcaModule
elif attn_type == 'ceca':
module_cls = CecaModule
elif attn_type == 'ge':
module_cls = GatherExcite
elif attn_type == 'gc':
module_cls = GlobalContext
elif attn_type == 'cbam':
module_cls = CbamModule
elif attn_type == 'lcbam':

@ -65,6 +65,9 @@ class EcaModule(nn.Module):
return x * y.expand_as(x)
EfficientChannelAttn = EcaModule # alias
class CecaModule(nn.Module):
"""Constructs a circular ECA module.
@ -105,3 +108,6 @@ class CecaModule(nn.Module):
y = self.conv(y)
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
return x * y.expand_as(x)
CircularEfficientChannelAttn = CecaModule

@ -0,0 +1,90 @@
""" Gather-Excite Attention Block
Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348
Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet
I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another
impl that covers all of the cases.
NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation
Hacked together by / Copyright 2021 Ross Wightman
"""
import math
from torch import nn as nn
import torch.nn.functional as F
from .create_act import create_act_layer, get_act_layer
from .create_conv2d import create_conv2d
from .helpers import make_divisible
from .mlp import ConvMlp
class GatherExcite(nn.Module):
""" Gather-Excite Attention Module
"""
def __init__(
self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True,
rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'):
super(GatherExcite, self).__init__()
self.add_maxpool = add_maxpool
act_layer = get_act_layer(act_layer)
self.extent = extent
if extra_params:
self.gather = nn.Sequential()
if extent == 0:
assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params'
self.gather.add_module(
'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True))
if norm_layer:
self.gather.add_module(f'norm1', nn.BatchNorm2d(channels))
else:
assert extent % 2 == 0
num_conv = int(math.log2(extent))
for i in range(num_conv):
self.gather.add_module(
f'conv{i + 1}',
create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True))
if norm_layer:
self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels))
if i != num_conv - 1:
self.gather.add_module(f'act{i + 1}', act_layer(inplace=True))
else:
self.gather = None
if self.extent == 0:
self.gk = 0
self.gs = 0
else:
assert extent % 2 == 0
self.gk = self.extent * 2 - 1
self.gs = self.extent
if not rd_channels:
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity()
self.gate = create_act_layer(gate_layer)
def forward(self, x):
size = x.shape[-2:]
if self.gather is not None:
x_ge = self.gather(x)
else:
if self.extent == 0:
# global extent
x_ge = x.mean(dim=(2, 3), keepdims=True)
if self.add_maxpool:
# experimental codepath, may remove or change
x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True)
else:
x_ge = F.avg_pool2d(
x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False)
if self.add_maxpool:
# experimental codepath, may remove or change
x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2)
x_ge = self.mlp(x_ge)
if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1:
x_ge = F.interpolate(x_ge, size=size)
return x * self.gate(x_ge)

@ -0,0 +1,67 @@
""" Global Context Attention Block
Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond`
- https://arxiv.org/abs/1904.11492
Official code consulted as reference: https://github.com/xvjiarui/GCNet
Hacked together by / Copyright 2021 Ross Wightman
"""
from torch import nn as nn
import torch.nn.functional as F
from .create_act import create_act_layer, get_act_layer
from .helpers import make_divisible
from .mlp import ConvMlp
from .norm import LayerNorm2d
class GlobalContext(nn.Module):
def __init__(self, channels, use_attn=True, fuse_add=True, fuse_scale=False, init_last_zero=False,
rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'):
super(GlobalContext, self).__init__()
act_layer = get_act_layer(act_layer)
self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None
if rd_channels is None:
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
if fuse_add:
self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
else:
self.mlp_add = None
if fuse_scale:
self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d)
else:
self.mlp_scale = None
self.gate = create_act_layer(gate_layer)
self.init_last_zero = init_last_zero
self.reset_parameters()
def reset_parameters(self):
if self.conv_attn is not None:
nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu')
if self.mlp_add is not None:
nn.init.zeros_(self.mlp_add.fc2.weight)
def forward(self, x):
B, C, H, W = x.shape
if self.conv_attn is not None:
attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W)
attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1)
context = x.reshape(B, C, H * W).unsqueeze(1) @ attn
context = context.view(B, C, 1, 1)
else:
context = x.mean(dim=(2, 3), keepdim=True)
if self.mlp_scale is not None:
mlp_x = self.mlp_scale(context)
x = x * self.gate(mlp_x)
if self.mlp_add is not None:
mlp_x = self.mlp_add(context)
x = x + mlp_x
return x

@ -16,7 +16,7 @@ class Involution(nn.Module):
kernel_size=3,
stride=1,
group_size=16,
reduction_ratio=4,
rd_ratio=4,
norm_layer=nn.BatchNorm2d,
act_layer=nn.ReLU,
):
@ -28,12 +28,12 @@ class Involution(nn.Module):
self.groups = self.channels // self.group_size
self.conv1 = ConvBnAct(
in_channels=channels,
out_channels=channels // reduction_ratio,
out_channels=channels // rd_ratio,
kernel_size=1,
norm_layer=norm_layer,
act_layer=act_layer)
self.conv2 = self.conv = create_conv2d(
in_channels=channels // reduction_ratio,
in_channels=channels // rd_ratio,
out_channels=kernel_size**2 * self.groups,
kernel_size=1,
stride=1)

@ -77,3 +77,26 @@ class GatedMlp(nn.Module):
x = self.fc2(x)
x = self.drop(x)
return x
class ConvMlp(nn.Module):
""" MLP using 1x1 convs that keeps spatial dims
"""
def __init__(
self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True)
self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.norm(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
return x

@ -12,3 +12,12 @@ class GroupNorm(nn.GroupNorm):
def forward(self, x):
return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
class LayerNorm2d(nn.LayerNorm):
""" Layernorm for channels of '2d' spatial BCHW tensors """
def __init__(self, num_channels):
super().__init__([num_channels, 1, 1])
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)

@ -1,50 +0,0 @@
from torch import nn as nn
import torch.nn.functional as F
from .create_act import create_act_layer
from .helpers import make_divisible
class SEModule(nn.Module):
""" SE Module as defined in original SE-Nets with a few additions
Additions include:
* min_channels can be specified to keep reduced channel count at a minimum (default: 8)
* divisor can be specified to keep channels rounded to specified values (default: 1)
* reduction channels can be specified directly by arg (if reduction_channels is set)
* reduction channels can be specified by float ratio (if reduction_ratio is set)
"""
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, gate_layer='sigmoid',
reduction_ratio=None, reduction_channels=None, min_channels=8, divisor=1):
super(SEModule, self).__init__()
if reduction_channels is not None:
reduction_channels = reduction_channels # direct specification highest priority, no rounding/min done
elif reduction_ratio is not None:
reduction_channels = make_divisible(channels * reduction_ratio, divisor, min_channels)
else:
reduction_channels = make_divisible(channels // reduction, divisor, min_channels)
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
x_se = self.fc1(x_se)
x_se = self.act(x_se)
x_se = self.fc2(x_se)
return x * self.gate(x_se)
class EffectiveSEModule(nn.Module):
""" 'Effective Squeeze-Excitation
From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
"""
def __init__(self, channels, gate_layer='hard_sigmoid'):
super(EffectiveSEModule, self).__init__()
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
x_se = self.fc(x_se)
return x * self.gate(x_se)

@ -0,0 +1,74 @@
""" Squeeze-and-Excitation Channel Attention
An SE implementation originally based on PyTorch SE-Net impl.
Has since evolved with additional functionality / configuration.
Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507
Also included is Effective Squeeze-Excitation (ESE).
Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
Hacked together by / Copyright 2021 Ross Wightman
"""
from torch import nn as nn
from .create_act import create_act_layer
from .helpers import make_divisible
class SEModule(nn.Module):
""" SE Module as defined in original SE-Nets with a few additions
Additions include:
* divisor can be specified to keep channels % div == 0 (default: 8)
* reduction channels can be specified directly by arg (if rd_channels is set)
* reduction channels can be specified by float rd_ratio (default: 1/16)
* global max pooling can be added to the squeeze aggregation
* customizable activation, normalization, and gate layer
"""
def __init__(
self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False,
act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'):
super(SEModule, self).__init__()
self.add_maxpool = add_maxpool
if not rd_channels:
rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.)
self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True)
self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity()
self.act = create_act_layer(act_layer, inplace=True)
self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
if self.add_maxpool:
# experimental codepath, may remove or change
x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
x_se = self.fc1(x_se)
x_se = self.act(self.bn(x_se))
x_se = self.fc2(x_se)
return x * self.gate(x_se)
SqueezeExcite = SEModule # alias
class EffectiveSEModule(nn.Module):
""" 'Effective Squeeze-Excitation
From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
"""
def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid'):
super(EffectiveSEModule, self).__init__()
self.add_maxpool = add_maxpool
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
if self.add_maxpool:
# experimental codepath, may remove or change
x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True)
x_se = self.fc(x_se)
return x * self.gate(x_se)
EffectiveSqueezeExcite = EffectiveSEModule # alias

@ -182,7 +182,7 @@ def _nfres_cfg(
def _nfreg_cfg(depths, channels=(48, 104, 208, 440)):
num_features = 1280 * channels[-1] // 440
attn_kwargs = dict(reduction_ratio=0.5, divisor=8)
attn_kwargs = dict(rd_ratio=0.5)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25,
num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs)
@ -193,7 +193,7 @@ def _nfnet_cfg(
depths, channels=(256, 512, 1536, 1536), group_size=128, bottle_ratio=0.5, feat_mult=2.,
act_layer='gelu', attn_layer='se', attn_kwargs=None):
num_features = int(channels[-1] * feat_mult)
attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(reduction_ratio=0.5, divisor=8)
attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(rd_ratio=0.5)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=group_size,
bottle_ratio=bottle_ratio, extra_conv=True, num_features=num_features, act_layer=act_layer,
@ -202,11 +202,10 @@ def _nfnet_cfg(
def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', skipinit=True):
attn_kwargs = dict(reduction_ratio=0.5, divisor=8)
cfg = NfCfg(
depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=128,
bottle_ratio=0.5, extra_conv=True, gamma_in_act=True, same_padding=True, skipinit=skipinit,
num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=attn_kwargs)
num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=dict(rd_ratio=0.5))
return cfg
@ -243,7 +242,7 @@ model_cfgs = dict(
# Experimental 'light' versions of NFNet-F that are little leaner
nfnet_l0=_nfnet_cfg(
depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
attn_kwargs=dict(reduction_ratio=0.25, divisor=8), act_layer='silu'),
attn_kwargs=dict(rd_ratio=0.25, rd_divisor=8), act_layer='silu'),
eca_nfnet_l0=_nfnet_cfg(
depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25,
attn_layer='eca', attn_kwargs=dict(), act_layer='silu'),
@ -272,9 +271,9 @@ model_cfgs = dict(
nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)),
nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)),
nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)),
nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)),
nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)),
nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)),
nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()),
nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()),

@ -146,7 +146,7 @@ class Bottleneck(nn.Module):
groups=groups, **cargs)
if se_ratio:
se_channels = int(round(in_chs * se_ratio))
self.se = SEModule(bottleneck_chs, reduction_channels=se_channels)
self.se = SEModule(bottleneck_chs, rd_channels=se_channels)
else:
self.se = None
cargs['act_layer'] = None

@ -1122,7 +1122,7 @@ def resnetrs50(pretrained=False, **kwargs):
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
@ -1135,7 +1135,7 @@ def resnetrs101(pretrained=False, **kwargs):
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
@ -1148,7 +1148,7 @@ def resnetrs152(pretrained=False, **kwargs):
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
@ -1161,7 +1161,7 @@ def resnetrs200(pretrained=False, **kwargs):
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
@ -1174,7 +1174,7 @@ def resnetrs270(pretrained=False, **kwargs):
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
@ -1188,7 +1188,7 @@ def resnetrs350(pretrained=False, **kwargs):
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)
@ -1201,7 +1201,7 @@ def resnetrs420(pretrained=False, **kwargs):
Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579
Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs
"""
attn_layer = partial(get_attn('se'), reduction_ratio=0.25)
attn_layer = partial(get_attn('se'), rd_ratio=0.25)
model_args = dict(
block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True,
avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs)

@ -11,11 +11,12 @@ Copyright 2020 Ross Wightman
"""
import torch.nn as nn
from functools import partial
from math import ceil
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath, make_divisible
from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath, make_divisible, SEModule
from .registry import register_model
from .efficientnet_builder import efficientnet_init_weights
@ -48,26 +49,7 @@ default_cfgs = dict(
url=''),
)
class SEWithNorm(nn.Module):
def __init__(self, channels, se_ratio=1 / 12., act_layer=nn.ReLU, divisor=1, reduction_channels=None,
gate_layer='sigmoid'):
super(SEWithNorm, self).__init__()
reduction_channels = reduction_channels or make_divisible(int(channels * se_ratio), divisor=divisor)
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
self.bn = nn.BatchNorm2d(reduction_channels)
self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
x_se = x.mean((2, 3), keepdim=True)
x_se = self.fc1(x_se)
x_se = self.bn(x_se)
x_se = self.act(x_se)
x_se = self.fc2(x_se)
return x * self.gate(x_se)
SEWithNorm = partial(SEModule, norm_layer=nn.BatchNorm2d)
class LinearBottleneck(nn.Module):
@ -86,7 +68,10 @@ class LinearBottleneck(nn.Module):
self.conv_exp = None
self.conv_dw = ConvBnAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False)
self.se = SEWithNorm(dw_chs, se_ratio=se_ratio, divisor=ch_div) if se_ratio > 0. else None
if se_ratio > 0:
self.se = SEWithNorm(dw_chs, rd_channels=make_divisible(int(dw_chs * se_ratio), ch_div))
else:
self.se = None
self.act_dw = create_act_layer(dw_act_layer)
self.conv_pwl = ConvBnAct(dw_chs, out_chs, 1, apply_act=False)

@ -84,8 +84,8 @@ class BasicBlock(nn.Module):
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
reduction_chs = max(planes * self.expansion // 4, 64)
self.se = SEModule(planes * self.expansion, reduction_channels=reduction_chs) if use_se else None
rd_chs = max(planes * self.expansion // 4, 64)
self.se = SEModule(planes * self.expansion, rd_channels=rd_chs) if use_se else None
def forward(self, x):
if self.downsample is not None:
@ -125,7 +125,7 @@ class Bottleneck(nn.Module):
aa_layer(channels=planes, filt_size=3, stride=2))
reduction_chs = max(planes * self.expansion // 8, 64)
self.se = SEModule(planes, reduction_channels=reduction_chs) if use_se else None
self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None
self.conv3 = conv2d_iabn(
planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity")

@ -13,7 +13,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, overlay_external_default_cfg
from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed
from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d
from .registry import register_model
@ -39,15 +39,6 @@ default_cfgs = dict(
)
class LayerNormBHWC(nn.LayerNorm):
def __init__(self, dim):
super().__init__(dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
class SpatialMlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None,
act_layer=nn.GELU, drop=0., group=8, spatial_conv=False):
@ -119,7 +110,7 @@ class Attention(nn.Module):
class Block(nn.Module):
def __init__(self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4.,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormBHWC,
drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm2d,
group=8, attn_disabled=False, spatial_conv=False):
super().__init__()
self.spatial_conv = spatial_conv
@ -148,7 +139,7 @@ class Block(nn.Module):
class Visformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384,
depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
norm_layer=LayerNormBHWC, attn_stage='111', pos_embed=True, spatial_conv='111',
norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111',
vit_stem=False, group=8, pool=True, conv_init=False, embed_norm=None):
super().__init__()
self.num_classes = num_classes

Loading…
Cancel
Save