Implement Eca modules

implement ECA module by
1. adopting original eca_module.py into models folder
2. adding use_eca layer besides every instance of SE layer
pull/82/head
Chris Ha 5 years ago
parent 697e05cb3e
commit f87fcd7e88

1
.gitignore vendored

@ -105,3 +105,4 @@ venv.bak/
*.pth
*.gz
Untitled.ipynb
Testing notebook.ipynb

@ -0,0 +1,110 @@
'''
ECA module from ECAnet
original paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks
https://arxiv.org/abs/1910.03151
https://github.com/BangguWu/ECANet
original ECA model borrowed from original github
modified circular ECA implementation and
adoptation for use in pytorch image models package
by Chris Ha https://github.com/VRandme
MIT License
Copyright (c) 2019 BangguWu, Qilong Wang
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
'''
import torch
from torch import nn
from torch.nn.parameter import Parameter
class eca_layer(nn.Module):
"""Constructs a ECA module.
Args:
channel: Number of channels of the input feature map
k_size: Adaptive selection of kernel size
"""
def __init__(self, channel, k_size=3):
super(eca_layer, self).__init__()
assert k_size % 2 == 1
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# x: input features with shape [b, c, h, w]
b, c, h, w = x.size()
# feature descriptor on the global spatial information
y = self.avg_pool(x)
# Two different branches of ECA module
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
# Multi-scale information fusion
y = self.sigmoid(y)
return x * y.expand_as(x)
class ceca_layer(nn.Module):
"""Constructs a circular ECA module.
the primary difference is that the conv uses a circular padding rather than zero padding.
This is because unlike images, the channels themselves do not have inherent ordering nor
locality. Although this module in essence, applies such an assumption, it is unnecessary
to limit the channels on either "edge" from being circularly adapted to each other.
This will fundamentally increase connectivity and possibly increase performance metrics
(accuracy, robustness), without signficantly impacting resource metrics
(parameter size, throughput,latency, etc)
Args:
channel: Number of channels of the input feature map
k_size: Adaptive selection of kernel size
"""
def __init__(self, channel, k_size=3):
super(ceca_layer, self).__init__()
assert k_size % 2 == 1
self.avg_pool = nn.AdaptiveAvgPool2d(1)
#pytorch circular padding mode is bugged as of pytorch 1.4
# see https://github.com/pytorch/pytorch/pull/17240
#implement manual circular padding
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding = 0, bias=False)
self.padding = (k_size - 1) // 2
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# x: input features with shape [b, c, h, w]
b, c, h, w = x.size()
# feature descriptor on the global spatial information
y = self.avg_pool(x)
#manually implement circular padding
y = torch.cat((y[:,:self.padding,:,:], y, y[:,-self.padding:,:,:]),dim=1)
# Two different branches of ECA module
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
# Multi-scale information fusion
y = self.sigmoid(y)
return x * y.expand_as(x)

@ -14,6 +14,7 @@ import torch.nn.functional as F
from .registry import register_model
from .helpers import load_pretrained
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .eca_module import eca_layer
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
@ -100,6 +101,10 @@ default_cfgs = {
'seresnext26tn_32x4d': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth',
interpolation='bicubic'),
'ecaresnext26tn_32x4d': _cfg(
url='',
interpolation='bicubic'),
}
@ -132,7 +137,7 @@ class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=64, use_se=False,
cardinality=1, base_width=64, use_se=False, use_eca = False,
reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(BasicBlock, self).__init__()
@ -150,7 +155,10 @@ class BasicBlock(nn.Module):
first_planes, outplanes, kernel_size=3, padding=previous_dilation,
dilation=previous_dilation, bias=False)
self.bn2 = norm_layer(outplanes)
self.se = SEModule(outplanes, planes // 4) if use_se else None
self.eca = eca_layer(outplanes) if use_eca else None
self.act2 = act_layer(inplace=True)
self.downsample = downsample
self.stride = stride
@ -167,6 +175,8 @@ class BasicBlock(nn.Module):
if self.se is not None:
out = self.se(out)
if self.eca is not None:
out = self.eca(out)
if self.downsample is not None:
residual = self.downsample(x)
@ -182,7 +192,7 @@ class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None,
cardinality=1, base_width=64, use_se=False,
cardinality=1, base_width=64, use_se=False, use_eca=False,
reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(Bottleneck, self).__init__()
@ -200,7 +210,10 @@ class Bottleneck(nn.Module):
self.act2 = act_layer(inplace=True)
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes)
self.se = SEModule(outplanes, planes // 4) if use_se else None
self.eca = eca_layer(outplanes) if use_eca else None
self.act3 = act_layer(inplace=True)
self.downsample = downsample
self.stride = stride
@ -222,6 +235,8 @@ class Bottleneck(nn.Module):
if self.se is not None:
out = self.se(out)
if self.eca is not None:
out = self.eca(out)
if self.downsample is not None:
residual = self.downsample(x)
@ -275,6 +290,8 @@ class ResNet(nn.Module):
Number of input (color) channels.
use_se : bool, default False
Enable Squeeze-Excitation module in blocks
use_eca : bool, default False
Enable ECA module in blocks
cardinality : int, default 1
Number of convolution groups for 3x3 conv in Bottleneck.
base_width : int, default 64
@ -303,7 +320,7 @@ class ResNet(nn.Module):
global_pool : str, default 'avg'
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
"""
def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False,
def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False, use_eca=False,
cardinality=1, base_width=64, stem_width=64, stem_type='',
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg',
@ -350,7 +367,7 @@ class ResNet(nn.Module):
assert output_stride == 32
llargs = list(zip(channels, layers, strides, dilations))
lkwargs = dict(
use_se=use_se, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer,
use_se=use_se, use_eca=use_eca, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer,
avg_down=avg_down, down_kernel_size=down_kernel_size, **block_args)
self.layer1 = self._make_layer(block, *llargs[0], **lkwargs)
self.layer2 = self._make_layer(block, *llargs[1], **lkwargs)
@ -375,7 +392,7 @@ class ResNet(nn.Module):
nn.init.constant_(m.bias, 0.)
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1,
use_se=False, avg_down=False, down_kernel_size=1, **kwargs):
use_se=False, use_eca=False,avg_down=False, down_kernel_size=1, **kwargs):
norm_layer = kwargs.get('norm_layer')
downsample = None
down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size
@ -396,7 +413,7 @@ class ResNet(nn.Module):
first_dilation = 1 if dilation in (1, 2) else 2
bkwargs = dict(
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
use_se=use_se, **kwargs)
use_se=use_se, use_eca=use_eca, **kwargs)
layers = [block(
self.inplanes, planes, stride, downsample, dilation=first_dilation, previous_dilation=dilation, **bkwargs)]
self.inplanes = planes * block.expansion
@ -944,3 +961,20 @@ def seresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model
@register_model
def ecaresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
"""Constructs a eca-ResNeXt-26-TN model.
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
this model replaces SE module with the ECA module
"""
default_cfg = default_cfgs['ecaresnext26tn_32x4d']
model = ResNet(
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, use_eca=True,
num_classes=num_classes, in_chans=in_chans, **kwargs)
model.default_cfg = default_cfg
if pretrained:
load_pretrained(model, default_cfg, num_classes, in_chans)
return model

Loading…
Cancel
Save