Add non-local and BAT attention. Merge attn and self-attn factories into one. Add attention references to README. Add mlp 'mode' to ECA.

more_attn
Ross Wightman 4 years ago
parent 17dc47c8e6
commit 307a935b79

@ -295,10 +295,24 @@ Several (less common) features that I often utilize in my projects are included.
* SplitBachNorm - allows splitting batch norm layers between clean and augmented (auxiliary batch norm) data * SplitBachNorm - allows splitting batch norm layers between clean and augmented (auxiliary batch norm) data
* DropPath aka "Stochastic Depth" (https://arxiv.org/abs/1603.09382) * DropPath aka "Stochastic Depth" (https://arxiv.org/abs/1603.09382)
* DropBlock (https://arxiv.org/abs/1810.12890) * DropBlock (https://arxiv.org/abs/1810.12890)
* Efficient Channel Attention - ECA (https://arxiv.org/abs/1910.03151)
* Blur Pooling (https://arxiv.org/abs/1904.11486) * Blur Pooling (https://arxiv.org/abs/1904.11486)
* Space-to-Depth by [mrT23](https://github.com/mrT23/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py) (https://arxiv.org/abs/1801.04590) -- original paper? * Space-to-Depth by [mrT23](https://github.com/mrT23/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py) (https://arxiv.org/abs/1801.04590) -- original paper?
* Adaptive Gradient Clipping (https://arxiv.org/abs/2102.06171, https://github.com/deepmind/deepmind-research/tree/master/nfnets) * Adaptive Gradient Clipping (https://arxiv.org/abs/2102.06171, https://github.com/deepmind/deepmind-research/tree/master/nfnets)
* An extensive selection of channel and/or spatial attention modules:
* Bottleneck Transformer - https://arxiv.org/abs/2101.11605
* CBAM - https://arxiv.org/abs/1807.06521
* Effective Squeeze-Excitation (ESE) - https://arxiv.org/abs/1911.06667
* Efficient Channel Attention (ECA) - https://arxiv.org/abs/1910.03151
* Gather-Excite (GE) - https://arxiv.org/abs/1810.12348
* Global Context (GC) - https://arxiv.org/abs/1904.11492
* Halo - https://arxiv.org/abs/2103.12731
* Involution - https://arxiv.org/abs/2103.06255
* Lambda Layer - https://arxiv.org/abs/2102.08602
* Non-Local (NL) - https://arxiv.org/abs/1711.07971
* Squeeze-and-Excitation (SE) - https://arxiv.org/abs/1709.01507
* Selective Kernel (SK) - (https://arxiv.org/abs/1903.06586
* Split (SPLAT) - https://arxiv.org/abs/2004.08955
* Shifted Window (SWIN) - https://arxiv.org/abs/2103.14030
## Results ## Results

@ -35,7 +35,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
create_conv2d, get_act_layer, convert_norm_act, get_attn, get_self_attn, make_divisible, to_2tuple create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple
from .registry import register_model from .registry import register_model
__all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block'] __all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block']
@ -935,7 +935,7 @@ def update_block_kwargs(block_kwargs: Dict[str, Any], block_cfg: ByoBlockCfg, mo
else: else:
self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs) self_attn_kwargs = override_kwargs(block_cfg.self_attn_kwargs, model_cfg.self_attn_kwargs)
self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer self_attn_layer = block_cfg.self_attn_layer or model_cfg.self_attn_layer
self_attn_layer = partial(get_self_attn(self_attn_layer), *self_attn_kwargs) \ self_attn_layer = partial(get_attn(self_attn_layer), *self_attn_kwargs) \
if self_attn_layer is not None else None if self_attn_layer is not None else None
layer_fns = replace(layer_fns, self_attn=self_attn_layer) layer_fns = replace(layer_fns, self_attn=self_attn_layer)
@ -1010,7 +1010,7 @@ def get_layer_fns(cfg: ByoModelCfg):
norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act) norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act)
conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act) conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act)
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
self_attn = partial(get_self_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None self_attn = partial(get_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None
layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn) layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn)
return layer_fn return layer_fn

@ -1234,7 +1234,8 @@ def eca_efficientnet_b0(pretrained=False, **kwargs):
""" EfficientNet-B0 w/ ECA attn """ """ EfficientNet-B0 w/ ECA attn """
# NOTE experimental config # NOTE experimental config
model = _gen_efficientnet( model = _gen_efficientnet(
'eca_efficientnet_b0', se_layer='eca', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) 'eca_efficientnet_b0', se_layer='ecam', channel_multiplier=1.0, depth_multiplier=1.0,
pretrained=pretrained, **kwargs)
return model return model
@ -1243,7 +1244,8 @@ def gc_efficientnet_b0(pretrained=False, **kwargs):
""" EfficientNet-B0 w/ GlobalContext """ """ EfficientNet-B0 w/ GlobalContext """
# NOTE experminetal config # NOTE experminetal config
model = _gen_efficientnet( model = _gen_efficientnet(
'gc_efficientnet_b0', se_layer='gc', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) 'gc_efficientnet_b0', se_layer='gc', channel_multiplier=1.0, depth_multiplier=1.0,
pretrained=pretrained, **kwargs)
return model return model

@ -12,7 +12,6 @@ from .create_act import create_act_layer, get_act_layer, get_act_fn
from .create_attn import get_attn, create_attn from .create_attn import get_attn, create_attn
from .create_conv2d import create_conv2d from .create_conv2d import create_conv2d
from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act
from .create_self_attn import get_self_attn, create_self_attn
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
from .evo_norm import EvoNormBatch2d, EvoNormSample2d from .evo_norm import EvoNormBatch2d, EvoNormSample2d
@ -24,16 +23,17 @@ from .involution import Involution
from .linear import Linear from .linear import Linear
from .mixed_conv2d import MixedConv2d from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp from .mlp import Mlp, GluMlp, GatedMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, LayerNorm2d from .norm import GroupNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct from .norm_act import BatchNormAct2d, GroupNormAct
from .padding import get_padding, get_same_padding, pad_same from .padding import get_padding, get_same_padding, pad_same
from .patch_embed import PatchEmbed from .patch_embed import PatchEmbed
from .pool2d_same import AvgPool2dSame, create_pool2d from .pool2d_same import AvgPool2dSame, create_pool2d
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernelConv from .selective_kernel import SelectiveKernel
from .separable_conv import SeparableConv2d, SeparableConvBnAct from .separable_conv import SeparableConv2d, SeparableConvBnAct
from .space_to_depth import SpaceToDepthModule from .space_to_depth import SpaceToDepthModule
from .split_attn import SplitAttnConv2d from .split_attn import SplitAttn
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .test_time_pool import TestTimePoolHead, apply_test_time_pool

@ -1,14 +1,23 @@
""" Select AttentionFactory Method """ Attention Factory
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2021 Ross Wightman
""" """
import torch import torch
from functools import partial
from .bottleneck_attn import BottleneckAttn
from .cbam import CbamModule, LightCbamModule from .cbam import CbamModule, LightCbamModule
from .eca import EcaModule, CecaModule from .eca import EcaModule, CecaModule
from .gather_excite import GatherExcite from .gather_excite import GatherExcite
from .global_context import GlobalContext from .global_context import GlobalContext
from .halo_attn import HaloAttn
from .involution import Involution
from .lambda_layer import LambdaLayer
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .selective_kernel import SelectiveKernel
from .split_attn import SplitAttn
from .squeeze_excite import SEModule, EffectiveSEModule from .squeeze_excite import SEModule, EffectiveSEModule
from .swin_attn import WindowAttention
def get_attn(attn_type): def get_attn(attn_type):
@ -18,12 +27,16 @@ def get_attn(attn_type):
if attn_type is not None: if attn_type is not None:
if isinstance(attn_type, str): if isinstance(attn_type, str):
attn_type = attn_type.lower() attn_type = attn_type.lower()
# Lightweight attention modules (channel and/or coarse spatial).
# Typically added to existing network architecture blocks in addition to existing convolutions.
if attn_type == 'se': if attn_type == 'se':
module_cls = SEModule module_cls = SEModule
elif attn_type == 'ese': elif attn_type == 'ese':
module_cls = EffectiveSEModule module_cls = EffectiveSEModule
elif attn_type == 'eca': elif attn_type == 'eca':
module_cls = EcaModule module_cls = EcaModule
elif attn_type == 'ecam':
module_cls = partial(EcaModule, use_mlp=True)
elif attn_type == 'ceca': elif attn_type == 'ceca':
module_cls = CecaModule module_cls = CecaModule
elif attn_type == 'ge': elif attn_type == 'ge':
@ -34,6 +47,34 @@ def get_attn(attn_type):
module_cls = CbamModule module_cls = CbamModule
elif attn_type == 'lcbam': elif attn_type == 'lcbam':
module_cls = LightCbamModule module_cls = LightCbamModule
# Attention / attention-like modules w/ significant params
# Typically replace some of the existing workhorse convs in a network architecture.
# All of these accept a stride argument and can spatially downsample the input.
elif attn_type == 'sk':
module_cls = SelectiveKernel
elif attn_type == 'splat':
module_cls = SplitAttn
# Self-attention / attention-like modules w/ significant compute and/or params
# Typically replace some of the existing workhorse convs in a network architecture.
# All of these accept a stride argument and can spatially downsample the input.
elif attn_type == 'lambda':
return LambdaLayer
elif attn_type == 'bottleneck':
return BottleneckAttn
elif attn_type == 'halo':
return HaloAttn
elif attn_type == 'swin':
return WindowAttention
elif attn_type == 'involution':
return Involution
elif attn_type == 'nl':
module_cls = NonLocalAttn
elif attn_type == 'bat':
module_cls = BatNonLocalAttn
# Woops!
else: else:
assert False, "Invalid attn module (%s)" % attn_type assert False, "Invalid attn module (%s)" % attn_type
elif isinstance(attn_type, bool): elif isinstance(attn_type, bool):

@ -1,25 +0,0 @@
from .bottleneck_attn import BottleneckAttn
from .halo_attn import HaloAttn
from .involution import Involution
from .lambda_layer import LambdaLayer
from .swin_attn import WindowAttention
def get_self_attn(attn_type):
if attn_type == 'bottleneck':
return BottleneckAttn
elif attn_type == 'halo':
return HaloAttn
elif attn_type == 'lambda':
return LambdaLayer
elif attn_type == 'swin':
return WindowAttention
elif attn_type == 'involution':
return Involution
else:
assert False, f"Unknown attn type ({attn_type})"
def create_self_attn(attn_type, dim, stride=1, **kwargs):
attn_fn = get_self_attn(attn_type)
return attn_fn(dim, stride=stride, **kwargs)

@ -39,6 +39,7 @@ import torch.nn.functional as F
from .create_act import create_act_layer from .create_act import create_act_layer
from .helpers import make_divisible
class EcaModule(nn.Module): class EcaModule(nn.Module):
@ -56,21 +57,36 @@ class EcaModule(nn.Module):
act_layer: optional non-linearity after conv, enables conv bias, this is an experiment act_layer: optional non-linearity after conv, enables conv bias, this is an experiment
gate_layer: gating non-linearity to use gate_layer: gating non-linearity to use
""" """
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid'): def __init__(
self, channels=None, kernel_size=3, gamma=2, beta=1, act_layer=None, gate_layer='sigmoid',
rd_ratio=1/8, rd_channels=None, rd_divisor=8, use_mlp=False):
super(EcaModule, self).__init__() super(EcaModule, self).__init__()
if channels is not None: if channels is not None:
t = int(abs(math.log(channels, 2) + beta) / gamma) t = int(abs(math.log(channels, 2) + beta) / gamma)
kernel_size = max(t if t % 2 else t + 1, 3) kernel_size = max(t if t % 2 else t + 1, 3)
assert kernel_size % 2 == 1 assert kernel_size % 2 == 1
has_act = act_layer is not None padding = (kernel_size - 1) // 2
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=has_act) if use_mlp:
self.act = create_act_layer(act_layer) if has_act else nn.Identity() # NOTE 'mlp' mode is a timm experiment, not in paper
assert channels is not None
if rd_channels is None:
rd_channels = make_divisible(channels * rd_ratio, divisor=rd_divisor)
act_layer = act_layer or nn.ReLU
self.conv = nn.Conv1d(1, rd_channels, kernel_size=1, padding=0, bias=True)
self.act = create_act_layer(act_layer)
self.conv2 = nn.Conv1d(rd_channels, 1, kernel_size=kernel_size, padding=padding, bias=True)
else:
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=padding, bias=False)
self.act = None
self.conv2 = None
self.gate = create_act_layer(gate_layer) self.gate = create_act_layer(gate_layer)
def forward(self, x): def forward(self, x):
y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv
y = self.conv(y) y = self.conv(y)
y = self.act(y) # NOTE: usually a no-op, added for experimentation if self.conv2 is not None:
y = self.act(y)
y = self.conv2(y)
y = self.gate(y).view(x.shape[0], -1, 1, 1) y = self.gate(y).view(x.shape[0], -1, 1, 1)
return x * y.expand_as(x) return x * y.expand_as(x)
@ -115,7 +131,6 @@ class CecaModule(nn.Module):
# implement manual circular padding # implement manual circular padding
self.padding = (kernel_size - 1) // 2 self.padding = (kernel_size - 1) // 2
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act) self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=has_act)
self.act = create_act_layer(act_layer) if has_act else nn.Identity()
self.gate = create_act_layer(gate_layer) self.gate = create_act_layer(gate_layer)
def forward(self, x): def forward(self, x):
@ -123,7 +138,6 @@ class CecaModule(nn.Module):
# Manually implement circular padding, F.pad does not seemed to be bugged # Manually implement circular padding, F.pad does not seemed to be bugged
y = F.pad(y, (self.padding, self.padding), mode='circular') y = F.pad(y, (self.padding, self.padding), mode='circular')
y = self.conv(y) y = self.conv(y)
y = self.act(y) # NOTE: usually a no-op, added for experimentation
y = self.gate(y).view(x.shape[0], -1, 1, 1) y = self.gate(y).view(x.shape[0], -1, 1, 1)
return x * y.expand_as(x) return x * y.expand_as(x)

