From 307a935b790b5af8d551ebecda053cb1a9b16fcb Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 31 May 2021 13:18:11 -0700 Subject: [PATCH] Add non-local and BAT attention. Merge attn and self-attn factories into one. Add attention references to README. Add mlp 'mode' to ECA. --- README.md | 16 ++- timm/models/byobnet.py | 6 +- timm/models/efficientnet.py | 6 +- timm/models/layers/__init__.py | 6 +- timm/models/layers/create_attn.py | 45 +++++++- timm/models/layers/create_self_attn.py | 25 ----- timm/models/layers/eca.py | 28 +++-- timm/models/layers/non_local_attn.py | 145 +++++++++++++++++++++++++ timm/models/layers/selective_kernel.py | 17 +-- timm/models/layers/split_attn.py | 39 +++---- timm/models/layers/squeeze_excite.py | 2 +- timm/models/resnest.py | 17 ++- timm/models/sknet.py | 16 +-- 13 files changed, 276 insertions(+), 92 deletions(-) delete mode 100644 timm/models/layers/create_self_attn.py create mode 100644 timm/models/layers/non_local_attn.py diff --git a/README.md b/README.md index 06aee7ec..0b878a0a 100644 --- a/README.md +++ b/README.md @@ -295,10 +295,24 @@ Several (less common) features that I often utilize in my projects are included. * SplitBachNorm - allows splitting batch norm layers between clean and augmented (auxiliary batch norm) data * DropPath aka "Stochastic Depth" (https://arxiv.org/abs/1603.09382) * DropBlock (https://arxiv.org/abs/1810.12890) -* Efficient Channel Attention - ECA (https://arxiv.org/abs/1910.03151) * Blur Pooling (https://arxiv.org/abs/1904.11486) * Space-to-Depth by [mrT23](https://github.com/mrT23/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py) (https://arxiv.org/abs/1801.04590) -- original paper? * Adaptive Gradient Clipping (https://arxiv.org/abs/2102.06171, https://github.com/deepmind/deepmind-research/tree/master/nfnets) +* An extensive selection of channel and/or spatial attention modules: + * Bottleneck Transformer - https://arxiv.org/abs/2101.11605 + * CBAM - https://arxiv.org/abs/1807.06521 + * Effective Squeeze-Excitation (ESE) - https://arxiv.org/abs/1911.06667 + * Efficient Channel Attention (ECA) - https://arxiv.org/abs/1910.03151 + * Gather-Excite (GE) - https://arxiv.org/abs/1810.12348 + * Global Context (GC) - https://arxiv.org/abs/1904.11492 + * Halo - https://arxiv.org/abs/2103.12731 + * Involution - https://arxiv.org/abs/2103.06255 + * Lambda Layer - https://arxiv.org/abs/2102.08602 + * Non-Local (NL) - https://arxiv.org/abs/1711.07971 + * Squeeze-and-Excitation (SE) - https://arxiv.org/abs/1709.01507 + * Selective Kernel (SK) - (https://arxiv.org/abs/1903.06586 + * Split (SPLAT) - https://arxiv.org/abs/2004.08955 + * Shifted Window (SWIN) - https://arxiv.org/abs/2103.14030 ## Results diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index 8ec8690a..d41245f5 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -35,7 +35,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ - create_conv2d, get_act_layer, convert_norm_act, get_attn, get_self_attn, make_divisible, to_2tuple + create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple from .registry import register_model __all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block'] @@ -935,7 +935,7 @@ def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, mo else: self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs) self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer - self_attn_layer = partial(get_self_attn(self_attn_layer), *self_attn_kwargs) \ + self_attn_layer = partial(get_attn(self_attn_layer), *self_attn_kwargs) \ if self_attn_layer is not None else None layer_fns = replace(layer_fns, self_attn=self_attn_layer) @@ -1010,7 +1010,7 @@ def get_layer_fns(cfg: ByoModelCfg): norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act) conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act) attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None - self_attn = partial(get_self_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None + self_attn = partial(get_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn) return layer_fn diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 09e47684..6426b540 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -1234,7 +1234,8 @@ def eca_efficientnet_b0(pretrained=False, **kwargs): """ EfficientNet-B0 w/ ECA attn """ # NOTE experimental config model = _gen_efficientnet( - 'eca_efficientnet_b0', se_layer='eca', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + 'eca_efficientnet_b0', se_layer='ecam', channel_multiplier=1.0, depth_multiplier=1.0, + pretrained=pretrained, **kwargs) return model @@ -1243,7 +1244,8 @@ def gc_efficientnet_b0(pretrained=False, **kwargs): """ EfficientNet-B0 w/ GlobalContext """ # NOTE experminetal config model = _gen_efficientnet( - 'gc_efficientnet_b0', se_layer='gc', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + 'gc_efficientnet_b0', se_layer='gc', channel_multiplier=1.0, depth_multiplier=1.0, + pretrained=pretrained, **kwargs) return model diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 30a1b40d..77d1026e 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -12,7 +12,6 @@ from .create_act import create_act_layer, get_act_layer, get_act_fn from .create_attn import get_attn, create_attn from .create_conv2d import create_conv2d from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act -from .create_self_attn import get_self_attn, create_self_attn from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn from .evo_norm import EvoNormBatch2d, EvoNormSample2d @@ -24,16 +23,17 @@ from .involution import Involution from .linear import Linear from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp +from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .norm import GroupNorm, LayerNorm2d from .norm_act import BatchNormAct2d, GroupNormAct from .padding import get_padding, get_same_padding, pad_same from .patch_embed import PatchEmbed from .pool2d_same import AvgPool2dSame, create_pool2d from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite -from .selective_kernel import SelectiveKernelConv +from .selective_kernel import SelectiveKernel from .separable_conv import SeparableConv2d, SeparableConvBnAct from .space_to_depth import SpaceToDepthModule -from .split_attn import SplitAttnConv2d +from .split_attn import SplitAttn from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from .test_time_pool import TestTimePoolHead, apply_test_time_pool diff --git a/timm/models/layers/create_attn.py b/timm/models/layers/create_attn.py index de866eea..3fed646b 100644 --- a/timm/models/layers/create_attn.py +++ b/timm/models/layers/create_attn.py @@ -1,14 +1,23 @@ -""" Select AttentionFactory Method +""" Attention Factory -Hacked together by / Copyright 2020 Ross Wightman +Hacked together by / Copyright 2021 Ross Wightman """ import torch +from functools import partial +from .bottleneck_attn import BottleneckAttn from .cbam import CbamModule, LightCbamModule from .eca import EcaModule, CecaModule from .gather_excite import GatherExcite from .global_context import GlobalContext +from .halo_attn import HaloAttn +from .involution import Involution +from .lambda_layer import LambdaLayer +from .non_local_attn import NonLocalAttn, BatNonLocalAttn +from .selective_kernel import SelectiveKernel +from .split_attn import SplitAttn from .squeeze_excite import SEModule, EffectiveSEModule +from .swin_attn import WindowAttention def get_attn(attn_type): @@ -18,12 +27,16 @@ def get_attn(attn_type): if attn_type is not None: if isinstance(attn_type, str): attn_type = attn_type.lower() + # Lightweight attention modules (channel and/or coarse spatial). + # Typically added to existing network architecture blocks in addition to existing convolutions. if attn_type == 'se': module_cls = SEModule elif attn_type == 'ese': module_cls = EffectiveSEModule elif attn_type == 'eca': module_cls = EcaModule + elif attn_type == 'ecam': + module_cls = partial(EcaModule, use_mlp=True) elif attn_type == 'ceca': module_cls = CecaModule elif attn_type == 'ge': @@ -34,6 +47,34 @@ def get_attn(attn_type): module_cls = CbamModule elif attn_type == 'lcbam': module_cls = LightCbamModule + + # Attention / attention-like modules w/ significant params + # Typically replace some of the existing workhorse convs in a network architecture. + # All of these accept a stride argument and can spatially downsample the input. + elif attn_type == 'sk': + module_cls = SelectiveKernel + elif attn_type == 'splat': + module_cls = SplitAttn + + # Self-attention / attention-like modules w/ significant compute and/or params + # Typically replace some of the existing workhorse convs in a network architecture. + # All of these accept a stride argument and can spatially downsample the input. + elif attn_type == 'lambda': + return LambdaLayer + elif attn_type == 'bottleneck': + return BottleneckAttn + elif attn_type == 'halo': + return HaloAttn + elif attn_type == 'swin': + return WindowAttention + elif attn_type == 'involution': + return Involution + elif attn_type == 'nl': + module_cls = NonLocalAttn + elif attn_type == 'bat': + module_cls = BatNonLocalAttn + + # Woops! else: assert False, "Invalid attn module (%s)" % attn_type elif isinstance(attn_type, bool): diff --git a/timm/models/layers/create_self_attn.py b/timm/models/layers/create_self_attn.py deleted file mode 100644 index 448ddb34..00000000 --- a/timm/models/layers/create_self_attn.py +++ /dev/null @@ -1,25 +0,0 @@ -from .bottleneck_attn import BottleneckAttn -from .halo_attn import HaloAttn -from .involution import Involution -from .lambda_layer import LambdaLayer -from .swin_attn import WindowAttention - - -def get_self_attn(attn_type): - if attn_type == 'bottleneck': - return BottleneckAttn - elif attn_type == 'halo': - return HaloAttn - elif attn_type == 'lambda': - return LambdaLayer - elif attn_type == 'swin': - return WindowAttention - elif attn_type == 'involution': - return Involution - else: - assert False, f"Unknown attn type ({attn_type})" - - -def create_self_attn(attn_type, dim, stride=1, **kwargs): - attn_fn = get_self_attn(attn_type) - return attn_fn(dim, stride=stride, **kwargs) diff --git a/timm/models/layers/eca.py b/timm/models/layers/eca.py index 5c024108..e29be6ac 100644 --- a/timm/models/layers/eca.py +++ b/timm/models/layers/eca.py @@ -39,6 +39,7 @@ import torch.nn.functional as F from .create_act import create_act_layer +from .helpers import make_divisible class EcaModule(nn.Module): @@ -56,21 +57,36 @@ class EcaModule(nn.Module): act_layer: optional non-linearity after conv, enables conv bias, this is an experiment gate_layer: gating non-linearity to use """ - def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'): + def __init__( + self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid', + rd_ratio=1/8, rd_channels=None, rd_divisor=8, use_mlp=False): super(EcaModule, self).__init__() if channels is not None: t = int(abs(math.log(channels, 2) + beta) / gamma) kernel_size = max(t if t % 2 else t + 1, 3) assert kernel_size % 2 == 1 - has_act = act_layer is not None - self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=has_act) - self.act = create_act_layer(act_layer) if has_act else nn.Identity() + padding = (kernel_size - 1) // 2 + if use_mlp: + # NOTE 'mlp' mode is a timm experiment, not in paper + assert channels is not None + if rd_channels is None: + rd_channels = make_divisible(channels * rd_ratio, divisor=rd_divisor) + act_layer = act_layer or nn.ReLU + self.conv = nn.Conv1d(1, rd_channels, kernel_size=1, padding=0, bias=True) + self.act = create_act_layer(act_layer) + self.conv2 = nn.Conv1d(rd_channels, 1, kernel_size=kernel_size, padding=padding, bias=True) + else: + self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=padding, bias=False) + self.act = None + self.conv2 = None self.gate = create_act_layer(gate_layer) def forward(self, x): y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv y = self.conv(y) - y = self.act(y) # NOTE: usually a no-op, added for experimentation + if self.conv2 is not None: + y = self.act(y) + y = self.conv2(y) y = self.gate(y).view(x.shape[0], -1, 1, 1) return x * y.expand_as(x) @@ -115,7 +131,6 @@ class CecaModule(nn.Module): # implement manual circular padding self.padding = (kernel_size - 1) // 2 self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act) - self.act = create_act_layer(act_layer) if has_act else nn.Identity() self.gate = create_act_layer(gate_layer) def forward(self, x): @@ -123,7 +138,6 @@ class CecaModule(nn.Module): # Manually implement circular padding, F.pad does not seemed to be bugged y = F.pad(y, (self.padding, self.padding), mode='circular') y = self.conv(y) - y = self.act(y) # NOTE: usually a no-op, added for experimentation y = self.gate(y).view(x.shape[0], -1, 1, 1) return x * y.expand_as(x) diff --git a/timm/models/layers/non_local_attn.py b/timm/models/layers/non_local_attn.py new file mode 100644 index 00000000..d20a5f3e --- /dev/null +++ b/timm/models/layers/non_local_attn.py @@ -0,0 +1,145 @@ +""" Bilinear-Attention-Transform and Non-Local Attention + +Paper: `Non-Local Neural Networks With Grouped Bilinear Attentional Transforms` + - https://openaccess.thecvf.com/content_CVPR_2020/html/Chi_Non-Local_Neural_Networks_With_Grouped_Bilinear_Attentional_Transforms_CVPR_2020_paper.html +Adapted from original code: https://github.com/BA-Transform/BAT-Image-Classification +""" +import torch +from torch import nn +from torch.nn import functional as F + +from .conv_bn_act import ConvBnAct +from .helpers import make_divisible + + +class NonLocalAttn(nn.Module): + """Spatial NL block for image classification. + + This was adapted from https://github.com/BA-Transform/BAT-Image-Classification + Their NonLocal impl inspired by https://github.com/facebookresearch/video-nonlocal-net. + """ + + def __init__(self, in_channels, use_scale=True, rd_ratio=1/8, rd_channels=None, rd_divisor=8, **kwargs): + super(NonLocalAttn, self).__init__() + if rd_channels is None: + rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) + self.scale = in_channels ** -0.5 if use_scale else 1.0 + self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) + self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) + self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True) + self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True) + self.norm = nn.BatchNorm2d(in_channels) + self.reset_parameters() + + def forward(self, x): + shortcut = x + + t = self.t(x) + p = self.p(x) + g = self.g(x) + + B, C, H, W = t.size() + t = t.view(B, C, -1).permute(0, 2, 1) + p = p.view(B, C, -1) + g = g.view(B, C, -1).permute(0, 2, 1) + + att = torch.bmm(t, p) * self.scale + att = F.softmax(att, dim=2) + x = torch.bmm(att, g) + + x = x.permute(0, 2, 1).reshape(B, C, H, W) + x = self.z(x) + x = self.norm(x) + shortcut + + return x + + def reset_parameters(self): + for name, m in self.named_modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + if len(list(m.parameters())) > 1: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.GroupNorm): + nn.init.constant_(m.weight, 0) + nn.init.constant_(m.bias, 0) + + +class BilinearAttnTransform(nn.Module): + + def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): + super(BilinearAttnTransform, self).__init__() + + self.conv1 = ConvBnAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer) + self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1)) + self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size)) + self.conv2 = ConvBnAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.block_size = block_size + self.groups = groups + self.in_channels = in_channels + + def resize_mat(self, x, t): + B, C, block_size, block_size1 = x.shape + assert block_size == block_size1 + if t <= 1: + return x + x = x.view(B * C, -1, 1, 1) + x = x * torch.eye(t, t, dtype=x.dtype, device=x.device) + x = x.view(B * C, block_size, block_size, t, t) + x = torch.cat(torch.split(x, 1, dim=1), dim=3) + x = torch.cat(torch.split(x, 1, dim=2), dim=4) + x = x.view(B, C, block_size * t, block_size * t) + return x + + def forward(self, x): + assert x.shape[-1] % self.block_size == 0 and x.shape[-2] % self.block_size == 0 + B, C, H, W = x.shape + out = self.conv1(x) + rp = F.adaptive_max_pool2d(out, (self.block_size, 1)) + cp = F.adaptive_max_pool2d(out, (1, self.block_size)) + p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size) + q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size) + p = F.sigmoid(p) + q = F.sigmoid(q) + p = p / p.sum(dim=3, keepdim=True) + q = q / q.sum(dim=2, keepdim=True) + p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size( + 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous() + p = p.view(B, C, self.block_size, self.block_size) + q = q.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size( + 0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous() + q = q.view(B, C, self.block_size, self.block_size) + p = self.resize_mat(p, H // self.block_size) + q = self.resize_mat(q, W // self.block_size) + y = p.matmul(x) + y = y.matmul(q) + + y = self.conv2(y) + return y + + +class BatNonLocalAttn(nn.Module): + """ BAT + Adapted from: https://github.com/BA-Transform/BAT-Image-Classification + """ + + def __init__( + self, in_channels, block_size=7, groups=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, + drop_rate=0.2, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, **_): + super().__init__() + if rd_channels is None: + rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) + self.conv1 = ConvBnAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer) + self.conv2 = ConvBnAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) + self.dropout = nn.Dropout2d(p=drop_rate) + + def forward(self, x): + xl = self.conv1(x) + y = self.ba(xl) + y = self.conv2(y) + y = self.dropout(y) + return y + x diff --git a/timm/models/layers/selective_kernel.py b/timm/models/layers/selective_kernel.py index 10bfd0e0..246f72a6 100644 --- a/timm/models/layers/selective_kernel.py +++ b/timm/models/layers/selective_kernel.py @@ -8,6 +8,7 @@ import torch from torch import nn as nn from .conv_bn_act import ConvBnAct +from .helpers import make_divisible def _kernel_valid(k): @@ -45,10 +46,10 @@ class SelectiveKernelAttn(nn.Module): return x -class SelectiveKernelConv(nn.Module): +class SelectiveKernel(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1, - attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False, + def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1, + rd_ratio=1./16, rd_channels=None, min_rd_channels=16, rd_divisor=8, keep_3x3=True, split_input=True, drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None): """ Selective Kernel Convolution Module @@ -66,8 +67,8 @@ class SelectiveKernelConv(nn.Module): stride (int): stride for convolutions dilation (int): dilation for module as a whole, impacts dilation of each branch groups (int): number of groups for each branch - attn_reduction (int, float): reduction factor for attention features - min_attn_channels (int): minimum attention feature channels + rd_ratio (int, float): reduction factor for attention features + min_rd_channels (int): minimum attention feature channels keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations split_input (bool): split input channels evenly across each convolution branch, keeps param count lower, can be viewed as grouping by path, output expands to module out_channels count @@ -75,7 +76,8 @@ class SelectiveKernelConv(nn.Module): act_layer (nn.Module): activation layer to use norm_layer (nn.Module): batchnorm/norm layer to use """ - super(SelectiveKernelConv, self).__init__() + super(SelectiveKernel, self).__init__() + out_channels = out_channels or in_channels kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation _kernel_valid(kernel_size) if not isinstance(kernel_size, list): @@ -101,7 +103,8 @@ class SelectiveKernelConv(nn.Module): ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) for k, d in zip(kernel_size, dilation)]) - attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) + attn_channels = rd_channels or make_divisible( + out_channels * rd_ratio, min_value=min_rd_channels, divisor=rd_divisor) self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) self.drop_block = drop_block diff --git a/timm/models/layers/split_attn.py b/timm/models/layers/split_attn.py index 5615aa0b..dde601be 100644 --- a/timm/models/layers/split_attn.py +++ b/timm/models/layers/split_attn.py @@ -10,6 +10,8 @@ import torch import torch.nn.functional as F from torch import nn +from .helpers import make_divisible + class RadixSoftmax(nn.Module): def __init__(self, radix, cardinality): @@ -28,41 +30,37 @@ class RadixSoftmax(nn.Module): return x -class SplitAttnConv2d(nn.Module): - """Split-Attention Conv2d +class SplitAttn(nn.Module): + """Split-Attention (aka Splat) """ - def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, - dilation=1, groups=1, bias=False, radix=2, reduction_factor=4, + def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, + dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs): - super(SplitAttnConv2d, self).__init__() + super(SplitAttn, self).__init__() + out_channels = out_channels or in_channels self.radix = radix self.drop_block = drop_block mid_chs = out_channels * radix - attn_chs = max(in_channels * radix // reduction_factor, 32) + if rd_channels is None: + attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor) + else: + attn_chs = rd_channels * radix + padding = kernel_size // 2 if padding is None else padding self.conv = nn.Conv2d( in_channels, mid_chs, kernel_size, stride, padding, dilation, groups=groups * radix, bias=bias, **kwargs) - self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None + self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() self.act0 = act_layer(inplace=True) self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) - self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None + self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() self.act1 = act_layer(inplace=True) self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) self.rsoftmax = RadixSoftmax(radix, groups) - @property - def in_channels(self): - return self.conv.in_channels - - @property - def out_channels(self): - return self.fc1.out_channels - def forward(self, x): x = self.conv(x) - if self.bn0 is not None: - x = self.bn0(x) + x = self.bn0(x) if self.drop_block is not None: x = self.drop_block(x) x = self.act0(x) @@ -73,10 +71,9 @@ class SplitAttnConv2d(nn.Module): x_gap = x.sum(dim=1) else: x_gap = x - x_gap = F.adaptive_avg_pool2d(x_gap, 1) + x_gap = x_gap.mean((2, 3), keepdim=True) x_gap = self.fc1(x_gap) - if self.bn1 is not None: - x_gap = self.bn1(x_gap) + x_gap = self.bn1(x_gap) x_gap = self.act1(x_gap) x_attn = self.fc2(x_gap) diff --git a/timm/models/layers/squeeze_excite.py b/timm/models/layers/squeeze_excite.py index 3e8a05bb..e5da29ef 100644 --- a/timm/models/layers/squeeze_excite.py +++ b/timm/models/layers/squeeze_excite.py @@ -56,7 +56,7 @@ class EffectiveSEModule(nn.Module): """ 'Effective Squeeze-Excitation From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 """ - def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid'): + def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_): super(EffectiveSEModule, self).__init__() self.add_maxpool = add_maxpool self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) diff --git a/timm/models/resnest.py b/timm/models/resnest.py index ac3b2559..31eebd80 100644 --- a/timm/models/resnest.py +++ b/timm/models/resnest.py @@ -11,7 +11,7 @@ from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import SplitAttnConv2d +from .layers import SplitAttn from .registry import register_model from .resnet import ResNet @@ -83,11 +83,11 @@ class ResNestBottleneck(nn.Module): self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None if self.radix >= 1: - self.conv2 = SplitAttnConv2d( + self.conv2 = SplitAttn( group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block) - self.bn2 = None # FIXME revisit, here to satisfy current torchscript fussyness - self.act2 = None + self.bn2 = nn.Identity() + self.act2 = nn.Identity() else: self.conv2 = nn.Conv2d( group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, @@ -117,11 +117,10 @@ class ResNestBottleneck(nn.Module): out = self.avd_first(out) out = self.conv2(out) - if self.bn2 is not None: - out = self.bn2(out) - if self.drop_block is not None: - out = self.drop_block(out) - out = self.act2(out) + out = self.bn2(out) + if self.drop_block is not None: + out = self.drop_block(out) + out = self.act2(out) if self.avd_last is not None: out = self.avd_last(out) diff --git a/timm/models/sknet.py b/timm/models/sknet.py index eb7ad8c3..82ca5bfe 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -14,7 +14,7 @@ from torch import nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import SelectiveKernelConv, ConvBnAct, create_attn +from .layers import SelectiveKernel, ConvBnAct, create_attn from .registry import register_model from .resnet import ResNet @@ -59,7 +59,7 @@ class SelectiveKernelBasic(nn.Module): outplanes = planes * self.expansion first_dilation = first_dilation or dilation - self.conv1 = SelectiveKernelConv( + self.conv1 = SelectiveKernel( inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs) conv_kwargs['act_layer'] = None self.conv2 = ConvBnAct( @@ -107,7 +107,7 @@ class SelectiveKernelBottleneck(nn.Module): first_dilation = first_dilation or dilation self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs) - self.conv2 = SelectiveKernelConv( + self.conv2 = SelectiveKernel( first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, **conv_kwargs, **sk_kwargs) conv_kwargs['act_layer'] = None @@ -153,10 +153,7 @@ def skresnet18(pretrained=False, **kwargs): Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this variation splits the input channels to the selective convolutions to keep param count down. """ - sk_kwargs = dict( - min_attn_channels=16, - attn_reduction=8, - split_input=True) + sk_kwargs = dict(min_rd_channels=16, rd_ratio=1/8, split_input=True) model_args = dict( block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs) @@ -170,10 +167,7 @@ def skresnet34(pretrained=False, **kwargs): Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this variation splits the input channels to the selective convolutions to keep param count down. """ - sk_kwargs = dict( - min_attn_channels=16, - attn_reduction=8, - split_input=True) + sk_kwargs = dict(min_rd_channels=16, rd_ratio=1/8, split_input=True) model_args = dict( block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs), zero_init_last_bn=False, **kwargs)