diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 788b7518..06217e18 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -17,7 +17,6 @@ from .inception_resnet_v2 import * from .inception_v3 import * from .inception_v4 import * from .levit import * -#from .levit import * from .mlp_mixer import * from .mobilenetv3 import * from .nasnet import * diff --git a/timm/models/byoanet.py b/timm/models/byoanet.py index c179a01c..73c6811b 100644 --- a/timm/models/byoanet.py +++ b/timm/models/byoanet.py @@ -12,24 +12,12 @@ Consider all of the models definitions here as experimental WIP and likely to ch Hacked together by / copyright Ross Wightman, 2021. """ -import math -from dataclasses import dataclass, field -from collections import OrderedDict -from typing import Tuple, List, Optional, Union, Any, Callable -from functools import partial - -import torch -import torch.nn as nn - from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .byobnet import BlocksCfg, ByobCfg, create_byob_stem, create_byob_stages, create_downsample,\ - reduce_feat_size, register_block, num_groups, LayerFn, _init_weights +from .byobnet import ByoBlockCfg, ByoModelCfg, ByobNet, interleave_blocks from .helpers import build_model_with_cfg -from .layers import ClassifierHead, ConvBnAct, DropPath, get_act_layer, convert_norm_act, get_attn, get_self_attn,\ - make_divisible, to_2tuple from .registry import register_model -__all__ = ['ByoaNet'] +__all__ = [] def _cfg(url='', **kwargs): @@ -63,100 +51,68 @@ default_cfgs = { 'swinnet50ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), 'eca_swinnext26ts_256': _cfg(url='', fixed_input_size=True, input_size=(3, 256, 256), pool_size=(8, 8)), - 'rednet26t': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)), - 'rednet50ts': _cfg(url='', fixed_input_size=False, input_size=(3, 256, 256), pool_size=(8, 8)), + 'rednet26t': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), + 'rednet50ts': _cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), } -@dataclass -class ByoaBlocksCfg(BlocksCfg): - # FIXME allow overriding self_attn layer or args per block/stage, - pass - - -@dataclass -class ByoaCfg(ByobCfg): - blocks: Tuple[Union[ByoaBlocksCfg, Tuple[ByoaBlocksCfg, ...]], ...] = None - self_attn_layer: Optional[str] = None - self_attn_fixed_size: bool = False - self_attn_kwargs: dict = field(default_factory=lambda: dict()) - - -def interleave_attn( - types : Tuple[str, str], every: Union[int, List[int]], d, first: bool = False, **kwargs -) -> Tuple[ByoaBlocksCfg]: - """ interleave attn blocks - """ - assert len(types) == 2 - if isinstance(every, int): - every = list(range(0 if first else every, d, every)) - if not every: - every = [d - 1] - set(every) - blocks = [] - for i in range(d): - block_type = types[1] if i in every else types[0] - blocks += [ByoaBlocksCfg(type=block_type, d=1, **kwargs)] - return tuple(blocks) - - model_cfgs = dict( - botnet26t=ByoaCfg( + botnet26t=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), - ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', stem_pool='maxpool', num_features=0, + fixed_input_size=True, self_attn_layer='bottleneck', - self_attn_fixed_size=True, self_attn_kwargs=dict() ), - botnet50ts=ByoaCfg( + botnet50ts=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=256, s=2, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=2048, s=1, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', stem_pool='', num_features=0, + fixed_input_size=True, act_layer='silu', self_attn_layer='bottleneck', - self_attn_fixed_size=True, self_attn_kwargs=dict() ), - eca_botnext26ts=ByoaCfg( + eca_botnext26ts=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25), - ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), - ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=16, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=16, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=16, br=0.25), ), stem_chs=64, stem_type='tiered', stem_pool='maxpool', num_features=0, + fixed_input_size=True, act_layer='silu', attn_layer='eca', self_attn_layer='bottleneck', - self_attn_fixed_size=True, self_attn_kwargs=dict() ), - halonet_h1=ByoaCfg( + halonet_h1=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='self_attn', d=3, c=64, s=1, gs=0, br=1.0), - ByoaBlocksCfg(type='self_attn', d=3, c=128, s=2, gs=0, br=1.0), - ByoaBlocksCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0), - ByoaBlocksCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0), + ByoBlockCfg(type='self_attn', d=3, c=64, s=1, gs=0, br=1.0), + ByoBlockCfg(type='self_attn', d=3, c=128, s=2, gs=0, br=1.0), + ByoBlockCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0), + ByoBlockCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0), ), stem_chs=64, stem_type='7x7', @@ -165,12 +121,12 @@ model_cfgs = dict( self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=3), ), - halonet_h1_c4c5=ByoaCfg( + halonet_h1_c4c5=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=3, c=64, s=1, gs=0, br=1.0), - ByoaBlocksCfg(type='bottle', d=3, c=128, s=2, gs=0, br=1.0), - ByoaBlocksCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0), - ByoaBlocksCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0), + ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=0, br=1.0), + ByoBlockCfg(type='bottle', d=3, c=128, s=2, gs=0, br=1.0), + ByoBlockCfg(type='self_attn', d=10, c=256, s=2, gs=0, br=1.0), + ByoBlockCfg(type='self_attn', d=3, c=512, s=2, gs=0, br=1.0), ), stem_chs=64, stem_type='tiered', @@ -179,12 +135,12 @@ model_cfgs = dict( self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=3), ), - halonet26t=ByoaCfg( + halonet26t=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), - ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', @@ -193,12 +149,12 @@ model_cfgs = dict( self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res ), - halonet50ts=ByoaCfg( + halonet50ts=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), - ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=6, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', @@ -208,12 +164,12 @@ model_cfgs = dict( self_attn_layer='halo', self_attn_kwargs=dict(block_size=8, halo_size=2) ), - eca_halonext26ts=ByoaCfg( + eca_halonext26ts=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), - ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), - ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), ), stem_chs=64, stem_type='tiered', @@ -225,12 +181,12 @@ model_cfgs = dict( self_attn_kwargs=dict(block_size=8, halo_size=2) # intended for 256x256 res ), - lambda_resnet26t=ByoaCfg( + lambda_resnet26t=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), - ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', @@ -239,12 +195,12 @@ model_cfgs = dict( self_attn_layer='lambda', self_attn_kwargs=dict() ), - lambda_resnet50t=ByoaCfg( + lambda_resnet50t=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), - ByoaBlocksCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=3, d=6, c=1024, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=3, d=6, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', @@ -253,12 +209,12 @@ model_cfgs = dict( self_attn_layer='lambda', self_attn_kwargs=dict() ), - eca_lambda_resnext26ts=ByoaCfg( + eca_lambda_resnext26ts=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), - ByoaBlocksCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), - ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=512, s=2, gs=16, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), ), stem_chs=64, stem_type='tiered', @@ -270,77 +226,76 @@ model_cfgs = dict( self_attn_kwargs=dict() ), - swinnet26t=ByoaCfg( + swinnet26t=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=0, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', stem_pool='maxpool', num_features=0, + fixed_input_size=True, self_attn_layer='swin', - self_attn_fixed_size=True, self_attn_kwargs=dict(win_size=8) ), - swinnet50ts=ByoaCfg( + swinnet50ts=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=4, c=512, s=2, gs=0, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=256, s=1, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=4, c=512, s=2, gs=0, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', stem_pool='maxpool', num_features=0, + fixed_input_size=True, act_layer='silu', self_attn_layer='swin', - self_attn_fixed_size=True, self_attn_kwargs=dict(win_size=8) ), - eca_swinnext26ts=ByoaCfg( + eca_swinnext26ts=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=16, br=0.25), - interleave_attn(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), - ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=16, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=512, s=2, gs=16, br=0.25), + interleave_blocks(types=('bottle', 'self_attn'), every=1, d=2, c=1024, s=2, gs=16, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=16, br=0.25), ), stem_chs=64, stem_type='tiered', stem_pool='maxpool', num_features=0, + fixed_input_size=True, act_layer='silu', attn_layer='eca', self_attn_layer='swin', - self_attn_fixed_size=True, self_attn_kwargs=dict(win_size=8) ), - rednet26t=ByoaCfg( + rednet26t=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='self_attn', d=2, c=256, s=1, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=2, c=512, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=512, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', # FIXME RedNet uses involution in middle of stem stem_pool='maxpool', num_features=0, self_attn_layer='involution', - self_attn_fixed_size=False, self_attn_kwargs=dict() ), - rednet50ts=ByoaCfg( + rednet50ts=ByoModelCfg( blocks=( - ByoaBlocksCfg(type='self_attn', d=3, c=256, s=1, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=4, c=512, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25), - ByoaBlocksCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=256, s=1, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=4, c=512, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=2, c=1024, s=2, gs=0, br=0.25), + ByoBlockCfg(type='self_attn', d=3, c=2048, s=2, gs=0, br=0.25), ), stem_chs=64, stem_type='tiered', @@ -348,161 +303,14 @@ model_cfgs = dict( num_features=0, act_layer='silu', self_attn_layer='involution', - self_attn_fixed_size=False, self_attn_kwargs=dict() ), ) -@dataclass -class ByoaLayerFn(LayerFn): - self_attn: Optional[Callable] = None - - -class SelfAttnBlock(nn.Module): - """ ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1 - """ - - def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, - downsample='avg', extra_conv=False, linear_out=False, post_attn_na=True, feat_size=None, - layers: ByoaLayerFn = None, drop_block=None, drop_path_rate=0.): - super(SelfAttnBlock, self).__init__() - assert layers is not None - mid_chs = make_divisible(out_chs * bottle_ratio) - groups = num_groups(group_size, mid_chs) - - if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: - self.shortcut = create_downsample( - downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], - apply_act=False, layers=layers) - else: - self.shortcut = nn.Identity() - - self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) - if extra_conv: - self.conv2_kxk = layers.conv_norm_act( - mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], - groups=groups, drop_block=drop_block) - stride = 1 # striding done via conv if enabled - else: - self.conv2_kxk = nn.Identity() - opt_kwargs = {} if feat_size is None else dict(feat_size=feat_size) - # FIXME need to dilate self attn to have dilated network support, moop moop - self.self_attn = layers.self_attn(mid_chs, stride=stride, **opt_kwargs) - self.post_attn = layers.norm_act(mid_chs) if post_attn_na else nn.Identity() - self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False) - self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() - self.act = nn.Identity() if linear_out else layers.act(inplace=True) - - def init_weights(self, zero_init_last_bn=False): - if zero_init_last_bn: - nn.init.zeros_(self.conv3_1x1.bn.weight) - if hasattr(self.self_attn, 'reset_parameters'): - self.self_attn.reset_parameters() - - def forward(self, x): - shortcut = self.shortcut(x) - - x = self.conv1_1x1(x) - x = self.conv2_kxk(x) - x = self.self_attn(x) - x = self.post_attn(x) - x = self.conv3_1x1(x) - x = self.drop_path(x) - - x = self.act(x + shortcut) - return x - -register_block('self_attn', SelfAttnBlock) - - -def _byoa_block_args(block_kwargs, block_cfg: ByoaBlocksCfg, model_cfg: ByoaCfg, feat_size=None): - if block_cfg.type == 'self_attn' and model_cfg.self_attn_fixed_size: - assert feat_size is not None - block_kwargs['feat_size'] = feat_size - return block_kwargs - - -def get_layer_fns(cfg: ByoaCfg): - act = get_act_layer(cfg.act_layer) - norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act) - conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act) - attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None - self_attn = partial(get_self_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None - layer_fn = ByoaLayerFn( - conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn) - return layer_fn - - -class ByoaNet(nn.Module): - """ 'Bring-your-own-attention' Net - - A ResNet inspired backbone that supports interleaving traditional residual blocks with - 'Self Attention' bottleneck blocks that replace the bottleneck kxk conv w/ a self-attention - or similar module. - - FIXME This class network definition is almost the same as ByobNet, I'd like to merge them but - torchscript limitations prevent sensible inheritance overrides. - """ - def __init__(self, cfg: ByoaCfg, num_classes=1000, in_chans=3, output_stride=32, global_pool='avg', - zero_init_last_bn=True, img_size=None, drop_rate=0., drop_path_rate=0.): - super().__init__() - self.num_classes = num_classes - self.drop_rate = drop_rate - layers = get_layer_fns(cfg) - feat_size = to_2tuple(img_size) if img_size is not None else None - - self.feature_info = [] - stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor)) - self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers) - self.feature_info.extend(stem_feat[:-1]) - feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction']) - - self.stages, stage_feat = create_byob_stages( - cfg, drop_path_rate, output_stride, stem_feat[-1], - feat_size=feat_size, layers=layers, extra_args_fn=_byoa_block_args) - self.feature_info.extend(stage_feat[:-1]) - - prev_chs = stage_feat[-1]['num_chs'] - if cfg.num_features: - self.num_features = int(round(cfg.width_factor * cfg.num_features)) - self.final_conv = layers.conv_norm_act(prev_chs, self.num_features, 1) - else: - self.num_features = prev_chs - self.final_conv = nn.Identity() - self.feature_info += [ - dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')] - - self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) - - for n, m in self.named_modules(): - _init_weights(m, n) - for m in self.modules(): - # call each block's weight init for block-specific overrides to init above - if hasattr(m, 'init_weights'): - m.init_weights(zero_init_last_bn=zero_init_last_bn) - - def get_classifier(self): - return self.head.fc - - def reset_classifier(self, num_classes, global_pool='avg'): - self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) - - def forward_features(self, x): - x = self.stem(x) - x = self.stages(x) - x = self.final_conv(x) - return x - - def forward(self, x): - x = self.forward_features(x) - x = self.head(x) - return x - - def _create_byoanet(variant, cfg_variant=None, pretrained=False, **kwargs): return build_model_with_cfg( - ByoaNet, variant, pretrained, + ByobNet, variant, pretrained, default_cfg=default_cfgs[variant], model_cfg=model_cfgs[variant] if not cfg_variant else model_cfgs[cfg_variant], feature_cfg=dict(flatten_sequential=True), diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 8f4a2020..3f162c79 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -26,8 +26,7 @@ Hacked together by / copyright Ross Wightman, 2021. """ import math from dataclasses import dataclass, field, replace -from collections import OrderedDict -from typing import Tuple, List, Optional, Union, Any, Callable, Sequence +from typing import Tuple, List, Dict, Optional, Union, Any, Callable, Sequence from functools import partial import torch @@ -36,10 +35,10 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ - create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible + create_conv2d, get_act_layer, convert_norm_act, get_attn, get_self_attn, make_divisible, to_2tuple from .registry import register_model -__all__ = ['ByobNet', 'ByobCfg', 'BlocksCfg', 'create_byob_stem', 'create_block'] +__all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block'] def _cfg(url='', **kwargs): @@ -87,35 +86,52 @@ default_cfgs = { 'repvgg_b3g4': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-repvgg-weights/repvgg_b3g4-73c370bf.pth', first_conv=('stem.conv_kxk.conv', 'stem.conv_1x1.conv')), + + # experimental configs + 'resnet52qs': _cfg(first_conv='stem.conv1.conv'), + 'geresnet50t': _cfg(first_conv='stem.conv1.conv'), + 'gcresnet50t': _cfg(first_conv='stem.conv1.conv'), } @dataclass -class BlocksCfg: +class ByoBlockCfg: type: Union[str, nn.Module] d: int # block depth (number of block repeats in stage) c: int # number of output channels for each block in stage s: int = 2 # stride of stage (first block) gs: Optional[Union[int, Callable]] = None # group-size of blocks in stage, conv is depthwise if gs == 1 br: float = 1. # bottleneck-ratio of blocks in stage - no_attn: bool = False # disable channel attn (ie SE) when layer is set for model + + # NOTE: these config items override the model cfgs that are applied to all blocks by default + attn_layer: Optional[str] = None + attn_kwargs: Optional[Dict[str, Any]] = None + self_attn_layer: Optional[str] = None + self_attn_kwargs: Optional[Dict[str, Any]] = None + block_kwargs: Optional[Dict[str, Any]] = None @dataclass -class ByobCfg: - blocks: Tuple[Union[BlocksCfg, Tuple[BlocksCfg, ...]], ...] +class ByoModelCfg: + blocks: Tuple[Union[ByoBlockCfg, Tuple[ByoBlockCfg, ...]], ...] downsample: str = 'conv1x1' stem_type: str = '3x3' - stem_pool: str = '' + stem_pool: Optional[str] = 'maxpool' stem_chs: int = 32 width_factor: float = 1.0 num_features: int = 0 # num out_channels for final conv, no final 1x1 conv if 0 zero_init_last_bn: bool = True + fixed_input_size: bool = False # model constrained to a fixed-input size / img_size must be provided on creation act_layer: str = 'relu' norm_layer: str = 'batchnorm' + + # NOTE: these config items will be overridden by the block cfg (per-block) if they are set there attn_layer: Optional[str] = None attn_kwargs: dict = field(default_factory=lambda: dict()) + self_attn_layer: Optional[str] = None + self_attn_kwargs: dict = field(default_factory=lambda: dict()) + block_kwargs: Dict[str, Any] = field(default_factory=lambda: dict()) def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0): @@ -123,103 +139,155 @@ def _rep_vgg_bcfg(d=(4, 6, 16, 1), wf=(1., 1., 1., 1.), groups=0): group_size = 0 if groups > 0: group_size = lambda chs, idx: chs // groups if (idx + 1) % 2 == 0 else 0 - bcfg = tuple([BlocksCfg(type='rep', d=d, c=c * wf, gs=group_size) for d, c, wf in zip(d, c, wf)]) + bcfg = tuple([ByoBlockCfg(type='rep', d=d, c=c * wf, gs=group_size) for d, c, wf in zip(d, c, wf)]) return bcfg -model_cfgs = dict( +def interleave_blocks( + types: Tuple[str, str], every: Union[int, List[int]], d, first: bool = False, **kwargs +) -> Tuple[ByoBlockCfg]: + """ interleave 2 block types in stack + """ + assert len(types) == 2 + if isinstance(every, int): + every = list(range(0 if first else every, d, every)) + if not every: + every = [d - 1] + set(every) + blocks = [] + for i in range(d): + block_type = types[1] if i in every else types[0] + blocks += [ByoBlockCfg(type=block_type, d=1, **kwargs)] + return tuple(blocks) - gernet_l=ByobCfg( + +model_cfgs = dict( + gernet_l=ByoModelCfg( blocks=( - BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), - BlocksCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.), - BlocksCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4), - BlocksCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.), - BlocksCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.), + ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), + ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.), + ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4), + ByoBlockCfg(type='bottle', d=5, c=640, s=2, gs=1, br=3.), + ByoBlockCfg(type='bottle', d=4, c=640, s=1, gs=1, br=3.), ), stem_chs=32, + stem_pool=None, num_features=2560, ), - gernet_m=ByobCfg( + gernet_m=ByoModelCfg( blocks=( - BlocksCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), - BlocksCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.), - BlocksCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4), - BlocksCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.), - BlocksCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.), + ByoBlockCfg(type='basic', d=1, c=128, s=2, gs=0, br=1.), + ByoBlockCfg(type='basic', d=2, c=192, s=2, gs=0, br=1.), + ByoBlockCfg(type='bottle', d=6, c=640, s=2, gs=0, br=1 / 4), + ByoBlockCfg(type='bottle', d=4, c=640, s=2, gs=1, br=3.), + ByoBlockCfg(type='bottle', d=1, c=640, s=1, gs=1, br=3.), ), stem_chs=32, + stem_pool=None, num_features=2560, ), - gernet_s=ByobCfg( + gernet_s=ByoModelCfg( blocks=( - BlocksCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.), - BlocksCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.), - BlocksCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4), - BlocksCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.), - BlocksCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.), + ByoBlockCfg(type='basic', d=1, c=48, s=2, gs=0, br=1.), + ByoBlockCfg(type='basic', d=3, c=48, s=2, gs=0, br=1.), + ByoBlockCfg(type='bottle', d=7, c=384, s=2, gs=0, br=1 / 4), + ByoBlockCfg(type='bottle', d=2, c=560, s=2, gs=1, br=3.), + ByoBlockCfg(type='bottle', d=1, c=256, s=1, gs=1, br=3.), ), stem_chs=13, + stem_pool=None, num_features=1920, ), - repvgg_a2=ByobCfg( + repvgg_a2=ByoModelCfg( blocks=_rep_vgg_bcfg(d=(2, 4, 14, 1), wf=(1.5, 1.5, 1.5, 2.75)), stem_type='rep', stem_chs=64, ), - repvgg_b0=ByobCfg( + repvgg_b0=ByoModelCfg( blocks=_rep_vgg_bcfg(wf=(1., 1., 1., 2.5)), stem_type='rep', stem_chs=64, ), - repvgg_b1=ByobCfg( + repvgg_b1=ByoModelCfg( blocks=_rep_vgg_bcfg(wf=(2., 2., 2., 4.)), stem_type='rep', stem_chs=64, ), - repvgg_b1g4=ByobCfg( + repvgg_b1g4=ByoModelCfg( blocks=_rep_vgg_bcfg(wf=(2., 2., 2., 4.), groups=4), stem_type='rep', stem_chs=64, ), - repvgg_b2=ByobCfg( + repvgg_b2=ByoModelCfg( blocks=_rep_vgg_bcfg(wf=(2.5, 2.5, 2.5, 5.)), stem_type='rep', stem_chs=64, ), - repvgg_b2g4=ByobCfg( + repvgg_b2g4=ByoModelCfg( blocks=_rep_vgg_bcfg(wf=(2.5, 2.5, 2.5, 5.), groups=4), stem_type='rep', stem_chs=64, ), - repvgg_b3=ByobCfg( + repvgg_b3=ByoModelCfg( blocks=_rep_vgg_bcfg(wf=(3., 3., 3., 5.)), stem_type='rep', stem_chs=64, ), - repvgg_b3g4=ByobCfg( + repvgg_b3g4=ByoModelCfg( blocks=_rep_vgg_bcfg(wf=(3., 3., 3., 5.), groups=4), stem_type='rep', stem_chs=64, ), - resnet52q=ByobCfg( + # WARN: experimental, may vanish/change + resnet52q=ByoModelCfg( blocks=( - BlocksCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25), - BlocksCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25), - BlocksCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25), - BlocksCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0), + ByoBlockCfg(type='bottle', d=2, c=256, s=1, gs=32, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, gs=32, br=0.25), + ByoBlockCfg(type='bottle', d=6, c=1536, s=2, gs=32, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=1536, s=2, gs=1, br=1.0), ), stem_chs=128, stem_type='quad', num_features=2048, act_layer='silu', ), + + # WARN: experimental, may vanish/change + geresnet50t=ByoModelCfg( + blocks=( + ByoBlockCfg(type='edge', d=3, c=256, s=1, br=0.25), + ByoBlockCfg(type='edge', d=4, c=512, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool=None, + attn_layer='ge', + attn_kwargs=dict(extent=8, extra_params=True), + #attn_kwargs=dict(extent=8), + #block_kwargs=dict(attn_last=True) + ), + + # WARN: experimental, may vanish/change + gcresnet50t=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=3, c=256, s=1, br=0.25), + ByoBlockCfg(type='bottle', d=4, c=512, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=6, c=1024, s=2, br=0.25), + ByoBlockCfg(type='bottle', d=3, c=2048, s=2, br=0.25), + ), + stem_chs=64, + stem_type='tiered', + stem_pool=None, + attn_layer='gc' + ), ) -def expand_blocks_cfg(stage_blocks_cfg: Union[BlocksCfg, Sequence[BlocksCfg]]) -> List[BlocksCfg]: +def expand_blocks_cfg(stage_blocks_cfg: Union[ByoBlockCfg, Sequence[ByoBlockCfg]]) -> List[ByoBlockCfg]: if not isinstance(stage_blocks_cfg, Sequence): stage_blocks_cfg = (stage_blocks_cfg,) block_cfgs = [] @@ -243,6 +311,7 @@ class LayerFn: norm_act: Callable = BatchNormAct2d act: Callable = nn.ReLU attn: Optional[Callable] = None + self_attn: Optional[Callable] = None class DownsampleAvg(nn.Module): @@ -275,7 +344,8 @@ class BasicBlock(nn.Module): def __init__( self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), group_size=None, bottle_ratio=1.0, - downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.): + downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None, + drop_path_rate=0.): super(BasicBlock, self).__init__() layers = layers or LayerFn() mid_chs = make_divisible(out_chs * bottle_ratio) @@ -289,15 +359,19 @@ class BasicBlock(nn.Module): self.shortcut = nn.Identity() self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0]) + self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) self.conv2_kxk = layers.conv_norm_act( mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block, apply_act=False) - self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) + self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) - def init_weights(self, zero_init_last_bn=False): + def init_weights(self, zero_init_last_bn: bool = False): if zero_init_last_bn: nn.init.zeros_(self.conv2_kxk.bn.weight) + for attn in (self.attn, self.attn_last): + if hasattr(attn, 'reset_parameters'): + attn.reset_parameters() def forward(self, x): shortcut = self.shortcut(x) @@ -317,7 +391,8 @@ class BottleneckBlock(nn.Module): """ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, - downsample='avg', linear_out=False, layers : LayerFn = None, drop_block=None, drop_path_rate=0.): + downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None, drop_block=None, + drop_path_rate=0.): super(BottleneckBlock, self).__init__() layers = layers or LayerFn() mid_chs = make_divisible(out_chs * bottle_ratio) @@ -334,14 +409,18 @@ class BottleneckBlock(nn.Module): self.conv2_kxk = layers.conv_norm_act( mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], groups=groups, drop_block=drop_block) - self.attn = nn.Identity() if layers.attn is None else layers.attn(mid_chs) + self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False) + self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) - def init_weights(self, zero_init_last_bn=False): + def init_weights(self, zero_init_last_bn: bool = False): if zero_init_last_bn: nn.init.zeros_(self.conv3_1x1.bn.weight) + for attn in (self.attn, self.attn_last): + if hasattr(attn, 'reset_parameters'): + attn.reset_parameters() def forward(self, x): shortcut = self.shortcut(x) @@ -350,6 +429,7 @@ class BottleneckBlock(nn.Module): x = self.conv2_kxk(x) x = self.attn(x) x = self.conv3_1x1(x) + x = self.attn_last(x) x = self.drop_path(x) x = self.act(x + shortcut) @@ -368,7 +448,8 @@ class DarkBlock(nn.Module): """ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, - downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.): + downsample='avg', attn_last=True, linear_out=False, layers: LayerFn = None, drop_block=None, + drop_path_rate=0.): super(DarkBlock, self).__init__() layers = layers or LayerFn() mid_chs = make_divisible(out_chs * bottle_ratio) @@ -382,23 +463,28 @@ class DarkBlock(nn.Module): self.shortcut = nn.Identity() self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) + self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) self.conv2_kxk = layers.conv_norm_act( mid_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0], groups=groups, drop_block=drop_block, apply_act=False) - self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) + self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) - def init_weights(self, zero_init_last_bn=False): + def init_weights(self, zero_init_last_bn: bool = False): if zero_init_last_bn: nn.init.zeros_(self.conv2_kxk.bn.weight) + for attn in (self.attn, self.attn_last): + if hasattr(attn, 'reset_parameters'): + attn.reset_parameters() def forward(self, x): shortcut = self.shortcut(x) x = self.conv1_1x1(x) - x = self.conv2_kxk(x) x = self.attn(x) + x = self.conv2_kxk(x) + x = self.attn_last(x) x = self.drop_path(x) x = self.act(x + shortcut) return x @@ -415,7 +501,8 @@ class EdgeBlock(nn.Module): """ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, - downsample='avg', linear_out=False, layers: LayerFn = None, drop_block=None, drop_path_rate=0.): + downsample='avg', attn_last=False, linear_out=False, layers: LayerFn = None, + drop_block=None, drop_path_rate=0.): super(EdgeBlock, self).__init__() layers = layers or LayerFn() mid_chs = make_divisible(out_chs * bottle_ratio) @@ -431,14 +518,18 @@ class EdgeBlock(nn.Module): self.conv1_kxk = layers.conv_norm_act( in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], groups=groups, drop_block=drop_block) - self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) + self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) self.conv2_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False) + self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) - def init_weights(self, zero_init_last_bn=False): + def init_weights(self, zero_init_last_bn: bool = False): if zero_init_last_bn: nn.init.zeros_(self.conv2_1x1.bn.weight) + for attn in (self.attn, self.attn_last): + if hasattr(attn, 'reset_parameters'): + attn.reset_parameters() def forward(self, x): shortcut = self.shortcut(x) @@ -446,6 +537,7 @@ class EdgeBlock(nn.Module): x = self.conv1_kxk(x) x = self.attn(x) x = self.conv2_1x1(x) + x = self.attn_last(x) x = self.drop_path(x) x = self.act(x + shortcut) return x @@ -460,7 +552,7 @@ class RepVggBlock(nn.Module): """ def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1.0, group_size=None, - downsample='', layers : LayerFn = None, drop_block=None, drop_path_rate=0.): + downsample='', layers: LayerFn = None, drop_block=None, drop_path_rate=0.): super(RepVggBlock, self).__init__() layers = layers or LayerFn() groups = num_groups(group_size, in_chs) @@ -475,12 +567,15 @@ class RepVggBlock(nn.Module): self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity() self.act = layers.act(inplace=True) - def init_weights(self, zero_init_last_bn=False): + def init_weights(self, zero_init_last_bn: bool = False): # NOTE this init overrides that base model init with specific changes for the block type for m in self.modules(): if isinstance(m, nn.BatchNorm2d): nn.init.normal_(m.weight, .1, .1) nn.init.normal_(m.bias, 0, .1) + for attn in (self.attn, self.attn_last): + if hasattr(attn, 'reset_parameters'): + attn.reset_parameters() def forward(self, x): if self.identity is None: @@ -495,12 +590,68 @@ class RepVggBlock(nn.Module): return x +class SelfAttnBlock(nn.Module): + """ ResNet-like Bottleneck Block - 1x1 - optional kxk - self attn - 1x1 + """ + + def __init__(self, in_chs, out_chs, kernel_size=3, stride=1, dilation=(1, 1), bottle_ratio=1., group_size=None, + downsample='avg', extra_conv=False, linear_out=False, post_attn_na=True, feat_size=None, + layers: LayerFn = None, drop_block=None, drop_path_rate=0.): + super(SelfAttnBlock, self).__init__() + assert layers is not None + mid_chs = make_divisible(out_chs * bottle_ratio) + groups = num_groups(group_size, mid_chs) + + if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: + self.shortcut = create_downsample( + downsample, in_chs=in_chs, out_chs=out_chs, stride=stride, dilation=dilation[0], + apply_act=False, layers=layers) + else: + self.shortcut = nn.Identity() + + self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) + if extra_conv: + self.conv2_kxk = layers.conv_norm_act( + mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], + groups=groups, drop_block=drop_block) + stride = 1 # striding done via conv if enabled + else: + self.conv2_kxk = nn.Identity() + opt_kwargs = {} if feat_size is None else dict(feat_size=feat_size) + # FIXME need to dilate self attn to have dilated network support, moop moop + self.self_attn = layers.self_attn(mid_chs, stride=stride, **opt_kwargs) + self.post_attn = layers.norm_act(mid_chs) if post_attn_na else nn.Identity() + self.conv3_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() + self.act = nn.Identity() if linear_out else layers.act(inplace=True) + + def init_weights(self, zero_init_last_bn: bool = False): + if zero_init_last_bn: + nn.init.zeros_(self.conv3_1x1.bn.weight) + if hasattr(self.self_attn, 'reset_parameters'): + self.self_attn.reset_parameters() + + def forward(self, x): + shortcut = self.shortcut(x) + + x = self.conv1_1x1(x) + x = self.conv2_kxk(x) + x = self.self_attn(x) + x = self.post_attn(x) + x = self.conv3_1x1(x) + x = self.drop_path(x) + + x = self.act(x + shortcut) + return x + + _block_registry = dict( basic=BasicBlock, bottle=BottleneckBlock, dark=DarkBlock, edge=EdgeBlock, rep=RepVggBlock, + self_attn=SelfAttnBlock, ) @@ -552,7 +703,7 @@ class Stem(nn.Sequential): curr_stride *= s prev_feat = conv_name - if 'max' in pool.lower(): + if pool and 'max' in pool.lower(): self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat)) self.add_module('pool', nn.MaxPool2d(3, 2, 1)) curr_stride *= 2 @@ -601,9 +752,58 @@ def reduce_feat_size(feat_size, stride=2): return None if feat_size is None else tuple([s // stride for s in feat_size]) +def override_kwargs(block_kwargs, model_kwargs): + """ Override model level attn/self-attn/block kwargs w/ block level + + NOTE: kwargs are NOT merged across levels, block_kwargs will fully replace model_kwargs + for the block if set to anything that isn't None. + + i.e. an empty block_kwargs dict will remove kwargs set at model level for that block + """ + out_kwargs = block_kwargs if block_kwargs is not None else model_kwargs + return out_kwargs or {} # make sure None isn't returned + + +def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, model_cfg: ByoModelCfg, ): + layer_fns = block_kwargs['layers'] + + # override attn layer / args with block local config + if block_cfg.attn_kwargs is not None or block_cfg.attn_layer is not None: + # override attn layer config + if not block_cfg.attn_layer: + # empty string for attn_layer type will disable attn for this block + attn_layer = None + else: + attn_kwargs = override_kwargs(block_cfg.attn_kwargs, model_cfg.attn_kwargs) + attn_layer = block_cfg.attn_layer or model_cfg.attn_layer + attn_layer = partial(get_attn(attn_layer), *attn_kwargs) if attn_layer is not None else None + layer_fns = replace(layer_fns, attn=attn_layer) + + # override self-attn layer / args with block local cfg + if block_cfg.self_attn_kwargs is not None or block_cfg.self_attn_layer is not None: + # override attn layer config + if not block_cfg.self_attn_layer: + # empty string for self_attn_layer type will disable attn for this block + self_attn_layer = None + else: + self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs) + self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer + self_attn_layer = partial(get_self_attn(self_attn_layer), *self_attn_kwargs) \ + if self_attn_layer is not None else None + layer_fns = replace(layer_fns, self_attn=self_attn_layer) + + block_kwargs['layers'] = layer_fns + + # add additional block_kwargs specified in block_cfg or model_cfg, precedence to block if set + block_kwargs.update(override_kwargs(block_cfg.block_kwargs, model_cfg.block_kwargs)) + + def create_byob_stages( - cfg, drop_path_rate, output_stride, stem_feat, - feat_size=None, layers=None, extra_args_fn=None): + cfg: ByoModelCfg, drop_path_rate: float, output_stride: int, stem_feat: Dict[str, Any], + feat_size: Optional[int] = None, + layers: Optional[LayerFn] = None, + block_kwargs_fn: Optional[Callable] = update_block_kwargs): + layers = layers or LayerFn() feature_info = [] block_cfgs = [expand_blocks_cfg(s) for s in cfg.blocks] @@ -641,8 +841,10 @@ def create_byob_stages( drop_path_rate=dpr[stage_idx][block_idx], layers=layers, ) - if extra_args_fn is not None: - extra_args_fn(block_kwargs, block_cfg=block_cfg, model_cfg=cfg, feat_size=feat_size) + if block_cfg.type in ('self_attn',): + # add feat_size arg for blocks that support/need it + block_kwargs['feat_size'] = feat_size + block_kwargs_fn(block_kwargs, block_cfg=block_cfg, model_cfg=cfg) blocks += [create_block(block_cfg.type, **block_kwargs)] first_dilation = dilation prev_chs = out_chs @@ -656,12 +858,13 @@ def create_byob_stages( return nn.Sequential(*stages), feature_info -def get_layer_fns(cfg: ByobCfg): +def get_layer_fns(cfg: ByoModelCfg): act = get_act_layer(cfg.act_layer) norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act) conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act) attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None - layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn) + self_attn = partial(get_self_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None + layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn) return layer_fn @@ -673,19 +876,24 @@ class ByobNet(nn.Module): Current assumption is that both stem and blocks are in conv-bn-act order (w/ block ending in act). """ - def __init__(self, cfg: ByobCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, - zero_init_last_bn=True, drop_rate=0., drop_path_rate=0.): + def __init__(self, cfg: ByoModelCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, + zero_init_last_bn=True, img_size=None, drop_rate=0., drop_path_rate=0.): super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate layers = get_layer_fns(cfg) + if cfg.fixed_input_size: + assert img_size is not None, 'img_size argument is required for fixed input size model' + feat_size = to_2tuple(img_size) if img_size is not None else None self.feature_info = [] stem_chs = int(round((cfg.stem_chs or cfg.blocks[0].c) * cfg.width_factor)) self.stem, stem_feat = create_byob_stem(in_chans, stem_chs, cfg.stem_type, cfg.stem_pool, layers=layers) self.feature_info.extend(stem_feat[:-1]) + feat_size = reduce_feat_size(feat_size, stride=stem_feat[-1]['reduction']) - self.stages, stage_feat = create_byob_stages(cfg, drop_path_rate, output_stride, stem_feat[-1], layers=layers) + self.stages, stage_feat = create_byob_stages( + cfg, drop_path_rate, output_stride, stem_feat[-1], layers=layers, feat_size=feat_size) self.feature_info.extend(stage_feat[:-1]) prev_chs = stage_feat[-1]['num_chs'] @@ -836,3 +1044,24 @@ def repvgg_b3g4(pretrained=False, **kwargs): `Making VGG-style ConvNets Great Again` - https://arxiv.org/abs/2101.03697 """ return _create_byobnet('repvgg_b3g4', pretrained=pretrained, **kwargs) + + +@register_model +def resnet52q(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('geresnet50t', pretrained=pretrained, **kwargs) + + +@register_model +def geresnet50t(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('geresnet50t', pretrained=pretrained, **kwargs) + + +@register_model +def gcresnet50t(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('gcresnet50t', pretrained=pretrained, **kwargs) diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index cd192281..30a1b40d 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -14,20 +14,22 @@ from .create_conv2d import create_conv2d from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act from .create_self_attn import get_self_attn, create_self_attn from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path -from .eca import EcaModule, CecaModule +from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn from .evo_norm import EvoNormBatch2d, EvoNormSample2d +from .gather_excite import GatherExcite +from .global_context import GlobalContext from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible from .inplace_abn import InplaceAbn from .involution import Involution from .linear import Linear from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp -from .norm import GroupNorm +from .norm import GroupNorm, LayerNorm2d from .norm_act import BatchNormAct2d, GroupNormAct from .padding import get_padding, get_same_padding, pad_same from .patch_embed import PatchEmbed from .pool2d_same import AvgPool2dSame, create_pool2d -from .se import SEModule +from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite from .selective_kernel import SelectiveKernelConv from .separable_conv import SeparableConv2d, SeparableConvBnAct from .space_to_depth import SpaceToDepthModule diff --git a/timm/models/layers/cbam.py b/timm/models/layers/cbam.py index 44e2fe6d..bacf5cf0 100644 --- a/timm/models/layers/cbam.py +++ b/timm/models/layers/cbam.py @@ -7,78 +7,87 @@ some tasks, especially fine-grained it seems. I may end up removing this impl. Hacked together by / Copyright 2020 Ross Wightman """ - import torch from torch import nn as nn import torch.nn.functional as F + from .conv_bn_act import ConvBnAct +from .create_act import create_act_layer, get_act_layer +from .helpers import make_divisible class ChannelAttn(nn.Module): """ Original CBAM channel attention module, currently avg + max pool variant only. """ - def __init__(self, channels, reduction=16, act_layer=nn.ReLU): + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): super(ChannelAttn, self).__init__() - self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False) + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.fc1 = nn.Conv2d(channels, rd_channels, 1, bias=mlp_bias) self.act = act_layer(inplace=True) - self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False) + self.fc2 = nn.Conv2d(rd_channels, channels, 1, bias=mlp_bias) + self.gate = create_act_layer(gate_layer) def forward(self, x): - x_avg = x.mean((2, 3), keepdim=True) - x_max = F.adaptive_max_pool2d(x, 1) - x_avg = self.fc2(self.act(self.fc1(x_avg))) - x_max = self.fc2(self.act(self.fc1(x_max))) - x_attn = x_avg + x_max - return x * x_attn.sigmoid() + x_avg = self.fc2(self.act(self.fc1(x.mean((2, 3), keepdim=True)))) + x_max = self.fc2(self.act(self.fc1(x.amax((2, 3), keepdim=True)))) + return x * self.gate(x_avg + x_max) class LightChannelAttn(ChannelAttn): """An experimental 'lightweight' that sums avg + max pool first """ - def __init__(self, channels, reduction=16): - super(LightChannelAttn, self).__init__(channels, reduction) + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): + super(LightChannelAttn, self).__init__( + channels, rd_ratio, rd_channels, rd_divisor, act_layer, gate_layer, mlp_bias) def forward(self, x): - x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * F.adaptive_max_pool2d(x, 1) + x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * x.amax((2, 3), keepdim=True) x_attn = self.fc2(self.act(self.fc1(x_pool))) - return x * x_attn.sigmoid() + return x * F.sigmoid(x_attn) class SpatialAttn(nn.Module): """ Original CBAM spatial attention module """ - def __init__(self, kernel_size=7): + def __init__(self, kernel_size=7, gate_layer='sigmoid'): super(SpatialAttn, self).__init__() self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None) + self.gate = create_act_layer(gate_layer) def forward(self, x): - x_avg = torch.mean(x, dim=1, keepdim=True) - x_max = torch.max(x, dim=1, keepdim=True)[0] - x_attn = torch.cat([x_avg, x_max], dim=1) + x_attn = torch.cat([x.mean(dim=1, keepdim=True), x.amax(dim=1, keepdim=True)], dim=1) x_attn = self.conv(x_attn) - return x * x_attn.sigmoid() + return x * self.gate(x_attn) class LightSpatialAttn(nn.Module): """An experimental 'lightweight' variant that sums avg_pool and max_pool results. """ - def __init__(self, kernel_size=7): + def __init__(self, kernel_size=7, gate_layer='sigmoid'): super(LightSpatialAttn, self).__init__() self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None) + self.gate = create_act_layer(gate_layer) def forward(self, x): - x_avg = torch.mean(x, dim=1, keepdim=True) - x_max = torch.max(x, dim=1, keepdim=True)[0] - x_attn = 0.5 * x_avg + 0.5 * x_max + x_attn = 0.5 * x.mean(dim=1, keepdim=True) + 0.5 * x.amax(dim=1, keepdim=True) x_attn = self.conv(x_attn) - return x * x_attn.sigmoid() + return x * self.gate(x_attn) class CbamModule(nn.Module): - def __init__(self, channels, spatial_kernel_size=7): + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): super(CbamModule, self).__init__() - self.channel = ChannelAttn(channels) - self.spatial = SpatialAttn(spatial_kernel_size) + self.channel = ChannelAttn( + channels, rd_ratio=rd_ratio, rd_channels=rd_channels, + rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) + self.spatial = SpatialAttn(spatial_kernel_size, gate_layer=gate_layer) def forward(self, x): x = self.channel(x) @@ -87,9 +96,13 @@ class CbamModule(nn.Module): class LightCbamModule(nn.Module): - def __init__(self, channels, spatial_kernel_size=7): + def __init__( + self, channels, rd_ratio=1./16, rd_channels=None, rd_divisor=1, + spatial_kernel_size=7, act_layer=nn.ReLU, gate_layer='sigmoid', mlp_bias=False): super(LightCbamModule, self).__init__() - self.channel = LightChannelAttn(channels) + self.channel = LightChannelAttn( + channels, rd_ratio=rd_ratio, rd_channels=rd_channels, + rd_divisor=rd_divisor, act_layer=act_layer, gate_layer=gate_layer, mlp_bias=mlp_bias) self.spatial = LightSpatialAttn(spatial_kernel_size) def forward(self, x): diff --git a/timm/models/layers/create_attn.py b/timm/models/layers/create_attn.py index ff20e5df..de866eea 100644 --- a/timm/models/layers/create_attn.py +++ b/timm/models/layers/create_attn.py @@ -3,9 +3,12 @@ Hacked together by / Copyright 2020 Ross Wightman """ import torch -from .se import SEModule, EffectiveSEModule -from .eca import EcaModule, CecaModule + from .cbam import CbamModule, LightCbamModule +from .eca import EcaModule, CecaModule +from .gather_excite import GatherExcite +from .global_context import GlobalContext +from .squeeze_excite import SEModule, EffectiveSEModule def get_attn(attn_type): @@ -23,6 +26,10 @@ def get_attn(attn_type): module_cls = EcaModule elif attn_type == 'ceca': module_cls = CecaModule + elif attn_type == 'ge': + module_cls = GatherExcite + elif attn_type == 'gc': + module_cls = GlobalContext elif attn_type == 'cbam': module_cls = CbamModule elif attn_type == 'lcbam': diff --git a/timm/models/layers/eca.py b/timm/models/layers/eca.py index 3a7f8b82..d0d8f74a 100644 --- a/timm/models/layers/eca.py +++ b/timm/models/layers/eca.py @@ -65,6 +65,9 @@ class EcaModule(nn.Module): return x * y.expand_as(x) +EfficientChannelAttn = EcaModule # alias + + class CecaModule(nn.Module): """Constructs a circular ECA module. @@ -105,3 +108,6 @@ class CecaModule(nn.Module): y = self.conv(y) y = y.view(x.shape[0], -1, 1, 1).sigmoid() return x * y.expand_as(x) + + +CircularEfficientChannelAttn = CecaModule \ No newline at end of file diff --git a/timm/models/layers/gather_excite.py b/timm/models/layers/gather_excite.py new file mode 100644 index 00000000..2d60dc96 --- /dev/null +++ b/timm/models/layers/gather_excite.py @@ -0,0 +1,90 @@ +""" Gather-Excite Attention Block + +Paper: `Gather-Excite: Exploiting Feature Context in CNNs` - https://arxiv.org/abs/1810.12348 + +Official code here, but it's only partial impl in Caffe: https://github.com/hujie-frank/GENet + +I've tried to support all of the extent both w/ and w/o params. I don't believe I've seen another +impl that covers all of the cases. + +NOTE: extent=0 + extra_params=False is equivalent to Squeeze-and-Excitation + +Hacked together by / Copyright 2021 Ross Wightman +""" +import math + +from torch import nn as nn +import torch.nn.functional as F + +from .create_act import create_act_layer, get_act_layer +from .create_conv2d import create_conv2d +from .helpers import make_divisible +from .mlp import ConvMlp + + +class GatherExcite(nn.Module): + """ Gather-Excite Attention Module + """ + def __init__( + self, channels, feat_size=None, extra_params=False, extent=0, use_mlp=True, + rd_ratio=1./16, rd_channels=None, rd_divisor=1, add_maxpool=False, + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, gate_layer='sigmoid'): + super(GatherExcite, self).__init__() + self.add_maxpool = add_maxpool + act_layer = get_act_layer(act_layer) + self.extent = extent + if extra_params: + self.gather = nn.Sequential() + if extent == 0: + assert feat_size is not None, 'spatial feature size must be specified for global extent w/ params' + self.gather.add_module( + 'conv1', create_conv2d(channels, channels, kernel_size=feat_size, stride=1, depthwise=True)) + if norm_layer: + self.gather.add_module(f'norm1', nn.BatchNorm2d(channels)) + else: + assert extent % 2 == 0 + num_conv = int(math.log2(extent)) + for i in range(num_conv): + self.gather.add_module( + f'conv{i + 1}', + create_conv2d(channels, channels, kernel_size=3, stride=2, depthwise=True)) + if norm_layer: + self.gather.add_module(f'norm{i + 1}', nn.BatchNorm2d(channels)) + if i != num_conv - 1: + self.gather.add_module(f'act{i + 1}', act_layer(inplace=True)) + else: + self.gather = None + if self.extent == 0: + self.gk = 0 + self.gs = 0 + else: + assert extent % 2 == 0 + self.gk = self.extent * 2 - 1 + self.gs = self.extent + + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.mlp = ConvMlp(channels, rd_channels, act_layer=act_layer) if use_mlp else nn.Identity() + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + size = x.shape[-2:] + if self.gather is not None: + x_ge = self.gather(x) + else: + if self.extent == 0: + # global extent + x_ge = x.mean(dim=(2, 3), keepdims=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_ge = 0.5 * x_ge + 0.5 * x.amax((2, 3), keepdim=True) + else: + x_ge = F.avg_pool2d( + x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2, count_include_pad=False) + if self.add_maxpool: + # experimental codepath, may remove or change + x_ge = 0.5 * x_ge + 0.5 * F.max_pool2d(x, kernel_size=self.gk, stride=self.gs, padding=self.gk // 2) + x_ge = self.mlp(x_ge) + if x_ge.shape[-1] != 1 or x_ge.shape[-2] != 1: + x_ge = F.interpolate(x_ge, size=size) + return x * self.gate(x_ge) diff --git a/timm/models/layers/global_context.py b/timm/models/layers/global_context.py new file mode 100644 index 00000000..4c2c82f3 --- /dev/null +++ b/timm/models/layers/global_context.py @@ -0,0 +1,67 @@ +""" Global Context Attention Block + +Paper: `GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond` + - https://arxiv.org/abs/1904.11492 + +Official code consulted as reference: https://github.com/xvjiarui/GCNet + +Hacked together by / Copyright 2021 Ross Wightman +""" +from torch import nn as nn +import torch.nn.functional as F + +from .create_act import create_act_layer, get_act_layer +from .helpers import make_divisible +from .mlp import ConvMlp +from .norm import LayerNorm2d + + +class GlobalContext(nn.Module): + + def __init__(self, channels, use_attn=True, fuse_add=True, fuse_scale=False, init_last_zero=False, + rd_ratio=1./8, rd_channels=None, rd_divisor=1, act_layer=nn.ReLU, gate_layer='sigmoid'): + super(GlobalContext, self).__init__() + act_layer = get_act_layer(act_layer) + + self.conv_attn = nn.Conv2d(channels, 1, kernel_size=1, bias=True) if use_attn else None + + if rd_channels is None: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + if fuse_add: + self.mlp_add = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) + else: + self.mlp_add = None + if fuse_scale: + self.mlp_scale = ConvMlp(channels, rd_channels, act_layer=act_layer, norm_layer=LayerNorm2d) + else: + self.mlp_scale = None + + self.gate = create_act_layer(gate_layer) + self.init_last_zero = init_last_zero + self.reset_parameters() + + def reset_parameters(self): + if self.conv_attn is not None: + nn.init.kaiming_normal_(self.conv_attn.weight, mode='fan_in', nonlinearity='relu') + if self.mlp_add is not None: + nn.init.zeros_(self.mlp_add.fc2.weight) + + def forward(self, x): + B, C, H, W = x.shape + + if self.conv_attn is not None: + attn = self.conv_attn(x).reshape(B, 1, H * W) # (B, 1, H * W) + attn = F.softmax(attn, dim=-1).unsqueeze(3) # (B, 1, H * W, 1) + context = x.reshape(B, C, H * W).unsqueeze(1) @ attn + context = context.view(B, C, 1, 1) + else: + context = x.mean(dim=(2, 3), keepdim=True) + + if self.mlp_scale is not None: + mlp_x = self.mlp_scale(context) + x = x * self.gate(mlp_x) + if self.mlp_add is not None: + mlp_x = self.mlp_add(context) + x = x + mlp_x + + return x diff --git a/timm/models/layers/involution.py b/timm/models/layers/involution.py index 0dba9fae..ccdeefcb 100644 --- a/timm/models/layers/involution.py +++ b/timm/models/layers/involution.py @@ -16,7 +16,7 @@ class Involution(nn.Module): kernel_size=3, stride=1, group_size=16, - reduction_ratio=4, + rd_ratio=4, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, ): @@ -28,12 +28,12 @@ class Involution(nn.Module): self.groups = self.channels // self.group_size self.conv1 = ConvBnAct( in_channels=channels, - out_channels=channels // reduction_ratio, + out_channels=channels // rd_ratio, kernel_size=1, norm_layer=norm_layer, act_layer=act_layer) self.conv2 = self.conv = create_conv2d( - in_channels=channels // reduction_ratio, + in_channels=channels // rd_ratio, out_channels=kernel_size**2 * self.groups, kernel_size=1, stride=1) diff --git a/timm/models/layers/mlp.py b/timm/models/layers/mlp.py index b3f8de11..4739ba74 100644 --- a/timm/models/layers/mlp.py +++ b/timm/models/layers/mlp.py @@ -77,3 +77,26 @@ class GatedMlp(nn.Module): x = self.fc2(x) x = self.drop(x) return x + + +class ConvMlp(nn.Module): + """ MLP using 1x1 convs that keeps spatial dims + """ + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True) + self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.norm(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + return x diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index 2925e5c7..433552b4 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -12,3 +12,12 @@ class GroupNorm(nn.GroupNorm): def forward(self, x): return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + + +class LayerNorm2d(nn.LayerNorm): + """ Layernorm for channels of '2d' spatial BCHW tensors """ + def __init__(self, num_channels): + super().__init__([num_channels, 1, 1]) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) diff --git a/timm/models/layers/se.py b/timm/models/layers/se.py deleted file mode 100644 index 4354144d..00000000 --- a/timm/models/layers/se.py +++ /dev/null @@ -1,50 +0,0 @@ -from torch import nn as nn -import torch.nn.functional as F - -from .create_act import create_act_layer -from .helpers import make_divisible - - -class SEModule(nn.Module): - """ SE Module as defined in original SE-Nets with a few additions - Additions include: - * min_channels can be specified to keep reduced channel count at a minimum (default: 8) - * divisor can be specified to keep channels rounded to specified values (default: 1) - * reduction channels can be specified directly by arg (if reduction_channels is set) - * reduction channels can be specified by float ratio (if reduction_ratio is set) - """ - def __init__(self, channels, reduction=16, act_layer=nn.ReLU, gate_layer='sigmoid', - reduction_ratio=None, reduction_channels=None, min_channels=8, divisor=1): - super(SEModule, self).__init__() - if reduction_channels is not None: - reduction_channels = reduction_channels # direct specification highest priority, no rounding/min done - elif reduction_ratio is not None: - reduction_channels = make_divisible(channels * reduction_ratio, divisor, min_channels) - else: - reduction_channels = make_divisible(channels // reduction, divisor, min_channels) - self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True) - self.act = act_layer(inplace=True) - self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True) - self.gate = create_act_layer(gate_layer) - - def forward(self, x): - x_se = x.mean((2, 3), keepdim=True) - x_se = self.fc1(x_se) - x_se = self.act(x_se) - x_se = self.fc2(x_se) - return x * self.gate(x_se) - - -class EffectiveSEModule(nn.Module): - """ 'Effective Squeeze-Excitation - From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 - """ - def __init__(self, channels, gate_layer='hard_sigmoid'): - super(EffectiveSEModule, self).__init__() - self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) - self.gate = create_act_layer(gate_layer) - - def forward(self, x): - x_se = x.mean((2, 3), keepdim=True) - x_se = self.fc(x_se) - return x * self.gate(x_se) diff --git a/timm/models/layers/squeeze_excite.py b/timm/models/layers/squeeze_excite.py new file mode 100644 index 00000000..3e8a05bb --- /dev/null +++ b/timm/models/layers/squeeze_excite.py @@ -0,0 +1,74 @@ +""" Squeeze-and-Excitation Channel Attention + +An SE implementation originally based on PyTorch SE-Net impl. +Has since evolved with additional functionality / configuration. + +Paper: `Squeeze-and-Excitation Networks` - https://arxiv.org/abs/1709.01507 + +Also included is Effective Squeeze-Excitation (ESE). +Paper: `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + +Hacked together by / Copyright 2021 Ross Wightman +""" +from torch import nn as nn + +from .create_act import create_act_layer +from .helpers import make_divisible + + +class SEModule(nn.Module): + """ SE Module as defined in original SE-Nets with a few additions + Additions include: + * divisor can be specified to keep channels % div == 0 (default: 8) + * reduction channels can be specified directly by arg (if rd_channels is set) + * reduction channels can be specified by float rd_ratio (default: 1/16) + * global max pooling can be added to the squeeze aggregation + * customizable activation, normalization, and gate layer + """ + def __init__( + self, channels, rd_ratio=1. / 16, rd_channels=None, rd_divisor=8, add_maxpool=False, + act_layer=nn.ReLU, norm_layer=None, gate_layer='sigmoid'): + super(SEModule, self).__init__() + self.add_maxpool = add_maxpool + if not rd_channels: + rd_channels = make_divisible(channels * rd_ratio, rd_divisor, round_limit=0.) + self.fc1 = nn.Conv2d(channels, rd_channels, kernel_size=1, bias=True) + self.bn = norm_layer(rd_channels) if norm_layer else nn.Identity() + self.act = create_act_layer(act_layer, inplace=True) + self.fc2 = nn.Conv2d(rd_channels, channels, kernel_size=1, bias=True) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) + x_se = self.fc1(x_se) + x_se = self.act(self.bn(x_se)) + x_se = self.fc2(x_se) + return x * self.gate(x_se) + + +SqueezeExcite = SEModule # alias + + +class EffectiveSEModule(nn.Module): + """ 'Effective Squeeze-Excitation + From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 + """ + def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid'): + super(EffectiveSEModule, self).__init__() + self.add_maxpool = add_maxpool + self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) + self.gate = create_act_layer(gate_layer) + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + if self.add_maxpool: + # experimental codepath, may remove or change + x_se = 0.5 * x_se + 0.5 * x.amax((2, 3), keepdim=True) + x_se = self.fc(x_se) + return x * self.gate(x_se) + + +EffectiveSqueezeExcite = EffectiveSEModule # alias diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 1b67581e..593796a5 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -182,7 +182,7 @@ def _nfres_cfg( def _nfreg_cfg(depths, channels=(48, 104, 208, 440)): num_features = 1280 * channels[-1] // 440 - attn_kwargs = dict(reduction_ratio=0.5, divisor=8) + attn_kwargs = dict(rd_ratio=0.5) cfg = NfCfg( depths=depths, channels=channels, stem_type='3x3', group_size=8, width_factor=0.75, bottle_ratio=2.25, num_features=num_features, reg=True, attn_layer='se', attn_kwargs=attn_kwargs) @@ -193,7 +193,7 @@ def _nfnet_cfg( depths, channels=(256, 512, 1536, 1536), group_size=128, bottle_ratio=0.5, feat_mult=2., act_layer='gelu', attn_layer='se', attn_kwargs=None): num_features = int(channels[-1] * feat_mult) - attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(reduction_ratio=0.5, divisor=8) + attn_kwargs = attn_kwargs if attn_kwargs is not None else dict(rd_ratio=0.5) cfg = NfCfg( depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=group_size, bottle_ratio=bottle_ratio, extra_conv=True, num_features=num_features, act_layer=act_layer, @@ -202,11 +202,10 @@ def _nfnet_cfg( def _dm_nfnet_cfg(depths, channels=(256, 512, 1536, 1536), act_layer='gelu', skipinit=True): - attn_kwargs = dict(reduction_ratio=0.5, divisor=8) cfg = NfCfg( depths=depths, channels=channels, stem_type='deep_quad', stem_chs=128, group_size=128, bottle_ratio=0.5, extra_conv=True, gamma_in_act=True, same_padding=True, skipinit=skipinit, - num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=attn_kwargs) + num_features=int(channels[-1] * 2.0), act_layer=act_layer, attn_layer='se', attn_kwargs=dict(rd_ratio=0.5)) return cfg @@ -243,7 +242,7 @@ model_cfgs = dict( # Experimental 'light' versions of NFNet-F that are little leaner nfnet_l0=_nfnet_cfg( depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25, - attn_kwargs=dict(reduction_ratio=0.25, divisor=8), act_layer='silu'), + attn_kwargs=dict(rd_ratio=0.25, rd_divisor=8), act_layer='silu'), eca_nfnet_l0=_nfnet_cfg( depths=(1, 2, 6, 3), feat_mult=1.5, group_size=64, bottle_ratio=0.25, attn_layer='eca', attn_kwargs=dict(), act_layer='silu'), @@ -272,9 +271,9 @@ model_cfgs = dict( nf_resnet50=_nfres_cfg(depths=(3, 4, 6, 3)), nf_resnet101=_nfres_cfg(depths=(3, 4, 23, 3)), - nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)), - nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)), - nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(reduction_ratio=1/16)), + nf_seresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)), + nf_seresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)), + nf_seresnet101=_nfres_cfg(depths=(3, 4, 23, 3), attn_layer='se', attn_kwargs=dict(rd_ratio=1/16)), nf_ecaresnet26=_nfres_cfg(depths=(2, 2, 2, 2), attn_layer='eca', attn_kwargs=dict()), nf_ecaresnet50=_nfres_cfg(depths=(3, 4, 6, 3), attn_layer='eca', attn_kwargs=dict()), diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 3b7dba52..6a381074 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -146,7 +146,7 @@ class Bottleneck(nn.Module): groups=groups, **cargs) if se_ratio: se_channels = int(round(in_chs * se_ratio)) - self.se = SEModule(bottleneck_chs, reduction_channels=se_channels) + self.se = SEModule(bottleneck_chs, rd_channels=se_channels) else: self.se = None cargs['act_layer'] = None diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 2b0b0339..2f02f12a 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -1122,7 +1122,7 @@ def resnetrs50(pretrained=False, **kwargs): Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs """ - attn_layer = partial(get_attn('se'), reduction_ratio=0.25) + attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( block=Bottleneck, layers=[3, 4, 6, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) @@ -1135,7 +1135,7 @@ def resnetrs101(pretrained=False, **kwargs): Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs """ - attn_layer = partial(get_attn('se'), reduction_ratio=0.25) + attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( block=Bottleneck, layers=[3, 4, 23, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) @@ -1148,7 +1148,7 @@ def resnetrs152(pretrained=False, **kwargs): Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs """ - attn_layer = partial(get_attn('se'), reduction_ratio=0.25) + attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( block=Bottleneck, layers=[3, 8, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) @@ -1161,7 +1161,7 @@ def resnetrs200(pretrained=False, **kwargs): Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs """ - attn_layer = partial(get_attn('se'), reduction_ratio=0.25) + attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( block=Bottleneck, layers=[3, 24, 36, 3], stem_width=32, stem_type='deep', replace_stem_pool=True, avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) @@ -1174,7 +1174,7 @@ def resnetrs270(pretrained=False, **kwargs): Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs """ - attn_layer = partial(get_attn('se'), reduction_ratio=0.25) + attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( block=Bottleneck, layers=[4, 29, 53, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) @@ -1188,7 +1188,7 @@ def resnetrs350(pretrained=False, **kwargs): Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs """ - attn_layer = partial(get_attn('se'), reduction_ratio=0.25) + attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( block=Bottleneck, layers=[4, 36, 72, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) @@ -1201,7 +1201,7 @@ def resnetrs420(pretrained=False, **kwargs): Paper: Revisiting ResNets - https://arxiv.org/abs/2103.07579 Pretrained weights from https://github.com/tensorflow/tpu/tree/bee9c4f6/models/official/resnet/resnet_rs """ - attn_layer = partial(get_attn('se'), reduction_ratio=0.25) + attn_layer = partial(get_attn('se'), rd_ratio=0.25) model_args = dict( block=Bottleneck, layers=[4, 44, 87, 4], stem_width=32, stem_type='deep', replace_stem_pool=True, avg_down=True, block_args=dict(attn_layer=attn_layer), **kwargs) diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 859b584e..7ab8d659 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -11,11 +11,12 @@ Copyright 2020 Ross Wightman """ import torch.nn as nn +from functools import partial from math import ceil from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath, make_divisible +from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath, make_divisible, SEModule from .registry import register_model from .efficientnet_builder import efficientnet_init_weights @@ -48,26 +49,7 @@ default_cfgs = dict( url=''), ) - -class SEWithNorm(nn.Module): - - def __init__(self, channels, se_ratio=1 / 12., act_layer=nn.ReLU, divisor=1, reduction_channels=None, - gate_layer='sigmoid'): - super(SEWithNorm, self).__init__() - reduction_channels = reduction_channels or make_divisible(int(channels * se_ratio), divisor=divisor) - self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True) - self.bn = nn.BatchNorm2d(reduction_channels) - self.act = act_layer(inplace=True) - self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True) - self.gate = create_act_layer(gate_layer) - - def forward(self, x): - x_se = x.mean((2, 3), keepdim=True) - x_se = self.fc1(x_se) - x_se = self.bn(x_se) - x_se = self.act(x_se) - x_se = self.fc2(x_se) - return x * self.gate(x_se) +SEWithNorm = partial(SEModule, norm_layer=nn.BatchNorm2d) class LinearBottleneck(nn.Module): @@ -86,7 +68,10 @@ class LinearBottleneck(nn.Module): self.conv_exp = None self.conv_dw = ConvBnAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False) - self.se = SEWithNorm(dw_chs, se_ratio=se_ratio, divisor=ch_div) if se_ratio > 0. else None + if se_ratio > 0: + self.se = SEWithNorm(dw_chs, rd_channels=make_divisible(int(dw_chs * se_ratio), ch_div)) + else: + self.se = None self.act_dw = create_act_layer(dw_act_layer) self.conv_pwl = ConvBnAct(dw_chs, out_chs, 1, apply_act=False) diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index 9fb34c20..372bfb7b 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -84,8 +84,8 @@ class BasicBlock(nn.Module): self.relu = nn.ReLU(inplace=True) self.downsample = downsample self.stride = stride - reduction_chs = max(planes * self.expansion // 4, 64) - self.se = SEModule(planes * self.expansion, reduction_channels=reduction_chs) if use_se else None + rd_chs = max(planes * self.expansion // 4, 64) + self.se = SEModule(planes * self.expansion, rd_channels=rd_chs) if use_se else None def forward(self, x): if self.downsample is not None: @@ -125,7 +125,7 @@ class Bottleneck(nn.Module): aa_layer(channels=planes, filt_size=3, stride=2)) reduction_chs = max(planes * self.expansion // 8, 64) - self.se = SEModule(planes, reduction_channels=reduction_chs) if use_se else None + self.se = SEModule(planes, rd_channels=reduction_chs) if use_se else None self.conv3 = conv2d_iabn( planes, planes * self.expansion, kernel_size=1, stride=1, act_layer="identity") diff --git a/timm/models/visformer.py b/timm/models/visformer.py index 33a2fe87..5583ea3c 100644 --- a/timm/models/visformer.py +++ b/timm/models/visformer.py @@ -13,7 +13,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, overlay_external_default_cfg -from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed +from .layers import to_2tuple, trunc_normal_, DropPath, PatchEmbed, LayerNorm2d from .registry import register_model @@ -39,15 +39,6 @@ default_cfgs = dict( ) -class LayerNormBHWC(nn.LayerNorm): - def __init__(self, dim): - super().__init__(dim) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return F.layer_norm( - x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) - - class SpatialMlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., group=8, spatial_conv=False): @@ -119,7 +110,7 @@ class Attention(nn.Module): class Block(nn.Module): def __init__(self, dim, num_heads, head_dim_ratio=1., mlp_ratio=4., - drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNormBHWC, + drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=LayerNorm2d, group=8, attn_disabled=False, spatial_conv=False): super().__init__() self.spatial_conv = spatial_conv @@ -148,7 +139,7 @@ class Block(nn.Module): class Visformer(nn.Module): def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, init_channels=32, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., drop_rate=0., attn_drop_rate=0., drop_path_rate=0., - norm_layer=LayerNormBHWC, attn_stage='111', pos_embed=True, spatial_conv='111', + norm_layer=LayerNorm2d, attn_stage='111', pos_embed=True, spatial_conv='111', vit_stem=False, group=8, pool=True, conv_init=False, embed_norm=None): super().__init__() self.num_classes = num_classes