@ -0,0 +1,145 @@
""" Bilinear-Attention-Transform and Non-Local Attention
Paper: `Non-Local Neural Networks With Grouped Bilinear Attentional Transforms`
- https://openaccess.thecvf.com/content_CVPR_2020/html/Chi_Non-Local_Neural_Networks_With_Grouped_Bilinear_Attentional_Transforms_CVPR_2020_paper.html
Adapted from original code: https://github.com/BA-Transform/BAT-Image-Classification
"""
import torch
from torch import nn
from torch.nn import functional as F
from .conv_bn_act import ConvBnAct
from .helpers import make_divisible
class NonLocalAttn(nn.Module):
"""Spatial NL block for image classification.
This was adapted from https://github.com/BA-Transform/BAT-Image-Classification
Their NonLocal impl inspired by https://github.com/facebookresearch/video-nonlocal-net.
"""
def __init__(self, in_channels, use_scale=True, rd_ratio=1/8, rd_channels=None, rd_divisor=8, **kwargs):
super(NonLocalAttn, self).__init__()
if rd_channels is None:
rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)
self.scale = in_channels ** -0.5 if use_scale else 1.0
self.t = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
self.p = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
self.g = nn.Conv2d(in_channels, rd_channels, kernel_size=1, stride=1, bias=True)
self.z = nn.Conv2d(rd_channels, in_channels, kernel_size=1, stride=1, bias=True)
self.norm = nn.BatchNorm2d(in_channels)
self.reset_parameters()
def forward(self, x):
shortcut = x
t = self.t(x)
p = self.p(x)
g = self.g(x)
B, C, H, W = t.size()
t = t.view(B, C, -1).permute(0, 2, 1)
p = p.view(B, C, -1)
g = g.view(B, C, -1).permute(0, 2, 1)
att = torch.bmm(t, p) * self.scale
att = F.softmax(att, dim=2)
x = torch.bmm(att, g)
x = x.permute(0, 2, 1).reshape(B, C, H, W)
x = self.z(x)
x = self.norm(x) + shortcut
return x
def reset_parameters(self):
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
if len(list(m.parameters())) > 1:
nn.init.constant_(m.bias, 0.0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 0)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.weight, 0)
nn.init.constant_(m.bias, 0)
class BilinearAttnTransform(nn.Module):
def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(BilinearAttnTransform, self).__init__()
self.conv1 = ConvBnAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer)
self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1))
self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size))
self.conv2 = ConvBnAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
self.block_size = block_size
self.groups = groups
self.in_channels = in_channels
def resize_mat(self, x, t):
B, C, block_size, block_size1 = x.shape
assert block_size == block_size1
if t <= 1:
return x
x = x.view(B * C, -1, 1, 1)
x = x * torch.eye(t, t, dtype=x.dtype, device=x.device)
x = x.view(B * C, block_size, block_size, t, t)
x = torch.cat(torch.split(x, 1, dim=1), dim=3)
x = torch.cat(torch.split(x, 1, dim=2), dim=4)
x = x.view(B, C, block_size * t, block_size * t)
return x
def forward(self, x):
assert x.shape[-1] % self.block_size == 0 and x.shape[-2] % self.block_size == 0
B, C, H, W = x.shape
out = self.conv1(x)
rp = F.adaptive_max_pool2d(out, (self.block_size, 1))
cp = F.adaptive_max_pool2d(out, (1, self.block_size))
p = self.conv_p(rp).view(B, self.groups, self.block_size, self.block_size)
q = self.conv_q(cp).view(B, self.groups, self.block_size, self.block_size)
p = F.sigmoid(p)
q = F.sigmoid(q)
p = p / p.sum(dim=3, keepdim=True)
q = q / q.sum(dim=2, keepdim=True)
p = p.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(
0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous()
p = p.view(B, C, self.block_size, self.block_size)
q = q.view(B, self.groups, 1, self.block_size, self.block_size).expand(x.size(
0), self.groups, C // self.groups, self.block_size, self.block_size).contiguous()
q = q.view(B, C, self.block_size, self.block_size)
p = self.resize_mat(p, H // self.block_size)
q = self.resize_mat(q, W // self.block_size)
y = p.matmul(x)
y = y.matmul(q)
y = self.conv2(y)
return y
class BatNonLocalAttn(nn.Module):
""" BAT
Adapted from: https://github.com/BA-Transform/BAT-Image-Classification
"""
def __init__(
self, in_channels, block_size=7, groups=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
drop_rate=0.2, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, **_):
super().__init__()
if rd_channels is None:
rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)
self.conv1 = ConvBnAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer)
self.conv2 = ConvBnAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
self.dropout = nn.Dropout2d(p=drop_rate)
def forward(self, x):
xl = self.conv1(x)
y = self.ba(xl)
y = self.conv2(y)
y = self.dropout(y)
return y + x

@ -8,6 +8,7 @@ import torch
from torch import nn as nn from torch import nn as nn
from .conv_bn_act import ConvBnAct from .conv_bn_act import ConvBnAct
from .helpers import make_divisible
def _kernel_valid(k): def _kernel_valid(k):
@ -45,10 +46,10 @@ class SelectiveKernelAttn(nn.Module):
return x return x
class SelectiveKernelConv(nn.Module): class SelectiveKernel(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=None, stride=1, dilation=1, groups=1, def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1,
attn_reduction=16, min_attn_channels=32, keep_3x3=True, split_input=False, rd_ratio=1./16, rd_channels=None, min_rd_channels=16, rd_divisor=8, keep_3x3=True, split_input=True,
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None): drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None):
""" Selective Kernel Convolution Module """ Selective Kernel Convolution Module
@ -66,8 +67,8 @@ class SelectiveKernelConv(nn.Module):
stride (int): stride for convolutions stride (int): stride for convolutions
dilation (int): dilation for module as a whole, impacts dilation of each branch dilation (int): dilation for module as a whole, impacts dilation of each branch
groups (int): number of groups for each branch groups (int): number of groups for each branch
attn_reduction (int, float): reduction factor for attention features rd_ratio (int, float): reduction factor for attention features
min_attn_channels (int): minimum attention feature channels min_rd_channels (int): minimum attention feature channels
keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations
split_input (bool): split input channels evenly across each convolution branch, keeps param count lower, split_input (bool): split input channels evenly across each convolution branch, keeps param count lower,
can be viewed as grouping by path, output expands to module out_channels count can be viewed as grouping by path, output expands to module out_channels count
@ -75,7 +76,8 @@ class SelectiveKernelConv(nn.Module):
act_layer (nn.Module): activation layer to use act_layer (nn.Module): activation layer to use
norm_layer (nn.Module): batchnorm/norm layer to use norm_layer (nn.Module): batchnorm/norm layer to use
""" """
super(SelectiveKernelConv, self).__init__() super(SelectiveKernel, self).__init__()
out_channels = out_channels or in_channels
kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation kernel_size = kernel_size or [3, 5] # default to one 3x3 and one 5x5 branch. 5x5 -> 3x3 + dilation
_kernel_valid(kernel_size) _kernel_valid(kernel_size)
if not isinstance(kernel_size, list): if not isinstance(kernel_size, list):
@ -101,7 +103,8 @@ class SelectiveKernelConv(nn.Module):
ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
for k, d in zip(kernel_size, dilation)]) for k, d in zip(kernel_size, dilation)])
attn_channels = max(int(out_channels / attn_reduction), min_attn_channels) attn_channels = rd_channels or make_divisible(
out_channels * rd_ratio, min_value=min_rd_channels, divisor=rd_divisor)
self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
self.drop_block = drop_block self.drop_block = drop_block

