From ab49d275de8a9c344aea086fd86d04c4cabb6098 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 14 Dec 2021 13:48:30 -0800 Subject: [PATCH 1/3] Significant norm update * ConvBnAct layer renamed -> ConvNormAct and ConvNormActAa for anti-aliased * Significant update to EfficientNet and MobileNetV3 arch to support NormAct layers and grouped conv (as alternative to depthwise) * Update RegNet to add Z variant * Add Pre variant of XceptionAligned that works with NormAct layers * EvoNorm matches bits_and_tpu branch for merge --- timm/models/byobnet.py | 27 +-- timm/models/cspnet.py | 39 +-- timm/models/densenet.py | 4 +- timm/models/dpn.py | 4 +- timm/models/efficientnet.py | 98 ++++++-- timm/models/efficientnet_blocks.py | 180 ++++++-------- timm/models/efficientnet_builder.py | 66 ++++-- timm/models/layers/__init__.py | 6 +- timm/models/layers/cbam.py | 6 +- timm/models/layers/conv_bn_act.py | 55 ++++- timm/models/layers/create_conv2d.py | 7 +- timm/models/layers/create_norm_act.py | 31 ++- timm/models/layers/drop.py | 23 +- timm/models/layers/evo_norm.py | 60 +++-- timm/models/layers/inplace_abn.py | 4 +- timm/models/layers/non_local_attn.py | 10 +- timm/models/layers/norm_act.py | 111 +++++++-- timm/models/layers/pooled_attn.py | 143 +++++++++++ timm/models/layers/selective_kernel.py | 17 +- timm/models/layers/separable_conv.py | 19 +- timm/models/layers/split_attn.py | 7 +- timm/models/mobilenetv3.py | 9 +- timm/models/nasnet.py | 4 +- timm/models/pnasnet.py | 4 +- timm/models/regnet.py | 313 +++++++++++++++---------- timm/models/resnest.py | 12 +- timm/models/resnet.py | 32 +-- timm/models/rexnet.py | 12 +- timm/models/sknet.py | 27 +-- timm/models/vovnet.py | 20 +- timm/models/xception_aligned.py | 128 +++++++--- 31 files changed, 955 insertions(+), 523 deletions(-) create mode 100644 timm/models/layers/pooled_attn.py diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 44f26e4e..e7faa63d 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -34,8 +34,8 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, named_apply -from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ - create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple, EvoNorm2dS0, EvoNorm2dS0a,\ +from .layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ + create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0, EvoNorm2dS0a,\ EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a, FilterResponseNormAct2d, FilterResponseNormTlu2d from .registry import register_model @@ -921,7 +921,7 @@ def num_groups(group_size, channels): @dataclass class LayerFn: - conv_norm_act: Callable = ConvBnAct + conv_norm_act: Callable = ConvNormAct norm_act: Callable = BatchNormAct2d act: Callable = nn.ReLU attn: Optional[Callable] = None @@ -978,7 +978,7 @@ class BasicBlock(nn.Module): self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0]) self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) self.conv2_kxk = layers.conv_norm_act( - mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block, apply_act=False) + mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups, drop_layer=drop_block, apply_act=False) self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) @@ -1019,11 +1019,9 @@ class BottleneckBlock(nn.Module): self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) self.conv2_kxk = layers.conv_norm_act( - mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], - groups=groups, drop_block=drop_block) + mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block) if extra_conv: - self.conv2b_kxk = layers.conv_norm_act( - mid_chs, mid_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block) + self.conv2b_kxk = layers.conv_norm_act(mid_chs, mid_chs, kernel_size, dilation=dilation[1], groups=groups) else: self.conv2b_kxk = nn.Identity() self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) @@ -1080,7 +1078,7 @@ class DarkBlock(nn.Module): self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) self.conv2_kxk = layers.conv_norm_act( mid_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0], - groups=groups, drop_block=drop_block, apply_act=False) + groups=groups, drop_layer=drop_block, apply_act=False) self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.act = nn.Identity() if linear_out else layers.act(inplace=True) @@ -1127,8 +1125,7 @@ class EdgeBlock(nn.Module): apply_act=False, layers=layers) self.conv1_kxk = layers.conv_norm_act( - in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], - groups=groups, drop_block=drop_block) + in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block) self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) self.conv2_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False) self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) @@ -1172,7 +1169,7 @@ class RepVggBlock(nn.Module): self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None self.conv_kxk = layers.conv_norm_act( in_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0], - groups=groups, drop_block=drop_block, apply_act=False) + groups=groups, drop_layer=drop_block, apply_act=False) self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False) self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity() @@ -1219,7 +1216,7 @@ class SelfAttnBlock(nn.Module): if extra_conv: self.conv2_kxk = layers.conv_norm_act( mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], - groups=groups, drop_block=drop_block) + groups=groups, drop_layer=drop_block) stride = 1 # striding done via conv if enabled else: self.conv2_kxk = nn.Identity() @@ -1466,8 +1463,8 @@ def create_byob_stages( def get_layer_fns(cfg: ByoModelCfg): act = get_act_layer(cfg.act_layer) - norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act) - conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act) + norm_act = get_norm_act_layer(norm_layer=cfg.norm_layer, act_layer=act) + conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act) attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None self_attn = partial(get_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn) diff --git a/timm/models/cspnet.py b/timm/models/cspnet.py index 39d16200..aa57bd88 100644 --- a/timm/models/cspnet.py +++ b/timm/models/cspnet.py @@ -14,11 +14,10 @@ Hacked together by / Copyright 2020 Ross Wightman """ import torch import torch.nn as nn -import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import ClassifierHead, ConvBnAct, DropPath, create_attn, get_norm_act_layer +from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, create_attn, get_norm_act_layer from .registry import register_model @@ -130,7 +129,7 @@ model_cfgs = dict( def create_stem( in_chans=3, out_chs=32, kernel_size=3, stride=2, pool='', - act_layer=None, norm_layer=None, aa_layer=None): + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None): stem = nn.Sequential() if not isinstance(out_chs, (tuple, list)): out_chs = [out_chs] @@ -138,7 +137,7 @@ def create_stem( in_c = in_chans for i, out_c in enumerate(out_chs): conv_name = f'conv{i + 1}' - stem.add_module(conv_name, ConvBnAct( + stem.add_module(conv_name, ConvNormAct( in_c, out_c, kernel_size, stride=stride if i == 0 else 1, act_layer=act_layer, norm_layer=norm_layer)) in_c = out_c @@ -161,12 +160,14 @@ class ResBottleneck(nn.Module): attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): super(ResBottleneck, self).__init__() mid_chs = int(round(out_chs * bottle_ratio)) - ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block) + ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer) - self.conv1 = ConvBnAct(in_chs, mid_chs, kernel_size=1, **ckwargs) - self.conv2 = ConvBnAct(mid_chs, mid_chs, kernel_size=3, dilation=dilation, groups=groups, **ckwargs) + self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs) + self.conv2 = ConvNormActAa( + mid_chs, mid_chs, kernel_size=3, dilation=dilation, groups=groups, + aa_layer=aa_layer, drop_layer=drop_block, **ckwargs) self.attn2 = create_attn(attn_layer, channels=mid_chs) if not attn_last else None - self.conv3 = ConvBnAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs) + self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs) self.attn3 = create_attn(attn_layer, channels=out_chs) if attn_last else None self.drop_path = drop_path self.act3 = act_layer(inplace=True) @@ -201,9 +202,11 @@ class DarkBlock(nn.Module): drop_block=None, drop_path=None): super(DarkBlock, self).__init__() mid_chs = int(round(out_chs * bottle_ratio)) - ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block) - self.conv1 = ConvBnAct(in_chs, mid_chs, kernel_size=1, **ckwargs) - self.conv2 = ConvBnAct(mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups, **ckwargs) + ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer) + self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs) + self.conv2 = ConvNormActAa( + mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups, + aa_layer=aa_layer, drop_layer=drop_block, **ckwargs) self.attn = create_attn(attn_layer, channels=out_chs) self.drop_path = drop_path @@ -235,7 +238,7 @@ class CrossStage(nn.Module): conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) if stride != 1 or first_dilation != dilation: - self.conv_down = ConvBnAct( + self.conv_down = ConvNormActAa( in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs) prev_chs = down_chs @@ -246,7 +249,7 @@ class CrossStage(nn.Module): # FIXME this 1x1 expansion is pushed down into the cross and block paths in the darknet cfgs. Also, # there is also special case for the first stage for some of the model that results in uneven split # across the two paths. I did it this way for simplicity for now. - self.conv_exp = ConvBnAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs) + self.conv_exp = ConvNormAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs) prev_chs = exp_chs // 2 # output of conv_exp is always split in two self.blocks = nn.Sequential() @@ -257,8 +260,8 @@ class CrossStage(nn.Module): prev_chs = block_out_chs # transition convs - self.conv_transition_b = ConvBnAct(prev_chs, exp_chs // 2, kernel_size=1, **conv_kwargs) - self.conv_transition = ConvBnAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs) + self.conv_transition_b = ConvNormAct(prev_chs, exp_chs // 2, kernel_size=1, **conv_kwargs) + self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs) def forward(self, x): if self.conv_down is not None: @@ -280,7 +283,7 @@ class DarkStage(nn.Module): super(DarkStage, self).__init__() first_dilation = first_dilation or dilation - self.conv_down = ConvBnAct( + self.conv_down = ConvNormActAa( in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'), aa_layer=block_kwargs.get('aa_layer', None)) @@ -437,7 +440,7 @@ def cspresnext50(pretrained=False, **kwargs): @register_model def cspresnext50_iabn(pretrained=False, **kwargs): - norm_layer = get_norm_act_layer('iabn') + norm_layer = get_norm_act_layer('iabn', act_layer='leaky_relu') return _create_cspnet('cspresnext50_iabn', pretrained=pretrained, norm_layer=norm_layer, **kwargs) @@ -448,7 +451,7 @@ def cspdarknet53(pretrained=False, **kwargs): @register_model def cspdarknet53_iabn(pretrained=False, **kwargs): - norm_layer = get_norm_act_layer('iabn') + norm_layer = get_norm_act_layer('iabn', act_layer='leaky_relu') return _create_cspnet('cspdarknet53_iabn', pretrained=pretrained, block_fn=DarkBlock, norm_layer=norm_layer, **kwargs) diff --git a/timm/models/densenet.py b/timm/models/densenet.py index 38a19727..7be15f49 100644 --- a/timm/models/densenet.py +++ b/timm/models/densenet.py @@ -14,7 +14,7 @@ from torch.jit.annotations import List from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import BatchNormAct2d, create_norm_act, BlurPool2d, create_classifier +from .layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier from .registry import register_model __all__ = ['DenseNet'] @@ -370,7 +370,7 @@ def densenet264d_iabn(pretrained=False, **kwargs): r"""Densenet-264 model with deep stem and Inplace-ABN """ def norm_act_fn(num_features, **kwargs): - return create_norm_act('iabn', num_features, **kwargs) + return create_norm_act_layer('iabn', num_features, act_layer='leaky_relu', **kwargs) model = _create_densenet( 'densenet264d_iabn', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep', norm_layer=norm_act_fn, pretrained=pretrained, **kwargs) diff --git a/timm/models/dpn.py b/timm/models/dpn.py index c4e380b1..07e4a128 100644 --- a/timm/models/dpn.py +++ b/timm/models/dpn.py @@ -16,7 +16,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import BatchNormAct2d, ConvBnAct, create_conv2d, create_classifier +from .layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier from .registry import register_model __all__ = ['DPN'] @@ -180,7 +180,7 @@ class DPN(nn.Module): blocks = OrderedDict() # conv1 - blocks['conv1_1'] = ConvBnAct( + blocks['conv1_1'] = ConvNormAct( in_chans, num_init_features, kernel_size=3 if small else 7, stride=2, norm_layer=norm_layer) blocks['conv1_pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) self.feature_info = [dict(num_chs=num_init_features, reduction=2, module='features.conv1_1')] diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 3d50b704..b38b3c0e 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -45,7 +45,7 @@ from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficien round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .features import FeatureInfo, FeatureHooks from .helpers import build_model_with_cfg, default_cfg_for_features -from .layers import create_conv2d, create_classifier +from .layers import create_conv2d, create_classifier, get_norm_act_layer, EvoNorm2dS0, GroupNormAct from .registry import register_model __all__ = ['EfficientNet', 'EfficientNetFeatures'] @@ -117,6 +117,20 @@ default_cfgs = { 'efficientnet_l2': _cfg( url='', input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961), + # FIXME experimental + 'efficientnet_b0_gn': _cfg( + url=''), + 'efficientnet_b0_g8': _cfg( + url=''), + 'efficientnet_b0_g16_evos': _cfg( + url=''), + 'efficientnet_b3_gn': _cfg( + url='', + input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0), + 'efficientnet_b3_g8_gn': _cfg( + url='', + input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0), + 'efficientnet_es': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth'), 'efficientnet_em': _cfg( @@ -431,6 +445,7 @@ class EfficientNet(nn.Module): super(EfficientNet, self).__init__() act_layer = act_layer or nn.ReLU norm_layer = norm_layer or nn.BatchNorm2d + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) se_layer = se_layer or SqueezeExcite self.num_classes = num_classes self.num_features = num_features @@ -440,8 +455,7 @@ class EfficientNet(nn.Module): if not fix_stem: stem_size = round_chs_fn(stem_size) self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) - self.bn1 = norm_layer(stem_size) - self.act1 = act_layer(inplace=True) + self.bn1 = norm_act_layer(stem_size, inplace=True) # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( @@ -453,17 +467,16 @@ class EfficientNet(nn.Module): # Head + Pooling self.conv_head = create_conv2d(head_chs, self.num_features, 1, padding=pad_type) - self.bn2 = norm_layer(self.num_features) - self.act2 = act_layer(inplace=True) + self.bn2 = norm_act_layer(self.num_features, inplace=True) self.global_pool, self.classifier = create_classifier( self.num_features, self.num_classes, pool_type=global_pool) efficientnet_init_weights(self) def as_sequential(self): - layers = [self.conv_stem, self.bn1, self.act1] + layers = [self.conv_stem, self.bn1] layers.extend(self.blocks) - layers.extend([self.conv_head, self.bn2, self.act2, self.global_pool]) + layers.extend([self.conv_head, self.bn2, self.global_pool]) layers.extend([nn.Dropout(self.drop_rate), self.classifier]) return nn.Sequential(*layers) @@ -478,11 +491,9 @@ class EfficientNet(nn.Module): def forward_features(self, x): x = self.conv_stem(x) x = self.bn1(x) - x = self.act1(x) x = self.blocks(x) x = self.conv_head(x) x = self.bn2(x) - x = self.act2(x) return x def forward(self, x): @@ -506,6 +517,7 @@ class EfficientNetFeatures(nn.Module): super(EfficientNetFeatures, self).__init__() act_layer = act_layer or nn.ReLU norm_layer = norm_layer or nn.BatchNorm2d + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) se_layer = se_layer or SqueezeExcite self.drop_rate = drop_rate @@ -513,8 +525,7 @@ class EfficientNetFeatures(nn.Module): if not fix_stem: stem_size = round_chs_fn(stem_size) self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) - self.bn1 = norm_layer(stem_size) - self.act1 = act_layer(inplace=True) + self.bn1 = norm_act_layer(stem_size, inplace=True) # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( @@ -536,7 +547,6 @@ class EfficientNetFeatures(nn.Module): def forward(self, x) -> List[torch.Tensor]: x = self.conv_stem(x) x = self.bn1(x) - x = self.act1(x) if self.feature_hooks is None: features = [] if 0 in self._stage_out_idx: @@ -767,7 +777,9 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): return model -def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): +def _gen_efficientnet( + variant, channel_multiplier=1.0, depth_multiplier=1.0, channel_divisor=8, + group_size=None, pretrained=False, **kwargs): """Creates an EfficientNet model. Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py @@ -800,9 +812,9 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre ['ir_r4_k5_s2_e6_c192_se0.25'], ['ir_r1_k3_s1_e6_c320_se0.25'], ] - round_chs_fn = partial(round_channels, multiplier=channel_multiplier) + round_chs_fn = partial(round_channels, multiplier=channel_multiplier, divisor=channel_divisor) model_kwargs = dict( - block_args=decode_arch_def(arch_def, depth_multiplier), + block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size), num_features=round_chs_fn(1280), stem_size=32, round_chs_fn=round_chs_fn, @@ -814,7 +826,8 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre return model -def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): +def _gen_efficientnet_edge( + variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs): """ Creates an EfficientNet-EdgeTPU model Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu @@ -832,7 +845,7 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0 ] round_chs_fn = partial(round_channels, multiplier=channel_multiplier) model_kwargs = dict( - block_args=decode_arch_def(arch_def, depth_multiplier), + block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size), num_features=round_chs_fn(1280), stem_size=32, round_chs_fn=round_chs_fn, @@ -946,7 +959,7 @@ def _gen_efficientnetv2_base( def _gen_efficientnetv2_s( - variant, channel_multiplier=1.0, depth_multiplier=1.0, rw=False, pretrained=False, **kwargs): + variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, rw=False, pretrained=False, **kwargs): """ Creates an EfficientNet-V2 Small model Ref impl: https://github.com/google/automl/tree/master/efficientnetv2 @@ -972,7 +985,7 @@ def _gen_efficientnetv2_s( round_chs_fn = partial(round_channels, multiplier=channel_multiplier) model_kwargs = dict( - block_args=decode_arch_def(arch_def, depth_multiplier), + block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size), num_features=round_chs_fn(num_features), stem_size=24, round_chs_fn=round_chs_fn, @@ -1366,6 +1379,52 @@ def efficientnet_l2(pretrained=False, **kwargs): return model +# FIXME experimental group cong / GroupNorm / EvoNorm experiments +@register_model +def efficientnet_b0_gn(pretrained=False, **kwargs): + """ EfficientNet-B0 + GroupNorm""" + model = _gen_efficientnet( + 'efficientnet_b0_gn', norm_layer=partial(GroupNormAct, group_size=8), pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b0_g8(pretrained=False, **kwargs): + """ EfficientNet-B0 w/ group conv + BN""" + model = _gen_efficientnet( + 'efficientnet_b0_g8', group_size=8, pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b0_g16_evos(pretrained=False, **kwargs): + """ EfficientNet-B0 w/ group 16 conv + EvoNorm""" + model = _gen_efficientnet( + 'efficientnet_b0_g16_evos', group_size=16, channel_divisor=16, + norm_layer=partial(EvoNorm2dS0, group_size=16), pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b3_gn(pretrained=False, **kwargs): + """ EfficientNet-B3 w/ GroupNorm """ + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b3_gn', channel_multiplier=1.2, depth_multiplier=1.4, channel_divisor=16, + norm_layer=partial(GroupNormAct, group_size=16), pretrained=pretrained, **kwargs) + return model + + +@register_model +def efficientnet_b3_g8_gn(pretrained=False, **kwargs): + """ EfficientNet-B3 w/ grouped conv + BN""" + # NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2 + model = _gen_efficientnet( + 'efficientnet_b3_g8_gn', channel_multiplier=1.2, depth_multiplier=1.4, group_size=8, channel_divisor=16, + norm_layer=partial(GroupNormAct, group_size=16), pretrained=pretrained, **kwargs) + return model + + @register_model def efficientnet_es(pretrained=False, **kwargs): """ EfficientNet-Edge Small. """ @@ -1373,6 +1432,7 @@ def efficientnet_es(pretrained=False, **kwargs): 'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) return model + @register_model def efficientnet_es_pruned(pretrained=False, **kwargs): """ EfficientNet-Edge Small Pruned. For more info: https://github.com/DeGirum/pruned-models/releases/tag/efficientnet_v1.0""" diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index b1fec449..0e91319b 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -2,18 +2,31 @@ Hacked together by / Copyright 2020 Ross Wightman """ +import math import torch import torch.nn as nn from torch.nn import functional as F -from .layers import create_conv2d, drop_path, make_divisible, create_act_layer -from .layers.activations import sigmoid +from .layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer __all__ = [ 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual'] +def num_groups(group_size, channels): + if not group_size: # 0 or None + return 1 # normal conv with 1 group + else: + # NOTE group_size == 1 -> depthwise conv + #assert channels % group_size == 0 + if channels % group_size != 0: + num_groups = math.floor(channels / group_size) + print(channels, group_size, num_groups) + return int(num_groups) + return channels // group_size + + class SqueezeExcite(nn.Module): """ Squeeze-and-Excitation w/ specific features for EfficientNet/MobileNet family @@ -51,31 +64,30 @@ class ConvBnAct(nn.Module): """ Conv + Norm Layer + Activation w/ optional skip connection """ def __init__( - self, in_chs, out_chs, kernel_size, stride=1, dilation=1, pad_type='', + self, in_chs, out_chs, kernel_size, stride=1, dilation=1, group_size=0, pad_type='', skip=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_path_rate=0.): super(ConvBnAct, self).__init__() - self.has_residual = skip and stride == 1 and in_chs == out_chs - self.drop_path_rate = drop_path_rate - self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type) - self.bn1 = norm_layer(out_chs) - self.act1 = act_layer(inplace=True) + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + groups = num_groups(group_size, in_chs) + self.has_skip = skip and stride == 1 and in_chs == out_chs + + self.conv = create_conv2d( + in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, groups=groups, padding=pad_type) + self.bn1 = norm_act_layer(out_chs, inplace=True) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() def feature_info(self, location): if location == 'expansion': # output of conv after act, same as block coutput - info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels) + return dict(module='bn1', hook_type='forward', num_chs=self.conv.out_channels) else: # location == 'bottleneck', block output - info = dict(module='', hook_type='', num_chs=self.conv.out_channels) - return info + return dict(module='', hook_type='', num_chs=self.conv.out_channels) def forward(self, x): shortcut = x x = self.conv(x) x = self.bn1(x) - x = self.act1(x) - if self.has_residual: - if self.drop_path_rate > 0.: - x = drop_path(x, self.drop_path_rate, self.training) - x += shortcut + if self.has_skip: + x = x + self.drop_path(shortcut) return x @@ -85,50 +97,41 @@ class DepthwiseSeparableConv(nn.Module): (factor of 1.0). This is an alternative to having a IR with an optional first pw conv. """ def __init__( - self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', + self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='', noskip=False, pw_kernel_size=1, pw_act=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.): super(DepthwiseSeparableConv, self).__init__() - self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + groups = num_groups(group_size, in_chs) + self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip self.has_pw_act = pw_act # activation after point-wise conv - self.drop_path_rate = drop_path_rate self.conv_dw = create_conv2d( - in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True) - self.bn1 = norm_layer(in_chs) - self.act1 = act_layer(inplace=True) + in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, groups=groups) + self.bn1 = norm_act_layer(in_chs, inplace=True) # Squeeze-and-excitation self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity() self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) - self.bn2 = norm_layer(out_chs) - self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity() + self.bn2 = norm_act_layer(out_chs, inplace=True, apply_act=self.has_pw_act) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() def feature_info(self, location): if location == 'expansion': # after SE, input to PW - info = dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels) + return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels) else: # location == 'bottleneck', block output - info = dict(module='', hook_type='', num_chs=self.conv_pw.out_channels) - return info + return dict(module='', hook_type='', num_chs=self.conv_pw.out_channels) def forward(self, x): shortcut = x - x = self.conv_dw(x) x = self.bn1(x) - x = self.act1(x) - x = self.se(x) - x = self.conv_pw(x) x = self.bn2(x) - x = self.act2(x) - - if self.has_residual: - if self.drop_path_rate > 0.: - x = drop_path(x, self.drop_path_rate, self.training) - x += shortcut + if self.has_skip: + x = x + self.drop_path(shortcut) return x @@ -143,66 +146,51 @@ class InvertedResidual(nn.Module): """ def __init__( - self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', + self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='', noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.): super(InvertedResidual, self).__init__() + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) conv_kwargs = conv_kwargs or {} mid_chs = make_divisible(in_chs * exp_ratio) - self.has_residual = (in_chs == out_chs and stride == 1) and not noskip - self.drop_path_rate = drop_path_rate + groups = num_groups(group_size, mid_chs) + self.has_skip = (in_chs == out_chs and stride == 1) and not noskip # Point-wise expansion self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) - self.bn1 = norm_layer(mid_chs) - self.act1 = act_layer(inplace=True) + self.bn1 = norm_act_layer(mid_chs, inplace=True) # Depth-wise convolution self.conv_dw = create_conv2d( mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, - padding=pad_type, depthwise=True, **conv_kwargs) - self.bn2 = norm_layer(mid_chs) - self.act2 = act_layer(inplace=True) + groups=groups, padding=pad_type, **conv_kwargs) + self.bn2 = norm_act_layer(mid_chs, inplace=True) # Squeeze-and-excitation self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() # Point-wise linear projection self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) - self.bn3 = norm_layer(out_chs) + self.bn3 = norm_act_layer(out_chs, apply_act=False) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() def feature_info(self, location): if location == 'expansion': # after SE, input to PWL - info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) + return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) else: # location == 'bottleneck', block output - info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels) - return info + return dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels) def forward(self, x): shortcut = x - - # Point-wise expansion x = self.conv_pw(x) x = self.bn1(x) - x = self.act1(x) - - # Depth-wise convolution x = self.conv_dw(x) x = self.bn2(x) - x = self.act2(x) - - # Squeeze-and-excitation x = self.se(x) - - # Point-wise linear projection x = self.conv_pwl(x) x = self.bn3(x) - - if self.has_residual: - if self.drop_path_rate > 0.: - x = drop_path(x, self.drop_path_rate, self.training) - x += shortcut - + if self.has_skip: + x = x + self.drop_path(shortcut) return x @@ -210,7 +198,7 @@ class CondConvResidual(InvertedResidual): """ Inverted residual block w/ CondConv routing""" def __init__( - self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', + self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='', noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.): @@ -218,8 +206,8 @@ class CondConvResidual(InvertedResidual): conv_kwargs = dict(num_experts=self.num_experts) super(CondConvResidual, self).__init__( - in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type, - act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, + in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, group_size=group_size, + pad_type=pad_type, act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, pw_kernel_size=pw_kernel_size, se_layer=se_layer, norm_layer=norm_layer, conv_kwargs=conv_kwargs, drop_path_rate=drop_path_rate) @@ -227,32 +215,17 @@ class CondConvResidual(InvertedResidual): def forward(self, x): shortcut = x - - # CondConv routing - pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) + pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) # CondConv routing routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs)) - - # Point-wise expansion x = self.conv_pw(x, routing_weights) x = self.bn1(x) - x = self.act1(x) - - # Depth-wise convolution x = self.conv_dw(x, routing_weights) x = self.bn2(x) - x = self.act2(x) - - # Squeeze-and-excitation x = self.se(x) - - # Point-wise linear projection x = self.conv_pwl(x, routing_weights) x = self.bn3(x) - - if self.has_residual: - if self.drop_path_rate > 0.: - x = drop_path(x, self.drop_path_rate, self.training) - x += shortcut + if self.has_skip: + x = x + self.drop_path(shortcut) return x @@ -269,55 +242,44 @@ class EdgeResidual(nn.Module): """ def __init__( - self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, pad_type='', + self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, group_size=0, pad_type='', force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.): super(EdgeResidual, self).__init__() + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) if force_in_chs > 0: mid_chs = make_divisible(force_in_chs * exp_ratio) else: mid_chs = make_divisible(in_chs * exp_ratio) - self.has_residual = (in_chs == out_chs and stride == 1) and not noskip - self.drop_path_rate = drop_path_rate + groups = num_groups(group_size, in_chs) + self.has_skip = (in_chs == out_chs and stride == 1) and not noskip # Expansion convolution self.conv_exp = create_conv2d( - in_chs, mid_chs, exp_kernel_size, stride=stride, dilation=dilation, padding=pad_type) - self.bn1 = norm_layer(mid_chs) - self.act1 = act_layer(inplace=True) + in_chs, mid_chs, exp_kernel_size, stride=stride, dilation=dilation, groups=groups, padding=pad_type) + self.bn1 = norm_act_layer(mid_chs, inplace=True) # Squeeze-and-excitation self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() # Point-wise linear projection self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type) - self.bn2 = norm_layer(out_chs) + self.bn2 = norm_act_layer(out_chs, apply_act=False) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity() def feature_info(self, location): if location == 'expansion': # after SE, before PWL - info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) + return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) else: # location == 'bottleneck', block output - info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels) - return info + return dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels) def forward(self, x): shortcut = x - - # Expansion convolution x = self.conv_exp(x) x = self.bn1(x) - x = self.act1(x) - - # Squeeze-and-excitation x = self.se(x) - - # Point-wise linear projection x = self.conv_pwl(x) x = self.bn2(x) - - if self.has_residual: - if self.drop_path_rate > 0.: - x = drop_path(x, self.drop_path_rate, self.training) - x += shortcut - + if self.has_skip: + x = x + self.drop_path(shortcut) return x diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index a23e8273..a102a872 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -139,60 +139,52 @@ def _decode_block_str(block_str): exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def - num_repeat = int(options['r']) + # each type of block has different valid arguments, fill accordingly + block_args = dict( + block_type=block_type, + out_chs=int(options['c']), + stride=int(options['s']), + act_layer=act_layer, + ) if block_type == 'ir': - block_args = dict( - block_type=block_type, + block_args.update(dict( dw_kernel_size=_parse_ksize(options['k']), exp_kernel_size=exp_kernel_size, pw_kernel_size=pw_kernel_size, - out_chs=int(options['c']), exp_ratio=float(options['e']), se_ratio=float(options['se']) if 'se' in options else 0., - stride=int(options['s']), - act_layer=act_layer, noskip=skip is False, - ) + )) if 'cc' in options: block_args['num_experts'] = int(options['cc']) elif block_type == 'ds' or block_type == 'dsa': - block_args = dict( - block_type=block_type, + block_args.update(dict( dw_kernel_size=_parse_ksize(options['k']), pw_kernel_size=pw_kernel_size, - out_chs=int(options['c']), se_ratio=float(options['se']) if 'se' in options else 0., - stride=int(options['s']), - act_layer=act_layer, pw_act=block_type == 'dsa', noskip=block_type == 'dsa' or skip is False, - ) + )) elif block_type == 'er': - block_args = dict( - block_type=block_type, + block_args.update(dict( exp_kernel_size=_parse_ksize(options['k']), pw_kernel_size=pw_kernel_size, - out_chs=int(options['c']), exp_ratio=float(options['e']), force_in_chs=force_in_chs, se_ratio=float(options['se']) if 'se' in options else 0., - stride=int(options['s']), - act_layer=act_layer, noskip=skip is False, - ) + )) elif block_type == 'cn': - block_args = dict( - block_type=block_type, + block_args.update(dict( kernel_size=int(options['k']), - out_chs=int(options['c']), - stride=int(options['s']), - act_layer=act_layer, skip=skip is True, - ) + )) else: assert False, 'Unknown block type (%s)' % block_type + if 'gs' in options: + block_args['group_size'] = options['gs'] return block_args, num_repeat @@ -235,7 +227,27 @@ def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='c return sa_scaled -def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False): +def decode_arch_def( + arch_def, + depth_multiplier=1.0, + depth_trunc='ceil', + experts_multiplier=1, + fix_first_last=False, + group_size=None, +): + """ Decode block architecture definition strings -> block kwargs + + Args: + arch_def: architecture definition strings, list of list of strings + depth_multiplier: network depth multiplier + depth_trunc: networ depth truncation mode when applying multiplier + experts_multiplier: CondConv experts multiplier + fix_first_last: fix first and last block depths when multiplier is applied + group_size: group size override for all blocks that weren't explicitly set in arch string + + Returns: + list of list of block kwargs + """ arch_args = [] if isinstance(depth_multiplier, tuple): assert len(depth_multiplier) == len(arch_def) @@ -250,6 +262,8 @@ def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_ ba, rep = _decode_block_str(block_str) if ba.get('num_experts', 0) > 0 and experts_multiplier > 1: ba['num_experts'] *= experts_multiplier + if group_size is not None: + ba.setdefault('group_size', group_size) stack_args.append(ba) repeats.append(rep) if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1): diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 0ed0c3af..1319cc74 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -7,11 +7,11 @@ from .cond_conv2d import CondConv2d, get_condconv_initializer from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ set_layer_config from .conv2d_same import Conv2dSame, conv2d_same -from .conv_bn_act import ConvBnAct +from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct from .create_act import create_act_layer, get_act_layer, get_act_fn from .create_attn import get_attn, create_attn from .create_conv2d import create_conv2d -from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act +from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ @@ -32,7 +32,7 @@ from .patch_embed import PatchEmbed from .pool2d_same import AvgPool2dSame, create_pool2d from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite from .selective_kernel import SelectiveKernel -from .separable_conv import SeparableConv2d, SeparableConvBnAct +from .separable_conv import SeparableConv2d, SeparableConvNormAct from .space_to_depth import SpaceToDepthModule from .split_attn import SplitAttn from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model diff --git a/timm/models/layers/cbam.py b/timm/models/layers/cbam.py index bacf5cf0..576a8306 100644 --- a/timm/models/layers/cbam.py +++ b/timm/models/layers/cbam.py @@ -11,7 +11,7 @@ import torch from torch import nn as nn import torch.nn.functional as F -from .conv_bn_act import ConvBnAct +from .conv_bn_act import ConvNormAct from .create_act import create_act_layer, get_act_layer from .helpers import make_divisible @@ -56,7 +56,7 @@ class SpatialAttn(nn.Module): """ def __init__(self, kernel_size=7, gate_layer='sigmoid'): super(SpatialAttn, self).__init__() - self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None) + self.conv = ConvNormAct(2, 1, kernel_size, apply_act=False) self.gate = create_act_layer(gate_layer) def forward(self, x): @@ -70,7 +70,7 @@ class LightSpatialAttn(nn.Module): """ def __init__(self, kernel_size=7, gate_layer='sigmoid'): super(LightSpatialAttn, self).__init__() - self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None) + self.conv = ConvNormAct(1, 1, kernel_size, apply_act=False) self.gate = create_act_layer(gate_layer) def forward(self, x): diff --git a/timm/models/layers/conv_bn_act.py b/timm/models/layers/conv_bn_act.py index 33005c37..af010573 100644 --- a/timm/models/layers/conv_bn_act.py +++ b/timm/models/layers/conv_bn_act.py @@ -5,14 +5,46 @@ Hacked together by / Copyright 2020 Ross Wightman from torch import nn as nn from .create_conv2d import create_conv2d -from .create_norm_act import convert_norm_act +from .create_norm_act import get_norm_act_layer -class ConvBnAct(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, - bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, - drop_block=None): - super(ConvBnAct, self).__init__() +class ConvNormAct(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, + bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, drop_layer=None): + super(ConvNormAct, self).__init__() + self.conv = create_conv2d( + in_channels, out_channels, kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=bias) + + # NOTE for backwards compatibility with models that use separate norm and act layer definitions + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` + norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} + self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) + + @property + def in_channels(self): + return self.conv.in_channels + + @property + def out_channels(self): + return self.conv.out_channels + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +ConvBnAct = ConvNormAct + + +class ConvNormActAa(nn.Module): + def __init__( + self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, + bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None): + super(ConvNormActAa, self).__init__() use_aa = aa_layer is not None self.conv = create_conv2d( @@ -20,9 +52,11 @@ class ConvBnAct(nn.Module): padding=padding, dilation=dilation, groups=groups, bias=bias) # NOTE for backwards compatibility with models that use separate norm and act layer definitions - norm_act_layer = convert_norm_act(norm_layer, act_layer) - self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block) - self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + # NOTE for backwards (weight) compatibility, norm layer name remains `.bn` + norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} + self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) + self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else nn.Identity() @property def in_channels(self): @@ -35,6 +69,5 @@ class ConvBnAct(nn.Module): def forward(self, x): x = self.conv(x) x = self.bn(x) - if self.aa is not None: - x = self.aa(x) + x = self.aa(x) return x diff --git a/timm/models/layers/create_conv2d.py b/timm/models/layers/create_conv2d.py index 3a0cc03a..ac9489ce 100644 --- a/timm/models/layers/create_conv2d.py +++ b/timm/models/layers/create_conv2d.py @@ -16,7 +16,12 @@ def create_conv2d(in_channels, out_channels, kernel_size, **kwargs): """ if isinstance(kernel_size, list): assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently - assert 'groups' not in kwargs # MixedConv groups are defined by kernel list + if 'groups' in kwargs: + groups = kwargs.pop('groups') + if groups == in_channels: + kwargs['depthwise'] = True + else: + assert groups == 1 # We're going to use only lists for defining the MixedConv2d kernel groups, # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) diff --git a/timm/models/layers/create_norm_act.py b/timm/models/layers/create_norm_act.py index 5d4894a0..cd15c2f8 100644 --- a/timm/models/layers/create_norm_act.py +++ b/timm/models/layers/create_norm_act.py @@ -11,12 +11,15 @@ import functools from .evo_norm import * from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d -from .norm_act import BatchNormAct2d, GroupNormAct +from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d from .inplace_abn import InplaceAbn _NORM_ACT_MAP = dict( batchnorm=BatchNormAct2d, + batchnorm2d=BatchNormAct2d, groupnorm=GroupNormAct, + layernorm=LayerNormAct, + layernorm2d=LayerNormAct2d, evonormb0=EvoNorm2dB0, evonormb1=EvoNorm2dB1, evonormb2=EvoNorm2dB2, @@ -33,28 +36,19 @@ _NORM_ACT_MAP = dict( ) _NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()} # has act_layer arg to define act type -_NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, FilterResponseNormAct2d, InplaceAbn} +_NORM_ACT_REQUIRES_ARG = { + BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn} -def get_norm_act_layer(layer_name): - layer_name = layer_name.replace('_', '').lower().split('-')[0] - layer = _NORM_ACT_MAP.get(layer_name, None) - assert layer is not None, "Invalid norm_act layer (%s)" % layer_name - return layer - - -def create_norm_act(layer_name, num_features, apply_act=True, jit=False, **kwargs): - layer_parts = layer_name.split('-') # e.g. batchnorm-leaky_relu - assert len(layer_parts) in (1, 2) - layer = get_norm_act_layer(layer_parts[0]) - #activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection? +def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs): + layer = get_norm_act_layer(layer_name, act_layer=act_layer) layer_instance = layer(num_features, apply_act=apply_act, **kwargs) if jit: layer_instance = torch.jit.script(layer_instance) return layer_instance -def convert_norm_act(norm_layer, act_layer): +def get_norm_act_layer(norm_layer, act_layer=None): assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) norm_act_kwargs = {} @@ -65,7 +59,8 @@ def convert_norm_act(norm_layer, act_layer): norm_layer = norm_layer.func if isinstance(norm_layer, str): - norm_act_layer = get_norm_act_layer(norm_layer) + layer_name = norm_layer.replace('_', '').lower().split('-')[0] + norm_act_layer = _NORM_ACT_MAP.get(layer_name, None) elif norm_layer in _NORM_ACT_TYPES: norm_act_layer = norm_layer elif isinstance(norm_layer, types.FunctionType): @@ -77,6 +72,10 @@ def convert_norm_act(norm_layer, act_layer): norm_act_layer = BatchNormAct2d elif type_name.startswith('groupnorm'): norm_act_layer = GroupNormAct + elif type_name.startswith('layernorm2d'): + norm_act_layer = LayerNormAct2d + elif type_name.startswith('layernorm'): + norm_act_layer = LayerNormAct else: assert False, f"No equivalent norm_act layer for {type_name}" diff --git a/timm/models/layers/drop.py b/timm/models/layers/drop.py index 90c1933a..fb20dfce 100644 --- a/timm/models/layers/drop.py +++ b/timm/models/layers/drop.py @@ -20,7 +20,7 @@ import torch.nn.functional as F def drop_block_2d( - x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, + x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, with_noise: bool = False, inplace: bool = False, batchwise: bool = False): """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf @@ -32,7 +32,7 @@ def drop_block_2d( clipped_block_size = min(block_size, min(W, H)) # seed_drop_rate, the gamma parameter gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( - (W - block_size + 1) * (H - block_size + 1)) + (W - block_size + 1) * (H - block_size + 1)) # Forces the block to be inside the feature map. w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device)) @@ -104,14 +104,16 @@ def drop_block_fast_2d( class DropBlock2d(nn.Module): """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf """ - def __init__(self, - drop_prob=0.1, - block_size=7, - gamma_scale=1.0, - with_noise=False, - inplace=False, - batchwise=False, - fast=True): + + def __init__( + self, + drop_prob=0.1, + block_size=7, + gamma_scale=1.0, + with_noise=False, + inplace=False, + batchwise=False, + fast=True): super(DropBlock2d, self).__init__() self.drop_prob = drop_prob self.gamma_scale = gamma_scale @@ -155,6 +157,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: b class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """ + def __init__(self, drop_prob=None, scale_by_keep=True): super(DropPath, self).__init__() self.drop_prob = drop_prob diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index d42c502c..f48d9a83 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -23,6 +23,7 @@ GPU, similar train speeds for EvoNormS variants and BatchNorm. Hacked together by / Copyright 2020 Ross Wightman """ +from typing import Sequence, Union import torch import torch.nn as nn @@ -33,41 +34,57 @@ from .trace_utils import _assert def instance_std(x, eps: float = 1e-5): - rms = x.float().var(dim=(2, 3), unbiased=False, keepdim=True).add(eps).sqrt().to(x.dtype) - return rms.expand(x.shape) + std = x.float().var(dim=(2, 3), unbiased=False, keepdim=True).add(eps).sqrt().to(x.dtype) + return std.expand(x.shape) + + +def instance_std_tpu(x, eps: float = 1e-5): + std = manual_var(x, dim=(2, 3)).add(eps).sqrt() + return std.expand(x.shape) +# instance_std = instance_std_tpu def instance_rms(x, eps: float = 1e-5): - rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).sqrt().to(dtype=x.dtype) + rms = x.float().square().mean(dim=(2, 3), keepdim=True).add(eps).sqrt().to(x.dtype) return rms.expand(x.shape) +def manual_var(x, dim: Union[int, Sequence[int]], diff_sqm: bool = False): + xm = x.mean(dim=dim, keepdim=True) + if diff_sqm: + # difference of squared mean and mean squared, faster on TPU can be less stable + var = ((x * x).mean(dim=dim, keepdim=True) - (xm * xm)).clamp(0) + else: + var = ((x - xm) * (x - xm)).mean(dim=dim, keepdim=True) + return var + + def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False): B, C, H, W = x.shape x_dtype = x.dtype _assert(C % groups == 0, '') - # x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues - # std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt() - x = x.reshape(B, groups, C // groups, H, W) - std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt() - return std.expand(x.shape).reshape(B, C, H, W).to(x_dtype) + if flatten: + x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues + std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype) + else: + x = x.reshape(B, groups, C // groups, H, W) + std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype) + return std.expand(x.shape).reshape(B, C, H, W) -def group_std_tpu(x, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False): +def group_std_tpu(x, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False, flatten: bool = False): # This is a workaround for some stability / odd behaviour of .var and .std # running on PyTorch XLA w/ TPUs. These manual var impl are producing much better results B, C, H, W = x.shape _assert(C % groups == 0, '') - x_dtype = x.dtype - x = x.float().reshape(B, groups, C // groups, H, W) - xm = x.mean(dim=(2, 3, 4), keepdim=True) - if diff_sqm: - # difference of squared mean and mean squared, faster on TPU - var = (x.square().mean(dim=(2, 3, 4), keepdim=True) - xm.square()).clamp(0) + if flatten: + x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues + var = manual_var(x, dim=-1, diff_sqm=diff_sqm) else: - var = (x - xm).square().mean(dim=(2, 3, 4), keepdim=True) - return var.add(eps).sqrt().expand(x.shape).reshape(B, C, H, W).to(x_dtype) -# group_std = group_std_tpu # temporary, for TPU / PT XLA + x = x.reshape(B, groups, C // groups, H, W) + var = manual_var(x, dim=(2, 3, 4), diff_sqm=diff_sqm) + return var.add(eps).sqrt().expand(x.shape).reshape(B, C, H, W) +#group_std = group_std_tpu # FIXME TPU temporary def group_rms(x, groups: int = 32, eps: float = 1e-5): @@ -75,8 +92,8 @@ def group_rms(x, groups: int = 32, eps: float = 1e-5): _assert(C % groups == 0, '') x_dtype = x.dtype x = x.reshape(B, groups, C // groups, H, W) - sqm = x.square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(dtype=x_dtype) - return sqm.expand(x.shape).reshape(B, C, H, W) + rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(dtype=x_dtype) + return rms.expand(x.shape).reshape(B, C, H, W) class EvoNorm2dB0(nn.Module): @@ -104,6 +121,7 @@ class EvoNorm2dB0(nn.Module): if self.v is not None: if self.training: var = x.float().var(dim=(0, 2, 3), unbiased=False) + # var = manual_var(x, dim=(0, 2, 3)).squeeze() n = x.numel() / x.shape[1] self.running_var.copy_( self.running_var * (1 - self.momentum) + @@ -230,7 +248,7 @@ class EvoNorm2dS0a(EvoNorm2dS0): d = group_std(x, self.groups, self.eps) if self.v is not None: v = self.v.view(v_shape).to(dtype=x_dtype) - x = x * (x * v).sigmoid_() + x = x * (x * v).sigmoid() x = x / d return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) diff --git a/timm/models/layers/inplace_abn.py b/timm/models/layers/inplace_abn.py index 3aae7cf5..a8088933 100644 --- a/timm/models/layers/inplace_abn.py +++ b/timm/models/layers/inplace_abn.py @@ -38,7 +38,7 @@ class InplaceAbn(nn.Module): """ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, - act_layer="leaky_relu", act_param=0.01, drop_block=None): + act_layer="leaky_relu", act_param=0.01, drop_layer=None): super(InplaceAbn, self).__init__() self.num_features = num_features self.affine = affine @@ -54,7 +54,7 @@ class InplaceAbn(nn.Module): self.act_name = 'elu' elif act_layer == nn.LeakyReLU: self.act_name = 'leaky_relu' - elif act_layer == nn.Identity: + elif act_layer is None or act_layer == nn.Identity: self.act_name = 'identity' else: assert False, f'Invalid act layer {act_layer.__name__} for IABN' diff --git a/timm/models/layers/non_local_attn.py b/timm/models/layers/non_local_attn.py index 881fa36d..670e8f24 100644 --- a/timm/models/layers/non_local_attn.py +++ b/timm/models/layers/non_local_attn.py @@ -8,7 +8,7 @@ import torch from torch import nn from torch.nn import functional as F -from .conv_bn_act import ConvBnAct +from .conv_bn_act import ConvNormAct from .helpers import make_divisible from .trace_utils import _assert @@ -74,10 +74,10 @@ class BilinearAttnTransform(nn.Module): def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(BilinearAttnTransform, self).__init__() - self.conv1 = ConvBnAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer) + self.conv1 = ConvNormAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer) self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1)) self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size)) - self.conv2 = ConvBnAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.conv2 = ConvNormAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) self.block_size = block_size self.groups = groups self.in_channels = in_channels @@ -132,9 +132,9 @@ class BatNonLocalAttn(nn.Module): super().__init__() if rd_channels is None: rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) - self.conv1 = ConvBnAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.conv1 = ConvNormAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer) self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer) - self.conv2 = ConvBnAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.conv2 = ConvNormAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) self.dropout = nn.Dropout2d(p=drop_rate) def forward(self, x): diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index 2e15181f..ae3f75c6 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -1,5 +1,7 @@ """ Normalization + Activation Layers """ +from typing import Union, List + import torch from torch import nn as nn from torch.nn import functional as F @@ -14,12 +16,13 @@ class BatchNormAct2d(nn.BatchNorm2d): compatible with weights trained with separate bn, act. This is why we inherit from BN instead of composing it as a .bn member. """ - def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, - apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): + def __init__( + self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, + apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): super(BatchNormAct2d, self).__init__( num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) - if isinstance(act_layer, str): - act_layer = get_act_layer(act_layer) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + act_layer = get_act_layer(act_layer) # string -> nn.Module if act_layer is not None and apply_act: act_args = dict(inplace=True) if inplace else {} self.act = act_layer(**act_args) @@ -29,8 +32,8 @@ class BatchNormAct2d(nn.BatchNorm2d): def _forward_jit(self, x): """ A cut & paste of the contents of the PyTorch BatchNorm2d forward function """ - # exponential_average_factor is self.momentum set to - # (when it is available) only so that if gets updated + # exponential_average_factor is set to self.momentum + # (when it is available) only so that it gets updated # in ONNX graph when this node is exported to ONNX. if self.momentum is None: exponential_average_factor = 0.0 @@ -39,18 +42,38 @@ class BatchNormAct2d(nn.BatchNorm2d): if self.training and self.track_running_stats: # TODO: if statement only here to tell the jit to skip emitting this when it is None - if self.num_batches_tracked is not None: - self.num_batches_tracked += 1 + if self.num_batches_tracked is not None: # type: ignore[has-type] + self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type] if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum - x = F.batch_norm( - x, self.running_mean, self.running_var, self.weight, self.bias, - self.training or not self.track_running_stats, - exponential_average_factor, self.eps) - return x + r""" + Decide whether the mini-batch stats should be used for normalization rather than the buffers. + Mini-batch stats are used in training mode, and in eval mode when buffers are None. + """ + if self.training: + bn_training = True + else: + bn_training = (self.running_mean is None) and (self.running_var is None) + + r""" + Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be + passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are + used for normalization (i.e. in eval mode when buffers are not None). + """ + return F.batch_norm( + x, + # If buffers are not to be tracked, ensure that they won't be updated + self.running_mean if not self.training or self.track_running_stats else None, + self.running_var if not self.training or self.track_running_stats else None, + self.weight, + self.bias, + bn_training, + exponential_average_factor, + self.eps, + ) @torch.jit.ignore def _forward_python(self, x): @@ -62,17 +85,27 @@ class BatchNormAct2d(nn.BatchNorm2d): x = self._forward_jit(x) else: x = self._forward_python(x) + x = self.drop(x) x = self.act(x) return x +def _num_groups(num_channels, num_groups, group_size): + if group_size: + assert num_channels % group_size == 0 + return num_channels // group_size + return num_groups + + class GroupNormAct(nn.GroupNorm): # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args - def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True, - apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): - super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine) - if isinstance(act_layer, str): - act_layer = get_act_layer(act_layer) + def __init__( + self, num_channels, num_groups=32, eps=1e-5, affine=True, group_size=None, + apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): + super(GroupNormAct, self).__init__( + _num_groups(num_channels, num_groups, group_size), num_channels, eps=eps, affine=affine) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + act_layer = get_act_layer(act_layer) # string -> nn.Module if act_layer is not None and apply_act: act_args = dict(inplace=True) if inplace else {} self.act = act_layer(**act_args) @@ -81,5 +114,47 @@ class GroupNormAct(nn.GroupNorm): def forward(self, x): x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + x = self.drop(x) + x = self.act(x) + return x + + +class LayerNormAct(nn.LayerNorm): + def __init__( + self, normalization_shape: Union[int, List[int], torch.Size], eps=1e-5, affine=True, + apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): + super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + act_layer = get_act_layer(act_layer) # string -> nn.Module + if act_layer is not None and apply_act: + act_args = dict(inplace=True) if inplace else {} + self.act = act_layer(**act_args) + else: + self.act = nn.Identity() + + def forward(self, x): + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + x = self.drop(x) + x = self.act(x) + return x + + +class LayerNormAct2d(nn.LayerNorm): + def __init__( + self, num_channels, eps=1e-5, affine=True, + apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): + super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + act_layer = get_act_layer(act_layer) # string -> nn.Module + if act_layer is not None and apply_act: + act_args = dict(inplace=True) if inplace else {} + self.act = act_layer(**act_args) + else: + self.act = nn.Identity() + + def forward(self, x): + x = F.layer_norm( + x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + x = self.drop(x) x = self.act(x) return x diff --git a/timm/models/layers/pooled_attn.py b/timm/models/layers/pooled_attn.py new file mode 100644 index 00000000..40cf2b34 --- /dev/null +++ b/timm/models/layers/pooled_attn.py @@ -0,0 +1,143 @@ +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .helpers import to_2tuple +from .weight_init import trunc_normal_ + + +def rel_logits_1d(q, rel_k, permute_mask: List[int]): + """ Compute relative logits along one dimension + + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + + Args: + q: (batch, heads, height, width, dim) + rel_k: (2 * width - 1, dim) + permute_mask: permute output dim according to this + """ + B, H, W, dim = q.shape + x = (q @ rel_k.transpose(-1, -2)) + x = x.reshape(-1, W, 2 * W -1) + + # pad to shift from relative to absolute indexing + x_pad = F.pad(x, [0, 1]).flatten(1) + x_pad = F.pad(x_pad, [0, W - 1]) + + # reshape and slice out the padded elements + x_pad = x_pad.reshape(-1, W + 1, 2 * W - 1) + x = x_pad[:, :W, W - 1:] + + # reshape and tile + x = x.reshape(B, H, 1, W, W).expand(-1, -1, H, -1, -1) + return x.permute(permute_mask) + + +class PosEmbedRel(nn.Module): + """ Relative Position Embedding + As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 + Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 + """ + def __init__(self, feat_size, dim_head, scale): + super().__init__() + self.height, self.width = to_2tuple(feat_size) + self.dim_head = dim_head + self.scale = scale + self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * self.scale) + self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * self.scale) + + def forward(self, q): + B, num_heads, HW, _ = q.shape + + # relative logits in width dimension. + q = q.reshape(B * num_heads, self.height, self.width, -1) + rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4)) + + # relative logits in height dimension. + q = q.transpose(1, 2) + rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2)) + + rel_logits = rel_logits_h + rel_logits_w + rel_logits = rel_logits.reshape(B, num_heads, HW, HW) + return rel_logits + + +class BottleneckAttn(nn.Module): + """ Bottleneck Attention + Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605 + """ + def __init__(self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, qkv_bias=False): + super().__init__() + assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required' + dim_out = dim_out or dim + assert dim_out % num_heads == 0 + self.num_heads = num_heads + self.dim_out = dim_out + self.dim_head = dim_out // num_heads + self.scale = self.dim_head ** -0.5 + + self.qkv = nn.Conv2d(dim, self.dim_out * 3, 1, bias=qkv_bias) + + # NOTE I'm only supporting relative pos embedding for now + self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head, scale=self.scale) + + self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() + + self.reset_parameters() + + def reset_parameters(self): + trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) + trunc_normal_(self.pos_embed.height_rel, std=self.scale) + trunc_normal_(self.pos_embed.width_rel, std=self.scale) + + def forward(self, x): + B, C, H, W = x.shape + assert H == self.pos_embed.height + assert W == self.pos_embed.width + + x = self.qkv(x) # B, 3 * num_heads * dim_head, H, W + x = x.reshape(B, -1, self.dim_head, H * W).transpose(-1, -2) + q, k, v = torch.split(x, self.num_heads, dim=1) + + attn_logits = (q @ k.transpose(-1, -2)) * self.scale + attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W + + attn_out = attn_logits.softmax(dim=-1) + attn_out = (attn_out @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W + attn_out = self.pool(attn_out) + return attn_out + + +class PoolingAttention(nn.Module): + def __init__(self, in_features: int, attention_features: int, segments: int, max_pool_kernel: int): + super(PoolingAttention, self).__init__() + self.attn = nn.Linear(in_features, attention_features * 5) + self.segments = segments + self.max_pool_kernel = max_pool_kernel + + def forward(self, inp: torch.Tensor): # Shape: [Batch, Sequence, Features] + batch, sequence, features = inp.size() + assert sequence % self.segments == 0 + + qry, key, val, seg, loc = self.attn(inp).chunk(5, 2) # 5x Shape: [Batch, Sequence, AttentionFeatures] + + aggregated = qry.mean(1, keepdim=True) # Shape: [Batch, AttentionFeatures] + aggregated = torch.einsum("ba,bsa->bs", aggregated, key) # Shape: [Batch, Sequence] + aggregated = F.softmax(aggregated, 1) + aggregated = torch.einsum("bs,bsa,bza->bza", aggregated, val, + qry) # Shape: [Batch, Sequence, AttentionFeatures] + + pooled_sequence = sequence // self.segments + segment_max_pooled = seg.view(batch, pooled_sequence, self.segments, -1) + segment_max_pooled = segment_max_pooled.max(2, keepdim=True) # Shape: [Batch, PooledSequence, 1, AttentionFeatures] + segment_max_pooled = segment_max_pooled * qry.view(batch, pooled_sequence, self.segments, -1) # Shape: [Batch, PooledSequence, PoolSize, AttentionFeatures] + segment_max_pooled = segment_max_pooled.view(batch, sequence, -1) # Shape: [Batch, Sequence, AttentionFeatures] + + loc = loc.transpose(1, 2) # Shape: [Batch, AttentionFeatures, Sequence] + local_max_pooled = F.max_pool1d(loc, self.max_pool_kernel, 1, self.max_pool_kernel // 2) + local_max_pooled = local_max_pooled.transpose(1, 2) # Shape: [Batch, Sequence, AttentionFeatures] + + return aggregated + segment_max_pooled + local_max_pooled \ No newline at end of file diff --git a/timm/models/layers/selective_kernel.py b/timm/models/layers/selective_kernel.py index 1aeb9294..3d71e3aa 100644 --- a/timm/models/layers/selective_kernel.py +++ b/timm/models/layers/selective_kernel.py @@ -7,7 +7,7 @@ Hacked together by / Copyright 2020 Ross Wightman import torch from torch import nn as nn -from .conv_bn_act import ConvBnAct +from .conv_bn_act import ConvNormActAa from .helpers import make_divisible from .trace_utils import _assert @@ -20,8 +20,7 @@ def _kernel_valid(k): class SelectiveKernelAttn(nn.Module): - def __init__(self, channels, num_paths=2, attn_channels=32, - act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + def __init__(self, channels, num_paths=2, attn_channels=32, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): """ Selective Kernel Attention Module Selective Kernel attention mechanism factored out into its own module. @@ -51,7 +50,7 @@ class SelectiveKernel(nn.Module): def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1, rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True, - drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None): + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_layer=None): """ Selective Kernel Convolution Module As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications. @@ -72,9 +71,10 @@ class SelectiveKernel(nn.Module): keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations split_input (bool): split input channels evenly across each convolution branch, keeps param count lower, can be viewed as grouping by path, output expands to module out_channels count - drop_block (nn.Module): drop block module act_layer (nn.Module): activation layer to use norm_layer (nn.Module): batchnorm/norm layer to use + aa_layer (nn.Module): anti-aliasing module + drop_layer (nn.Module): spatial drop module in convs (drop block, etc) """ super(SelectiveKernel, self).__init__() out_channels = out_channels or in_channels @@ -97,15 +97,14 @@ class SelectiveKernel(nn.Module): groups = min(out_channels, groups) conv_kwargs = dict( - stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, - aa_layer=aa_layer) + stride=stride, groups=groups, act_layer=act_layer, norm_layer=norm_layer, + aa_layer=aa_layer, drop_layer=drop_layer) self.paths = nn.ModuleList([ - ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) + ConvNormActAa(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) for k, d in zip(kernel_size, dilation)]) attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor) self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) - self.drop_block = drop_block def forward(self, x): if self.split_input: diff --git a/timm/models/layers/separable_conv.py b/timm/models/layers/separable_conv.py index 1ddcb4e6..c081e02b 100644 --- a/timm/models/layers/separable_conv.py +++ b/timm/models/layers/separable_conv.py @@ -8,16 +8,16 @@ Hacked together by / Copyright 2020 Ross Wightman from torch import nn as nn from .create_conv2d import create_conv2d -from .create_norm_act import convert_norm_act +from .create_norm_act import get_norm_act_layer -class SeparableConvBnAct(nn.Module): +class SeparableConvNormAct(nn.Module): """ Separable Conv w/ trailing Norm and Activation """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, - apply_act=True, drop_block=None): - super(SeparableConvBnAct, self).__init__() + apply_act=True, drop_layer=None): + super(SeparableConvNormAct, self).__init__() self.conv_dw = create_conv2d( in_channels, int(in_channels * channel_multiplier), kernel_size, @@ -26,8 +26,9 @@ class SeparableConvBnAct(nn.Module): self.conv_pw = create_conv2d( int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) - norm_act_layer = convert_norm_act(norm_layer, act_layer) - self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block) + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {} + self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs) @property def in_channels(self): @@ -40,11 +41,13 @@ class SeparableConvBnAct(nn.Module): def forward(self, x): x = self.conv_dw(x) x = self.conv_pw(x) - if self.bn is not None: - x = self.bn(x) + x = self.bn(x) return x +SeparableConvBnAct = SeparableConvNormAct + + class SeparableConv2d(nn.Module): """ Separable Conv """ diff --git a/timm/models/layers/split_attn.py b/timm/models/layers/split_attn.py index dde601be..ac54f898 100644 --- a/timm/models/layers/split_attn.py +++ b/timm/models/layers/split_attn.py @@ -35,11 +35,10 @@ class SplitAttn(nn.Module): """ def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, - act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs): + act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs): super(SplitAttn, self).__init__() out_channels = out_channels or in_channels self.radix = radix - self.drop_block = drop_block mid_chs = out_channels * radix if rd_channels is None: attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor) @@ -51,6 +50,7 @@ class SplitAttn(nn.Module): in_channels, mid_chs, kernel_size, stride, padding, dilation, groups=groups * radix, bias=bias, **kwargs) self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() + self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.act0 = act_layer(inplace=True) self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() @@ -61,8 +61,7 @@ class SplitAttn(nn.Module): def forward(self, x): x = self.conv(x) x = self.bn0(x) - if self.drop_block is not None: - x = self.drop_block(x) + x = self.drop(x) x = self.act0(x) B, RC, H, W = x.shape diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index f810eb82..f49a35de 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -20,7 +20,7 @@ from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficien round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT from .features import FeatureInfo, FeatureHooks from .helpers import build_model_with_cfg, default_cfg_for_features -from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid +from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, get_norm_act_layer from .registry import register_model __all__ = ['MobileNetV3', 'MobileNetV3Features'] @@ -95,6 +95,7 @@ class MobileNetV3(nn.Module): super(MobileNetV3, self).__init__() act_layer = act_layer or nn.ReLU norm_layer = norm_layer or nn.BatchNorm2d + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) se_layer = se_layer or SqueezeExcite self.num_classes = num_classes self.num_features = num_features @@ -103,8 +104,7 @@ class MobileNetV3(nn.Module): # Stem stem_size = round_chs_fn(stem_size) self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) - self.bn1 = norm_layer(stem_size) - self.act1 = act_layer(inplace=True) + self.bn1 = norm_act_layer(stem_size, inplace=True) # Middle stages (IR/ER/DS Blocks) builder = EfficientNetBuilder( @@ -125,7 +125,7 @@ class MobileNetV3(nn.Module): efficientnet_init_weights(self) def as_sequential(self): - layers = [self.conv_stem, self.bn1, self.act1] + layers = [self.conv_stem, self.bn1] layers.extend(self.blocks) layers.extend([self.global_pool, self.conv_head, self.act2]) layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) @@ -144,7 +144,6 @@ class MobileNetV3(nn.Module): def forward_features(self, x): x = self.conv_stem(x) x = self.bn1(x) - x = self.act1(x) x = self.blocks(x) x = self.global_pool(x) x = self.conv_head(x) diff --git a/timm/models/nasnet.py b/timm/models/nasnet.py index 2afe82c3..9c257d9d 100644 --- a/timm/models/nasnet.py +++ b/timm/models/nasnet.py @@ -9,7 +9,7 @@ import torch.nn as nn import torch.nn.functional as F from .helpers import build_model_with_cfg -from .layers import ConvBnAct, create_conv2d, create_pool2d, create_classifier +from .layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier from .registry import register_model __all__ = ['NASNetALarge'] @@ -420,7 +420,7 @@ class NASNetALarge(nn.Module): channels = self.num_features // 24 # 24 is default value for the architecture - self.conv0 = ConvBnAct( + self.conv0 = ConvNormAct( in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2, norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False) diff --git a/timm/models/pnasnet.py b/timm/models/pnasnet.py index 99918156..208bccf3 100644 --- a/timm/models/pnasnet.py +++ b/timm/models/pnasnet.py @@ -13,7 +13,7 @@ import torch.nn as nn import torch.nn.functional as F from .helpers import build_model_with_cfg -from .layers import ConvBnAct, create_conv2d, create_pool2d, create_classifier +from .layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier from .registry import register_model __all__ = ['PNASNet5Large'] @@ -243,7 +243,7 @@ class PNASNet5Large(nn.Module): self.num_features = 4320 assert output_stride == 32 - self.conv_0 = ConvBnAct( + self.conv_0 = ConvNormAct( in_chans, 96, kernel_size=3, stride=2, padding=0, norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False) diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 6a381074..8a0689f7 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -15,45 +15,76 @@ Hacked together by / Copyright 2020 Ross Wightman """ import numpy as np import torch.nn as nn +from dataclasses import dataclass +from functools import partial +from typing import Optional, Union, Callable from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from .helpers import build_model_with_cfg -from .layers import ClassifierHead, AvgPool2dSame, ConvBnAct, SEModule, DropPath +from .helpers import build_model_with_cfg, named_apply +from .layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, get_act_layer, GroupNormAct from .registry import register_model -def _mcfg(**kwargs): - cfg = dict(se_ratio=0., bottle_ratio=1., stem_width=32) - cfg.update(**kwargs) - return cfg +@dataclass +class RegNetCfg: + depth: int = 21 + w0: int = 80 + wa: float = 42.63 + wm: float = 2.66 + group_size: int = 24 + bottle_ratio: float = 1. + se_ratio: float = 0. + stem_width: int = 32 + downsample: Optional[str] = 'conv1x1' + linear_out: bool = False + act_layer: Union[str, Callable] = 'relu' + norm_layer: Union[str, Callable] = 'batchnorm' # Model FLOPS = three trailing digits * 10^8 model_cfgs = dict( - regnetx_002=_mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13), - regnetx_004=_mcfg(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22), - regnetx_006=_mcfg(w0=48, wa=36.97, wm=2.24, group_w=24, depth=16), - regnetx_008=_mcfg(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16), - regnetx_016=_mcfg(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18), - regnetx_032=_mcfg(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25), - regnetx_040=_mcfg(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23), - regnetx_064=_mcfg(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17), - regnetx_080=_mcfg(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23), - regnetx_120=_mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19), - regnetx_160=_mcfg(w0=216, wa=55.59, wm=2.1, group_w=128, depth=22), - regnetx_320=_mcfg(w0=320, wa=69.86, wm=2.0, group_w=168, depth=23), - regnety_002=_mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13, se_ratio=0.25), - regnety_004=_mcfg(w0=48, wa=27.89, wm=2.09, group_w=8, depth=16, se_ratio=0.25), - regnety_006=_mcfg(w0=48, wa=32.54, wm=2.32, group_w=16, depth=15, se_ratio=0.25), - regnety_008=_mcfg(w0=56, wa=38.84, wm=2.4, group_w=16, depth=14, se_ratio=0.25), - regnety_016=_mcfg(w0=48, wa=20.71, wm=2.65, group_w=24, depth=27, se_ratio=0.25), - regnety_032=_mcfg(w0=80, wa=42.63, wm=2.66, group_w=24, depth=21, se_ratio=0.25), - regnety_040=_mcfg(w0=96, wa=31.41, wm=2.24, group_w=64, depth=22, se_ratio=0.25), - regnety_064=_mcfg(w0=112, wa=33.22, wm=2.27, group_w=72, depth=25, se_ratio=0.25), - regnety_080=_mcfg(w0=192, wa=76.82, wm=2.19, group_w=56, depth=17, se_ratio=0.25), - regnety_120=_mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, se_ratio=0.25), - regnety_160=_mcfg(w0=200, wa=106.23, wm=2.48, group_w=112, depth=18, se_ratio=0.25), - regnety_320=_mcfg(w0=232, wa=115.89, wm=2.53, group_w=232, depth=20, se_ratio=0.25), + # RegNet-X + regnetx_002=RegNetCfg(w0=24, wa=36.44, wm=2.49, group_size=8, depth=13), + regnetx_004=RegNetCfg(w0=24, wa=24.48, wm=2.54, group_size=16, depth=22), + regnetx_006=RegNetCfg(w0=48, wa=36.97, wm=2.24, group_size=24, depth=16), + regnetx_008=RegNetCfg(w0=56, wa=35.73, wm=2.28, group_size=16, depth=16), + regnetx_016=RegNetCfg(w0=80, wa=34.01, wm=2.25, group_size=24, depth=18), + regnetx_032=RegNetCfg(w0=88, wa=26.31, wm=2.25, group_size=48, depth=25), + regnetx_040=RegNetCfg(w0=96, wa=38.65, wm=2.43, group_size=40, depth=23), + regnetx_064=RegNetCfg(w0=184, wa=60.83, wm=2.07, group_size=56, depth=17), + regnetx_080=RegNetCfg(w0=80, wa=49.56, wm=2.88, group_size=120, depth=23), + regnetx_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19), + regnetx_160=RegNetCfg(w0=216, wa=55.59, wm=2.1, group_size=128, depth=22), + regnetx_320=RegNetCfg(w0=320, wa=69.86, wm=2.0, group_size=168, depth=23), + + # RegNet-Y + regnety_002=RegNetCfg(w0=24, wa=36.44, wm=2.49, group_size=8, depth=13, se_ratio=0.25), + regnety_004=RegNetCfg(w0=48, wa=27.89, wm=2.09, group_size=8, depth=16, se_ratio=0.25), + regnety_006=RegNetCfg(w0=48, wa=32.54, wm=2.32, group_size=16, depth=15, se_ratio=0.25), + regnety_008=RegNetCfg(w0=56, wa=38.84, wm=2.4, group_size=16, depth=14, se_ratio=0.25), + regnety_016=RegNetCfg(w0=48, wa=20.71, wm=2.65, group_size=24, depth=27, se_ratio=0.25), + regnety_032=RegNetCfg(w0=80, wa=42.63, wm=2.66, group_size=24, depth=21, se_ratio=0.25), + regnety_040=RegNetCfg(w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25), + regnety_064=RegNetCfg(w0=112, wa=33.22, wm=2.27, group_size=72, depth=25, se_ratio=0.25), + regnety_080=RegNetCfg(w0=192, wa=76.82, wm=2.19, group_size=56, depth=17, se_ratio=0.25), + regnety_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19, se_ratio=0.25), + regnety_160=RegNetCfg(w0=200, wa=106.23, wm=2.48, group_size=112, depth=18, se_ratio=0.25), + regnety_320=RegNetCfg(w0=232, wa=115.89, wm=2.53, group_size=232, depth=20, se_ratio=0.25), + + # Experimental + regnety_040s_gn=RegNetCfg( + w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25, + act_layer='silu', norm_layer=partial(GroupNormAct, group_size=16)), + + # RegNet-Z (unverified) + regnetz_005=RegNetCfg( + depth=21, w0=16, wa=10.7, wm=2.51, group_size=4, bottle_ratio=4.0, se_ratio=0.25, + downsample=None, linear_out=True, act_layer='silu', + ), + regnetz_040=RegNetCfg( + depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25, + downsample=None, linear_out=True, act_layer='silu', + ), ) @@ -80,6 +111,7 @@ default_cfgs = dict( regnetx_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth'), regnetx_160=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth'), regnetx_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth'), + regnety_002=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth'), regnety_004=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth'), regnety_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth'), @@ -96,6 +128,11 @@ default_cfgs = dict( url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth', # from Facebook DeiT GitHub repository crop_pct=1.0, test_input_size=(3, 288, 288)), regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'), + + regnety_040s_gn=_cfg(url=''), + + regnetz_005=_cfg(url=''), + regnetz_040=_cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), ) @@ -125,6 +162,40 @@ def generate_regnet(width_slope, width_initial, width_mult, depth, q=8): return widths, num_stages, max_stage, widths_cont +def downsample_conv(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None): + norm_layer = norm_layer or nn.BatchNorm2d + kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size + dilation = dilation if kernel_size > 1 else 1 + return ConvNormAct( + in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, apply_act=False) + + +def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None): + """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.""" + norm_layer = norm_layer or nn.BatchNorm2d + avg_stride = stride if dilation == 1 else 1 + pool = nn.Identity() + if stride > 1 or dilation > 1: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + return nn.Sequential(*[ + pool, ConvNormAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, apply_act=False)]) + + +def create_shortcut(downsample_type, in_chs, out_chs, kernel_size, stride, dilation=(1, 1), norm_layer=None): + assert downsample_type in ('avg', 'conv1x1', '', None) + if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: + if not downsample_type: + return None # no shortcut, no downsample + elif downsample_type == 'avg': + return downsample_avg(in_chs, out_chs, stride=stride, dilation=dilation[0], norm_layer=norm_layer) + else: + return downsample_conv( + in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation[0], norm_layer=norm_layer) + else: + return nn.Identity() # identity shortcut (no downsample) + + class Bottleneck(nn.Module): """ RegNet Bottleneck @@ -132,97 +203,70 @@ class Bottleneck(nn.Module): after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels. """ - def __init__(self, in_chs, out_chs, stride=1, dilation=1, bottleneck_ratio=1, group_width=1, se_ratio=0.25, - downsample=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, - drop_block=None, drop_path=None): + def __init__(self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25, + downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + drop_block=None, drop_path_rate=0.): super(Bottleneck, self).__init__() - bottleneck_chs = int(round(out_chs * bottleneck_ratio)) - groups = bottleneck_chs // group_width - - cargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block) - self.conv1 = ConvBnAct(in_chs, bottleneck_chs, kernel_size=1, **cargs) - self.conv2 = ConvBnAct( - bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=dilation, - groups=groups, **cargs) + act_layer = get_act_layer(act_layer) + bottleneck_chs = int(round(out_chs * bottle_ratio)) + groups = bottleneck_chs // group_size + + cargs = dict(act_layer=act_layer, norm_layer=norm_layer) + self.conv1 = ConvNormAct(in_chs, bottleneck_chs, kernel_size=1, **cargs) + self.conv2 = ConvNormAct( + bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=dilation[0], + groups=groups, drop_layer=drop_block, **cargs) if se_ratio: se_channels = int(round(in_chs * se_ratio)) - self.se = SEModule(bottleneck_chs, rd_channels=se_channels) + self.se = SEModule(bottleneck_chs, rd_channels=se_channels, act_layer=act_layer) else: - self.se = None - cargs['act_layer'] = None - self.conv3 = ConvBnAct(bottleneck_chs, out_chs, kernel_size=1, **cargs) - self.act3 = act_layer(inplace=True) - self.downsample = downsample - self.drop_path = drop_path - - def zero_init_last_bn(self): + self.se = nn.Identity() + self.conv3 = ConvNormAct(bottleneck_chs, out_chs, kernel_size=1, apply_act=False, **cargs) + self.act3 = nn.Identity() if linear_out else act_layer() + self.downsample = create_shortcut(downsample, in_chs, out_chs, 1, stride, dilation, norm_layer=norm_layer) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + + def zero_init_last(self): nn.init.zeros_(self.conv3.bn.weight) def forward(self, x): shortcut = x x = self.conv1(x) x = self.conv2(x) - if self.se is not None: - x = self.se(x) + x = self.se(x) x = self.conv3(x) - if self.drop_path is not None: - x = self.drop_path(x) if self.downsample is not None: - shortcut = self.downsample(shortcut) - x += shortcut + # NOTE stuck with downsample as the attr name due to weight compatibility + # now represents the shortcut, no shortcut if None, and non-downsample shortcut == nn.Identity() + x = x + self.drop_path(self.downsample(shortcut)) x = self.act3(x) return x -def downsample_conv( - in_chs, out_chs, kernel_size, stride=1, dilation=1, norm_layer=None): - norm_layer = norm_layer or nn.BatchNorm2d - kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size - dilation = dilation if kernel_size > 1 else 1 - return ConvBnAct( - in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, act_layer=None) - - -def downsample_avg( - in_chs, out_chs, kernel_size, stride=1, dilation=1, norm_layer=None): - """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.""" - norm_layer = norm_layer or nn.BatchNorm2d - avg_stride = stride if dilation == 1 else 1 - pool = nn.Identity() - if stride > 1 or dilation > 1: - avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d - pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) - return nn.Sequential(*[ - pool, ConvBnAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, act_layer=None)]) - - class RegStage(nn.Module): """Stage (sequence of blocks w/ the same output shape).""" - def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio, group_width, - block_fn=Bottleneck, se_ratio=0., drop_path_rates=None, drop_block=None): + def __init__( + self, depth, in_chs, out_chs, stride, dilation, bottle_ratio=1.0, group_size=8, block_fn=Bottleneck, + se_ratio=0., downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + drop_path_rates=None, drop_block=None): super(RegStage, self).__init__() - block_kwargs = {} # FIXME setup to pass various aa, norm, act layer common args + block_kwargs = dict( + bottle_ratio=bottle_ratio, group_size=group_size, se_ratio=se_ratio, downsample=downsample, + linear_out=linear_out, act_layer=act_layer, norm_layer=norm_layer, drop_block=drop_block) first_dilation = 1 if dilation in (1, 2) else 2 for i in range(depth): block_stride = stride if i == 0 else 1 block_in_chs = in_chs if i == 0 else out_chs - block_dilation = first_dilation if i == 0 else dilation - if drop_path_rates is not None and drop_path_rates[i] > 0.: - drop_path = DropPath(drop_path_rates[i]) - else: - drop_path = None - if (block_in_chs != out_chs) or (block_stride != 1): - proj_block = downsample_conv(block_in_chs, out_chs, 1, block_stride, block_dilation) - else: - proj_block = None - + block_dilation = (first_dilation, dilation) + dpr = drop_path_rates[i] if drop_path_rates is not None else 0. name = "b{}".format(i + 1) self.add_module( name, block_fn( - block_in_chs, out_chs, block_stride, block_dilation, bottle_ratio, group_width, se_ratio, - downsample=proj_block, drop_block=drop_block, drop_path=drop_path, **block_kwargs) + block_in_chs, out_chs, stride=block_stride, dilation=block_dilation, + drop_path_rate=dpr, **block_kwargs) ) + first_dilation = dilation def forward(self, x): for block in self.children(): @@ -231,33 +275,34 @@ class RegStage(nn.Module): class RegNet(nn.Module): - """RegNet model. + """RegNet-X, Y, and Z Models Paper: https://arxiv.org/abs/2003.13678 Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py """ - def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0., - drop_path_rate=0., zero_init_last_bn=True): + def __init__( + self, cfg: RegNetCfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', + drop_rate=0., drop_path_rate=0., zero_init_last=True): super().__init__() - # TODO add drop block, drop path, anti-aliasing, custom bn/act args self.num_classes = num_classes self.drop_rate = drop_rate assert output_stride in (8, 16, 32) # Construct the stem - stem_width = cfg['stem_width'] - self.stem = ConvBnAct(in_chans, stem_width, 3, stride=2) + stem_width = cfg.stem_width + self.stem = ConvNormAct(in_chans, stem_width, 3, stride=2, act_layer=cfg.act_layer, norm_layer=cfg.norm_layer) self.feature_info = [dict(num_chs=stem_width, reduction=2, module='stem')] # Construct the stages prev_width = stem_width curr_stride = 2 stage_params = self._get_stage_params(cfg, output_stride=output_stride, drop_path_rate=drop_path_rate) - se_ratio = cfg['se_ratio'] for i, stage_args in enumerate(stage_params): stage_name = "s{}".format(i + 1) - self.add_module(stage_name, RegStage(prev_width, **stage_args, se_ratio=se_ratio)) + self.add_module(stage_name, RegStage( + in_chs=prev_width, se_ratio=cfg.se_ratio, downsample=cfg.downsample, linear_out=cfg.linear_out, + act_layer=cfg.act_layer, norm_layer=cfg.norm_layer, **stage_args)) prev_width = stage_args['out_chs'] curr_stride *= stage_args['stride'] self.feature_info += [dict(num_chs=prev_width, reduction=curr_stride, module=stage_name)] @@ -267,31 +312,18 @@ class RegNet(nn.Module): self.head = ClassifierHead( in_chs=prev_width, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) - for m in self.modules(): - if isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, nn.BatchNorm2d): - nn.init.ones_(m.weight) - nn.init.zeros_(m.bias) - elif isinstance(m, nn.Linear): - nn.init.normal_(m.weight, mean=0.0, std=0.01) - nn.init.zeros_(m.bias) - if zero_init_last_bn: - for m in self.modules(): - if hasattr(m, 'zero_init_last_bn'): - m.zero_init_last_bn() - - def _get_stage_params(self, cfg, default_stride=2, output_stride=32, drop_path_rate=0.): + named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) + + def _get_stage_params(self, cfg: RegNetCfg, default_stride=2, output_stride=32, drop_path_rate=0.): # Generate RegNet ws per block - w_a, w_0, w_m, d = cfg['wa'], cfg['w0'], cfg['wm'], cfg['depth'] - widths, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d) + widths, num_stages, _, _ = generate_regnet(cfg.wa, cfg.w0, cfg.wm, cfg.depth) # Convert to per stage format stage_widths, stage_depths = np.unique(widths, return_counts=True) # Use the same group width, bottleneck mult and stride for each stage - stage_groups = [cfg['group_w'] for _ in range(num_stages)] - stage_bottle_ratios = [cfg['bottle_ratio'] for _ in range(num_stages)] + stage_groups = [cfg.group_size for _ in range(num_stages)] + stage_bottle_ratios = [cfg.bottle_ratio for _ in range(num_stages)] stage_strides = [] stage_dilations = [] net_stride = 2 @@ -305,11 +337,11 @@ class RegNet(nn.Module): net_stride *= stride stage_strides.append(stride) stage_dilations.append(dilation) - stage_dpr = np.split(np.linspace(0, drop_path_rate, d), np.cumsum(stage_depths[:-1])) + stage_dpr = np.split(np.linspace(0, drop_path_rate, cfg.depth), np.cumsum(stage_depths[:-1])) # Adjust the compatibility of ws and gws stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups) - param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width', 'drop_path_rates'] + param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_size', 'drop_path_rates'] stage_params = [ dict(zip(param_names, params)) for params in zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups, @@ -333,6 +365,19 @@ class RegNet(nn.Module): return x +def _init_weights(module, name='', zero_init_last=False): + if isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(module, nn.BatchNorm2d): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Linear): + nn.init.normal_(module.weight, mean=0.0, std=0.01) + nn.init.zeros_(module.bias) + elif hasattr(module, 'zero_init_last'): + module.zero_init_last() + + def _filter_fn(state_dict): """ convert patch embedding weight from manual patchify + linear proj to conv""" if 'model' in state_dict: @@ -492,3 +537,27 @@ def regnety_160(pretrained=False, **kwargs): def regnety_320(pretrained=False, **kwargs): """RegNetY-32GF""" return _create_regnet('regnety_320', pretrained, **kwargs) + + +@register_model +def regnety_040s_gn(pretrained=False, **kwargs): + """RegNetY-4.0GF w/ GroupNorm """ + return _create_regnet('regnety_040s_gn', pretrained, **kwargs) + + +@register_model +def regnetz_005(pretrained=False, **kwargs): + """RegNetZ-500MF + NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py + but it's not clear it is equivalent to paper model as not detailed in the paper. + """ + return _create_regnet('regnetz_005', pretrained, **kwargs) + + +@register_model +def regnetz_040(pretrained=False, **kwargs): + """RegNetZ-4.0GF + NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py + but it's not clear it is equivalent to paper model as not detailed in the paper. + """ + return _create_regnet('regnetz_040', pretrained, **kwargs) diff --git a/timm/models/resnest.py b/timm/models/resnest.py index 31eebd80..f3119807 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -75,7 +75,6 @@ class ResNestBottleneck(nn.Module): else: avd_stride = 0 self.radix = radix - self.drop_block = drop_block self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) self.bn1 = norm_layer(group_width) @@ -85,14 +84,16 @@ class ResNestBottleneck(nn.Module): if self.radix >= 1: self.conv2 = SplitAttn( group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, - dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block) + dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_layer=drop_block) self.bn2 = nn.Identity() + self.drop_block = nn.Identity() self.act2 = nn.Identity() else: self.conv2 = nn.Conv2d( group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) self.bn2 = norm_layer(group_width) + self.drop_block = drop_block() if drop_block is not None else nn.Identity() self.act2 = act_layer(inplace=True) self.avd_last = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and not avd_first else None @@ -109,8 +110,6 @@ class ResNestBottleneck(nn.Module): out = self.conv1(x) out = self.bn1(out) - if self.drop_block is not None: - out = self.drop_block(out) out = self.act1(out) if self.avd_first is not None: @@ -118,8 +117,7 @@ class ResNestBottleneck(nn.Module): out = self.conv2(out) out = self.bn2(out) - if self.drop_block is not None: - out = self.drop_block(out) + out = self.drop_block(out) out = self.act2(out) if self.avd_last is not None: @@ -127,8 +125,6 @@ class ResNestBottleneck(nn.Module): out = self.conv3(out) out = self.bn3(out) - if self.drop_block is not None: - out = self.drop_block(out) if self.downsample is not None: shortcut = self.downsample(x) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index bbcae9a3..cb71c464 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -307,8 +307,9 @@ class BasicBlock(nn.Module): inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation, dilation=first_dilation, bias=False) self.bn1 = norm_layer(first_planes) + self.drop_block = drop_block() if drop_block is not None else nn.Identity() self.act1 = act_layer(inplace=True) - self.aa = aa_layer(channels=first_planes, stride=stride) if use_aa else None + self.aa = aa_layer(channels=first_planes, stride=stride) if use_aa else nn.Identity() self.conv2 = nn.Conv2d( first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) @@ -320,7 +321,6 @@ class BasicBlock(nn.Module): self.downsample = downsample self.stride = stride self.dilation = dilation - self.drop_block = drop_block self.drop_path = drop_path def zero_init_last_bn(self): @@ -331,16 +331,12 @@ class BasicBlock(nn.Module): x = self.conv1(x) x = self.bn1(x) - if self.drop_block is not None: - x = self.drop_block(x) + x = self.drop_block(x) x = self.act1(x) - if self.aa is not None: - x = self.aa(x) + x = self.aa(x) x = self.conv2(x) x = self.bn2(x) - if self.drop_block is not None: - x = self.drop_block(x) if self.se is not None: x = self.se(x) @@ -378,8 +374,9 @@ class Bottleneck(nn.Module): first_planes, width, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) self.bn2 = norm_layer(width) + self.drop_block = drop_block() if drop_block is not None else nn.Identity() self.act2 = act_layer(inplace=True) - self.aa = aa_layer(channels=width, stride=stride) if use_aa else None + self.aa = aa_layer(channels=width, stride=stride) if use_aa else nn.Identity() self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) self.bn3 = norm_layer(outplanes) @@ -390,7 +387,6 @@ class Bottleneck(nn.Module): self.downsample = downsample self.stride = stride self.dilation = dilation - self.drop_block = drop_block self.drop_path = drop_path def zero_init_last_bn(self): @@ -401,22 +397,16 @@ class Bottleneck(nn.Module): x = self.conv1(x) x = self.bn1(x) - if self.drop_block is not None: - x = self.drop_block(x) x = self.act1(x) x = self.conv2(x) x = self.bn2(x) - if self.drop_block is not None: - x = self.drop_block(x) + x = self.drop_block(x) x = self.act2(x) - if self.aa is not None: - x = self.aa(x) + x = self.aa(x) x = self.conv3(x) x = self.bn3(x) - if self.drop_block is not None: - x = self.drop_block(x) if self.se is not None: x = self.se(x) @@ -463,11 +453,11 @@ def downsample_avg( ]) -def drop_blocks(drop_block_rate=0.): +def drop_blocks(drop_prob=0.): return [ None, None, - DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None, - DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None] + partial(DropBlock2d, drop_prob=drop_prob, block_size=5, gamma_scale=0.25) if drop_prob else None, + partial(DropBlock2d, drop_prob=drop_prob, block_size=3, gamma_scale=1.00) if drop_prob else None] def make_blocks( diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index f27ce5d8..1cb8e2f5 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -17,7 +17,7 @@ from math import ceil from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath, make_divisible, SEModule +from .layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule from .registry import register_model from .efficientnet_builder import efficientnet_init_weights @@ -63,19 +63,19 @@ class LinearBottleneck(nn.Module): if exp_ratio != 1.: dw_chs = make_divisible(round(in_chs * exp_ratio), divisor=ch_div) - self.conv_exp = ConvBnAct(in_chs, dw_chs, act_layer=act_layer) + self.conv_exp = ConvNormAct(in_chs, dw_chs, act_layer=act_layer) else: dw_chs = in_chs self.conv_exp = None - self.conv_dw = ConvBnAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False) + self.conv_dw = ConvNormAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False) if se_ratio > 0: self.se = SEWithNorm(dw_chs, rd_channels=make_divisible(int(dw_chs * se_ratio), ch_div)) else: self.se = None self.act_dw = create_act_layer(dw_act_layer) - self.conv_pwl = ConvBnAct(dw_chs, out_chs, 1, apply_act=False) + self.conv_pwl = ConvNormAct(dw_chs, out_chs, 1, apply_act=False) self.drop_path = drop_path def feat_channels(self, exp=False): @@ -138,7 +138,7 @@ def _build_blocks( feat_chs += [features[-1].feat_channels()] pen_chs = make_divisible(1280 * width_mult, divisor=ch_div) feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=f'features.{len(features) - 1}')] - features.append(ConvBnAct(prev_chs, pen_chs, act_layer=act_layer)) + features.append(ConvNormAct(prev_chs, pen_chs, act_layer=act_layer)) return features, feature_info @@ -153,7 +153,7 @@ class ReXNetV1(nn.Module): assert output_stride == 32 # FIXME support dilation stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32 stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div) - self.stem = ConvBnAct(in_chans, stem_chs, 3, stride=2, act_layer=act_layer) + self.stem = ConvNormAct(in_chans, stem_chs, 3, stride=2, act_layer=act_layer) block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, se_ratio, ch_div) features, self.feature_info = _build_blocks( diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 4dc2aa53..87520fbe 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -14,7 +14,7 @@ from torch import nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import SelectiveKernel, ConvBnAct, create_attn +from .layers import SelectiveKernel, ConvNormAct, ConvNormActAa, create_attn from .registry import register_model from .resnet import ResNet @@ -52,7 +52,7 @@ class SelectiveKernelBasic(nn.Module): super(SelectiveKernelBasic, self).__init__() sk_kwargs = sk_kwargs or {} - conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer) + conv_kwargs = dict(act_layer=act_layer, norm_layer=norm_layer) assert cardinality == 1, 'BasicBlock only supports cardinality of 1' assert base_width == 64, 'BasicBlock doest not support changing base width' first_planes = planes // reduce_first @@ -60,16 +60,13 @@ class SelectiveKernelBasic(nn.Module): first_dilation = first_dilation or dilation self.conv1 = SelectiveKernel( - inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs) - conv_kwargs['act_layer'] = None - self.conv2 = ConvBnAct( - first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs) + inplanes, first_planes, stride=stride, dilation=first_dilation, + aa_layer=aa_layer, drop_layer=drop_block, **conv_kwargs, **sk_kwargs) + self.conv2 = ConvNormAct( + first_planes, outplanes, kernel_size=3, dilation=dilation, apply_act=False, **conv_kwargs) self.se = create_attn(attn_layer, outplanes) self.act = act_layer(inplace=True) self.downsample = downsample - self.stride = stride - self.dilation = dilation - self.drop_block = drop_block self.drop_path = drop_path def zero_init_last_bn(self): @@ -100,24 +97,20 @@ class SelectiveKernelBottleneck(nn.Module): super(SelectiveKernelBottleneck, self).__init__() sk_kwargs = sk_kwargs or {} - conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer) + conv_kwargs = dict(act_layer=act_layer, norm_layer=norm_layer) width = int(math.floor(planes * (base_width / 64)) * cardinality) first_planes = width // reduce_first outplanes = planes * self.expansion first_dilation = first_dilation or dilation - self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs) + self.conv1 = ConvNormAct(inplanes, first_planes, kernel_size=1, **conv_kwargs) self.conv2 = SelectiveKernel( first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, - **conv_kwargs, **sk_kwargs) - conv_kwargs['act_layer'] = None - self.conv3 = ConvBnAct(width, outplanes, kernel_size=1, **conv_kwargs) + aa_layer=aa_layer, drop_layer=drop_block, **conv_kwargs, **sk_kwargs) + self.conv3 = ConvNormAct(width, outplanes, kernel_size=1, apply_act=False, **conv_kwargs) self.se = create_attn(attn_layer, outplanes) self.act = act_layer(inplace=True) self.downsample = downsample - self.stride = stride - self.dilation = dilation - self.drop_block = drop_block self.drop_path = drop_path def zero_init_last_bn(self): diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index 608cd45b..c9d8c6ff 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -20,8 +20,8 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .registry import register_model from .helpers import build_model_with_cfg -from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, ClassifierHead, DropPath,\ - create_attn, create_norm_act, get_norm_act_layer +from .layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath,\ + create_attn, create_norm_act_layer, get_norm_act_layer # model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 & @@ -189,23 +189,23 @@ class OsaBlock(nn.Module): next_in_chs = in_chs if self.depthwise and next_in_chs != mid_chs: assert not residual - self.conv_reduction = ConvBnAct(next_in_chs, mid_chs, 1, **conv_kwargs) + self.conv_reduction = ConvNormAct(next_in_chs, mid_chs, 1, **conv_kwargs) else: self.conv_reduction = None mid_convs = [] for i in range(layer_per_block): if self.depthwise: - conv = SeparableConvBnAct(mid_chs, mid_chs, **conv_kwargs) + conv = SeparableConvNormAct(mid_chs, mid_chs, **conv_kwargs) else: - conv = ConvBnAct(next_in_chs, mid_chs, 3, **conv_kwargs) + conv = ConvNormAct(next_in_chs, mid_chs, 3, **conv_kwargs) next_in_chs = mid_chs mid_convs.append(conv) self.conv_mid = SequentialAppendList(*mid_convs) # feature aggregation next_in_chs = in_chs + layer_per_block * mid_chs - self.conv_concat = ConvBnAct(next_in_chs, out_chs, **conv_kwargs) + self.conv_concat = ConvNormAct(next_in_chs, out_chs, **conv_kwargs) if attn: self.attn = create_attn(attn, out_chs) @@ -283,9 +283,9 @@ class VovNet(nn.Module): # Stem module last_stem_stride = stem_stride // 2 - conv_type = SeparableConvBnAct if cfg["depthwise"] else ConvBnAct + conv_type = SeparableConvNormAct if cfg["depthwise"] else ConvNormAct self.stem = nn.Sequential(*[ - ConvBnAct(in_chans, stem_chs[0], 3, stride=2, **conv_kwargs), + ConvNormAct(in_chans, stem_chs[0], 3, stride=2, **conv_kwargs), conv_type(stem_chs[0], stem_chs[1], 3, stride=1, **conv_kwargs), conv_type(stem_chs[1], stem_chs[2], 3, stride=last_stem_stride, **conv_kwargs), ]) @@ -395,12 +395,12 @@ def eca_vovnet39b(pretrained=False, **kwargs): @register_model def ese_vovnet39b_evos(pretrained=False, **kwargs): def norm_act_fn(num_features, **nkwargs): - return create_norm_act('evonorms0', num_features, jit=False, **nkwargs) + return create_norm_act_layer('evonorms0', num_features, jit=False, **nkwargs) return _create_vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs) @register_model def ese_vovnet99b_iabn(pretrained=False, **kwargs): - norm_layer = get_norm_act_layer('iabn') + norm_layer = get_norm_act_layer('iabn', act_layer='leaky_relu') return _create_vovnet( 'ese_vovnet99b_iabn', pretrained=pretrained, norm_layer=norm_layer, act_layer=nn.LeakyReLU, **kwargs) diff --git a/timm/models/xception_aligned.py b/timm/models/xception_aligned.py index ea7f5c05..457dc11a 100644 --- a/timm/models/xception_aligned.py +++ b/timm/models/xception_aligned.py @@ -12,7 +12,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .helpers import build_model_with_cfg -from .layers import ClassifierHead, ConvBnAct, create_conv2d +from .layers import ClassifierHead, ConvNormAct, create_conv2d, get_norm_act_layer from .layers.helpers import to_3tuple from .registry import register_model @@ -37,12 +37,14 @@ default_cfgs = dict( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_65-c9ae96e8.pth'), xception71=_cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'), + + xception41p=_cfg(url=''), ) class SeparableConv2d(nn.Module): def __init__( - self, inplanes, planes, kernel_size=3, stride=1, dilation=1, padding='', + self, in_chs, out_chs, kernel_size=3, stride=1, dilation=1, padding='', act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(SeparableConv2d, self).__init__() self.kernel_size = kernel_size @@ -50,31 +52,48 @@ class SeparableConv2d(nn.Module): # depthwise convolution self.conv_dw = create_conv2d( - inplanes, inplanes, kernel_size, stride=stride, + in_chs, in_chs, kernel_size, stride=stride, padding=padding, dilation=dilation, depthwise=True) - self.bn_dw = norm_layer(inplanes) - if act_layer is not None: - self.act_dw = act_layer(inplace=True) - else: - self.act_dw = None + self.bn_dw = norm_layer(in_chs) + self.act_dw = act_layer(inplace=True) if act_layer is not None else nn.Identity() # pointwise convolution - self.conv_pw = create_conv2d(inplanes, planes, kernel_size=1) - self.bn_pw = norm_layer(planes) - if act_layer is not None: - self.act_pw = act_layer(inplace=True) - else: - self.act_pw = None + self.conv_pw = create_conv2d(in_chs, out_chs, kernel_size=1) + self.bn_pw = norm_layer(out_chs) + self.act_pw = act_layer(inplace=True) if act_layer is not None else nn.Identity() def forward(self, x): x = self.conv_dw(x) x = self.bn_dw(x) - if self.act_dw is not None: - x = self.act_dw(x) + x = self.act_dw(x) x = self.conv_pw(x) x = self.bn_pw(x) - if self.act_pw is not None: - x = self.act_pw(x) + x = self.act_pw(x) + return x + + +class PreSeparableConv2d(nn.Module): + def __init__( + self, in_chs, out_chs, kernel_size=3, stride=1, dilation=1, padding='', + act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, first_act=True): + super(PreSeparableConv2d, self).__init__() + norm_act_layer = get_norm_act_layer(norm_layer, act_layer=act_layer) + self.kernel_size = kernel_size + self.dilation = dilation + + self.norm = norm_act_layer(in_chs, inplace=True) if first_act else nn.Identity() + # depthwise convolution + self.conv_dw = create_conv2d( + in_chs, in_chs, kernel_size, stride=stride, + padding=padding, dilation=dilation, depthwise=True) + + # pointwise convolution + self.conv_pw = create_conv2d(in_chs, out_chs, kernel_size=1) + + def forward(self, x): + x = self.norm(x) + x = self.conv_dw(x) + x = self.conv_pw(x) return x @@ -88,8 +107,8 @@ class XceptionModule(nn.Module): self.out_channels = out_chs[-1] self.no_skip = no_skip if not no_skip and (self.out_channels != self.in_channels or stride != 1): - self.shortcut = ConvBnAct( - in_chs, self.out_channels, 1, stride=stride, norm_layer=norm_layer, act_layer=None) + self.shortcut = ConvNormAct( + in_chs, self.out_channels, 1, stride=stride, norm_layer=norm_layer, apply_act=False) else: self.shortcut = None @@ -97,7 +116,7 @@ class XceptionModule(nn.Module): self.stack = nn.Sequential() for i in range(3): if start_with_relu: - self.stack.add_module(f'act{i + 1}', nn.ReLU(inplace=i > 0)) + self.stack.add_module(f'act{i + 1}', act_layer(inplace=i > 0)) self.stack.add_module(f'conv{i + 1}', SeparableConv2d( in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type, act_layer=separable_act_layer, norm_layer=norm_layer)) @@ -113,11 +132,42 @@ class XceptionModule(nn.Module): return x +class PreXceptionModule(nn.Module): + def __init__( + self, in_chs, out_chs, stride=1, dilation=1, pad_type='', + no_skip=False, act_layer=nn.ReLU, norm_layer=None): + super(PreXceptionModule, self).__init__() + out_chs = to_3tuple(out_chs) + self.in_channels = in_chs + self.out_channels = out_chs[-1] + self.no_skip = no_skip + if not no_skip and (self.out_channels != self.in_channels or stride != 1): + self.shortcut = create_conv2d(in_chs, self.out_channels, 1, stride=stride) + else: + self.shortcut = nn.Identity() + + self.norm = get_norm_act_layer(norm_layer, act_layer=act_layer)(in_chs, inplace=True) + self.stack = nn.Sequential() + for i in range(3): + self.stack.add_module(f'conv{i + 1}', PreSeparableConv2d( + in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type, + act_layer=act_layer, norm_layer=norm_layer, first_act=i > 0)) + in_chs = out_chs[i] + + def forward(self, x): + x = self.norm(x) + skip = x + x = self.stack(x) + if not self.no_skip: + x = x + self.shortcut(skip) + return x + + class XceptionAligned(nn.Module): """Modified Aligned Xception """ - def __init__(self, block_cfg, num_classes=1000, in_chans=3, output_stride=32, + def __init__(self, block_cfg, num_classes=1000, in_chans=3, output_stride=32, preact=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0., global_pool='avg'): super(XceptionAligned, self).__init__() self.num_classes = num_classes @@ -126,31 +176,33 @@ class XceptionAligned(nn.Module): layer_args = dict(act_layer=act_layer, norm_layer=norm_layer) self.stem = nn.Sequential(*[ - ConvBnAct(in_chans, 32, kernel_size=3, stride=2, **layer_args), - ConvBnAct(32, 64, kernel_size=3, stride=1, **layer_args) + ConvNormAct(in_chans, 32, kernel_size=3, stride=2, **layer_args), + create_conv2d(32, 64, kernel_size=3, stride=1) if preact else + ConvNormAct(32, 64, kernel_size=3, stride=1, **layer_args) ]) curr_dilation = 1 curr_stride = 2 self.feature_info = [] self.blocks = nn.Sequential() + module_fn = PreXceptionModule if preact else XceptionModule for i, b in enumerate(block_cfg): b['dilation'] = curr_dilation if b['stride'] > 1: - self.feature_info += [dict( - num_chs=to_3tuple(b['out_chs'])[-2], reduction=curr_stride, module=f'blocks.{i}.stack.act3')] + name = f'blocks.{i}.stack.conv2' if preact else f'blocks.{i}.stack.act3' + self.feature_info += [dict(num_chs=to_3tuple(b['out_chs'])[-2], reduction=curr_stride, module=name)] next_stride = curr_stride * b['stride'] if next_stride > output_stride: curr_dilation *= b['stride'] b['stride'] = 1 else: curr_stride = next_stride - self.blocks.add_module(str(i), XceptionModule(**b, **layer_args)) + self.blocks.add_module(str(i), module_fn(**b, **layer_args)) self.num_features = self.blocks[-1].out_channels self.feature_info += [dict( num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))] - + self.act = act_layer(inplace=True) if preact else nn.Identity() self.head = ClassifierHead( in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) @@ -163,6 +215,7 @@ class XceptionAligned(nn.Module): def forward_features(self, x): x = self.stem(x) x = self.blocks(x) + x = self.act(x) return x def forward(self, x): @@ -236,3 +289,22 @@ def xception71(pretrained=False, **kwargs): ] model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs) return _xception('xception71', pretrained=pretrained, **model_args) + + +@register_model +def xception41p(pretrained=False, **kwargs): + """ Modified Aligned Xception-41 w/ Pre-Act + """ + block_cfg = [ + # entry flow + dict(in_chs=64, out_chs=128, stride=2), + dict(in_chs=128, out_chs=256, stride=2), + dict(in_chs=256, out_chs=728, stride=2), + # middle flow + *([dict(in_chs=728, out_chs=728, stride=1)] * 8), + # exit flow + dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2), + dict(in_chs=1024, out_chs=(1536, 1536, 2048), no_skip=True, stride=1), + ] + model_args = dict(block_cfg=block_cfg, preact=True, norm_layer=nn.BatchNorm2d, **kwargs) + return _xception('xception41p', pretrained=pretrained, **model_args) From 683fba76862980d33e0982c9844403e6c8a01fcd Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 14 Dec 2021 13:51:00 -0800 Subject: [PATCH 2/3] Add drop args to benchmark.py --- benchmark.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/benchmark.py b/benchmark.py index ccd9b4fa..f1604a04 100755 --- a/benchmark.py +++ b/benchmark.py @@ -199,7 +199,11 @@ class BenchmarkRunner: num_classes=kwargs.pop('num_classes', None), in_chans=3, global_pool=kwargs.pop('gp', 'fast'), - scriptable=torchscript) + scriptable=torchscript, + drop_rate=kwargs.pop('drop', 0.), + drop_path_rate=kwargs.pop('drop_path', None), + drop_block_rate=kwargs.pop('drop_block', None), + ) self.model.to( device=self.device, dtype=self.model_dtype, From a52a614475b0a869b9b30c0ed629043407da4dd2 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 14 Dec 2021 14:29:32 -0800 Subject: [PATCH 3/3] Remove layer experiment which should not have been added --- timm/models/layers/pooled_attn.py | 143 ------------------------------ 1 file changed, 143 deletions(-) delete mode 100644 timm/models/layers/pooled_attn.py diff --git a/timm/models/layers/pooled_attn.py b/timm/models/layers/pooled_attn.py deleted file mode 100644 index 40cf2b34..00000000 --- a/timm/models/layers/pooled_attn.py +++ /dev/null @@ -1,143 +0,0 @@ -from typing import List - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .helpers import to_2tuple -from .weight_init import trunc_normal_ - - -def rel_logits_1d(q, rel_k, permute_mask: List[int]): - """ Compute relative logits along one dimension - - As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 - Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 - - Args: - q: (batch, heads, height, width, dim) - rel_k: (2 * width - 1, dim) - permute_mask: permute output dim according to this - """ - B, H, W, dim = q.shape - x = (q @ rel_k.transpose(-1, -2)) - x = x.reshape(-1, W, 2 * W -1) - - # pad to shift from relative to absolute indexing - x_pad = F.pad(x, [0, 1]).flatten(1) - x_pad = F.pad(x_pad, [0, W - 1]) - - # reshape and slice out the padded elements - x_pad = x_pad.reshape(-1, W + 1, 2 * W - 1) - x = x_pad[:, :W, W - 1:] - - # reshape and tile - x = x.reshape(B, H, 1, W, W).expand(-1, -1, H, -1, -1) - return x.permute(permute_mask) - - -class PosEmbedRel(nn.Module): - """ Relative Position Embedding - As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2 - Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925 - """ - def __init__(self, feat_size, dim_head, scale): - super().__init__() - self.height, self.width = to_2tuple(feat_size) - self.dim_head = dim_head - self.scale = scale - self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * self.scale) - self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * self.scale) - - def forward(self, q): - B, num_heads, HW, _ = q.shape - - # relative logits in width dimension. - q = q.reshape(B * num_heads, self.height, self.width, -1) - rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4)) - - # relative logits in height dimension. - q = q.transpose(1, 2) - rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2)) - - rel_logits = rel_logits_h + rel_logits_w - rel_logits = rel_logits.reshape(B, num_heads, HW, HW) - return rel_logits - - -class BottleneckAttn(nn.Module): - """ Bottleneck Attention - Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605 - """ - def __init__(self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, qkv_bias=False): - super().__init__() - assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required' - dim_out = dim_out or dim - assert dim_out % num_heads == 0 - self.num_heads = num_heads - self.dim_out = dim_out - self.dim_head = dim_out // num_heads - self.scale = self.dim_head ** -0.5 - - self.qkv = nn.Conv2d(dim, self.dim_out * 3, 1, bias=qkv_bias) - - # NOTE I'm only supporting relative pos embedding for now - self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head, scale=self.scale) - - self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity() - - self.reset_parameters() - - def reset_parameters(self): - trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5) - trunc_normal_(self.pos_embed.height_rel, std=self.scale) - trunc_normal_(self.pos_embed.width_rel, std=self.scale) - - def forward(self, x): - B, C, H, W = x.shape - assert H == self.pos_embed.height - assert W == self.pos_embed.width - - x = self.qkv(x) # B, 3 * num_heads * dim_head, H, W - x = x.reshape(B, -1, self.dim_head, H * W).transpose(-1, -2) - q, k, v = torch.split(x, self.num_heads, dim=1) - - attn_logits = (q @ k.transpose(-1, -2)) * self.scale - attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W - - attn_out = attn_logits.softmax(dim=-1) - attn_out = (attn_out @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W - attn_out = self.pool(attn_out) - return attn_out - - -class PoolingAttention(nn.Module): - def __init__(self, in_features: int, attention_features: int, segments: int, max_pool_kernel: int): - super(PoolingAttention, self).__init__() - self.attn = nn.Linear(in_features, attention_features * 5) - self.segments = segments - self.max_pool_kernel = max_pool_kernel - - def forward(self, inp: torch.Tensor): # Shape: [Batch, Sequence, Features] - batch, sequence, features = inp.size() - assert sequence % self.segments == 0 - - qry, key, val, seg, loc = self.attn(inp).chunk(5, 2) # 5x Shape: [Batch, Sequence, AttentionFeatures] - - aggregated = qry.mean(1, keepdim=True) # Shape: [Batch, AttentionFeatures] - aggregated = torch.einsum("ba,bsa->bs", aggregated, key) # Shape: [Batch, Sequence] - aggregated = F.softmax(aggregated, 1) - aggregated = torch.einsum("bs,bsa,bza->bza", aggregated, val, - qry) # Shape: [Batch, Sequence, AttentionFeatures] - - pooled_sequence = sequence // self.segments - segment_max_pooled = seg.view(batch, pooled_sequence, self.segments, -1) - segment_max_pooled = segment_max_pooled.max(2, keepdim=True) # Shape: [Batch, PooledSequence, 1, AttentionFeatures] - segment_max_pooled = segment_max_pooled * qry.view(batch, pooled_sequence, self.segments, -1) # Shape: [Batch, PooledSequence, PoolSize, AttentionFeatures] - segment_max_pooled = segment_max_pooled.view(batch, sequence, -1) # Shape: [Batch, Sequence, AttentionFeatures] - - loc = loc.transpose(1, 2) # Shape: [Batch, AttentionFeatures, Sequence] - local_max_pooled = F.max_pool1d(loc, self.max_pool_kernel, 1, self.max_pool_kernel // 2) - local_max_pooled = local_max_pooled.transpose(1, 2) # Shape: [Batch, Sequence, AttentionFeatures] - - return aggregated + segment_max_pooled + local_max_pooled \ No newline at end of file