@ -10,6 +10,8 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from .helpers import make_divisible
class RadixSoftmax(nn.Module): class RadixSoftmax(nn.Module):
def __init__(self, radix, cardinality): def __init__(self, radix, cardinality):
@ -28,41 +30,37 @@ class RadixSoftmax(nn.Module):
return x return x
class SplitAttnConv2d(nn.Module): class SplitAttn(nn.Module):
"""Split-Attention Conv2d """Split-Attention (aka Splat)
""" """
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None,
dilation=1, groups=1, bias=False, radix=2, reduction_factor=4, dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs): act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs):
super(SplitAttnConv2d, self).__init__() super(SplitAttn, self).__init__()
out_channels = out_channels or in_channels
self.radix = radix self.radix = radix
self.drop_block = drop_block self.drop_block = drop_block
mid_chs = out_channels * radix mid_chs = out_channels * radix
attn_chs = max(in_channels * radix // reduction_factor, 32) if rd_channels is None:
attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor)
else:
attn_chs = rd_channels * radix
padding = kernel_size // 2 if padding is None else padding
self.conv = nn.Conv2d( self.conv = nn.Conv2d(
in_channels, mid_chs, kernel_size, stride, padding, dilation, in_channels, mid_chs, kernel_size, stride, padding, dilation,
groups=groups * radix, bias=bias, **kwargs) groups=groups * radix, bias=bias, **kwargs)
self.bn0 = norm_layer(mid_chs) if norm_layer is not None else None self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity()
self.act0 = act_layer(inplace=True) self.act0 = act_layer(inplace=True)
self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)
self.bn1 = norm_layer(attn_chs) if norm_layer is not None else None self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity()
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups) self.fc2 = nn.Conv2d(attn_chs, mid_chs, 1, groups=groups)
self.rsoftmax = RadixSoftmax(radix, groups) self.rsoftmax = RadixSoftmax(radix, groups)
@property
def in_channels(self):
return self.conv.in_channels
@property
def out_channels(self):
return self.fc1.out_channels
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
if self.bn0 is not None: x = self.bn0(x)
x = self.bn0(x)
if self.drop_block is not None: if self.drop_block is not None:
x = self.drop_block(x) x = self.drop_block(x)
x = self.act0(x) x = self.act0(x)
@ -73,10 +71,9 @@ class SplitAttnConv2d(nn.Module):
x_gap = x.sum(dim=1) x_gap = x.sum(dim=1)
else: else:
x_gap = x x_gap = x
x_gap = F.adaptive_avg_pool2d(x_gap, 1) x_gap = x_gap.mean((2, 3), keepdim=True)
x_gap = self.fc1(x_gap) x_gap = self.fc1(x_gap)
if self.bn1 is not None: x_gap = self.bn1(x_gap)
x_gap = self.bn1(x_gap)
x_gap = self.act1(x_gap) x_gap = self.act1(x_gap)
x_attn = self.fc2(x_gap) x_attn = self.fc2(x_gap)

@ -56,7 +56,7 @@ class EffectiveSEModule(nn.Module):
""" 'Effective Squeeze-Excitation """ 'Effective Squeeze-Excitation
From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
""" """
def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid'): def __init__(self, channels, add_maxpool=False, gate_layer='hard_sigmoid', **_):
super(EffectiveSEModule, self).__init__() super(EffectiveSEModule, self).__init__()
self.add_maxpool = add_maxpool self.add_maxpool = add_maxpool
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)

@ -11,7 +11,7 @@ from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import SplitAttnConv2d from .layers import SplitAttn
from .registry import register_model from .registry import register_model
from .resnet import ResNet from .resnet import ResNet
@ -83,11 +83,11 @@ class ResNestBottleneck(nn.Module):
self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None self.avd_first = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and avd_first else None
if self.radix >= 1: if self.radix >= 1:
self.conv2 = SplitAttnConv2d( self.conv2 = SplitAttn(
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block) dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block)
self.bn2 = None # FIXME revisit, here to satisfy current torchscript fussyness self.bn2 = nn.Identity()
self.act2 = None self.act2 = nn.Identity()
else: else:
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
@ -117,11 +117,10 @@ class ResNestBottleneck(nn.Module):
out = self.avd_first(out) out = self.avd_first(out)
out = self.conv2(out) out = self.conv2(out)
if self.bn2 is not None: out = self.bn2(out)
out = self.bn2(out) if self.drop_block is not None:
if self.drop_block is not None: out = self.drop_block(out)
out = self.drop_block(out) out = self.act2(out)
out = self.act2(out)
if self.avd_last is not None: if self.avd_last is not None:
out = self.avd_last(out) out = self.avd_last(out)

@ -14,7 +14,7 @@ from torch import nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import SelectiveKernelConv, ConvBnAct, create_attn from .layers import SelectiveKernel, ConvBnAct, create_attn
from .registry import register_model from .registry import register_model
from .resnet import ResNet from .resnet import ResNet
@ -59,7 +59,7 @@ class SelectiveKernelBasic(nn.Module):
outplanes = planes * self.expansion outplanes = planes * self.expansion
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
self.conv1 = SelectiveKernelConv( self.conv1 = SelectiveKernel(
inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs) inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs)
conv_kwargs['act_layer'] = None conv_kwargs['act_layer'] = None
self.conv2 = ConvBnAct( self.conv2 = ConvBnAct(
@ -107,7 +107,7 @@ class SelectiveKernelBottleneck(nn.Module):
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs) self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs)
self.conv2 = SelectiveKernelConv( self.conv2 = SelectiveKernel(
first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality,
**conv_kwargs, **sk_kwargs) **conv_kwargs, **sk_kwargs)
conv_kwargs['act_layer'] = None conv_kwargs['act_layer'] = None
@ -153,10 +153,7 @@ def skresnet18(pretrained=False, **kwargs):
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
variation splits the input channels to the selective convolutions to keep param count down. variation splits the input channels to the selective convolutions to keep param count down.
""" """
sk_kwargs = dict( sk_kwargs = dict(min_rd_channels=16, rd_ratio=1/8, split_input=True)
min_attn_channels=16,
attn_reduction=8,
split_input=True)
model_args = dict( model_args = dict(
block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs), block=SelectiveKernelBasic, layers=[2, 2, 2, 2], block_args=dict(sk_kwargs=sk_kwargs),
zero_init_last_bn=False, **kwargs) zero_init_last_bn=False, **kwargs)
@ -170,10 +167,7 @@ def skresnet34(pretrained=False, **kwargs):
Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this Different from configs in Select Kernel paper or "Compounding the Performance Improvements..." this
variation splits the input channels to the selective convolutions to keep param count down. variation splits the input channels to the selective convolutions to keep param count down.
""" """
sk_kwargs = dict( sk_kwargs = dict(min_rd_channels=16, rd_ratio=1/8, split_input=True)
min_attn_channels=16,
attn_reduction=8,
split_input=True)
model_args = dict( model_args = dict(
block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs), block=SelectiveKernelBasic, layers=[3, 4, 6, 3], block_args=dict(sk_kwargs=sk_kwargs),
zero_init_last_bn=False, **kwargs) zero_init_last_bn=False, **kwargs)

Loading…
Cancel
Save