A little bit of ECA cleanup

pull/87/head
Ross Wightman 5 years ago
parent 46471df7b2
commit 7011cd0902

@ -1,14 +1,16 @@
'''
"""
ECA module from ECAnet
original paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks
paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks
https://arxiv.org/abs/1910.03151
https://github.com/BangguWu/ECANet
original ECA model borrowed from original github
modified circular ECA implementation and
adoptation for use in pytorch image models package
Original ECA model borrowed from https://github.com/BangguWu/ECANet
Modified circular ECA implementation and adaption for use in timm package
by Chris Ha https://github.com/VRandme
Original License:
MIT License
Copyright (c) 2019 BangguWu, Qilong Wang
@ -30,13 +32,14 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
'''
"""
import math
from torch import nn
import torch.nn.functional as F
class EcaModule(nn.Module):
"""Constructs a ECA module.
"""Constructs an ECA module.
Args:
channel: Number of channels of the input feature map for use in adaptive kernel sizes
@ -59,9 +62,9 @@ class EcaModule(nn.Module):
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# feature descriptor on the global spatial information
# Feature descriptor on the global spatial information
y = self.avg_pool(x)
# reshape for convolution
# Reshape for convolution
y = y.view(x.shape[0], 1, -1)
# Two different branches of ECA module
y = self.conv(y)
@ -69,10 +72,12 @@ class EcaModule(nn.Module):
y = self.sigmoid(y.view(x.shape[0], -1, 1, 1))
return x * y.expand_as(x)
class CecaModule(nn.Module):
"""Constructs a circular ECA module.
the primary difference is that the conv uses a circular padding rather than zero padding.
This is because unlike images, the channels themselves do not have inherent ordering nor
ECA module where the conv uses circular padding rather than zero padding.
Unlike the spatial dimension, the channels do not have inherent ordering nor
locality. Although this module in essence, applies such an assumption, it is unnecessary
to limit the channels on either "edge" from being circularly adapted to each other.
This will fundamentally increase connectivity and possibly increase performance metrics
@ -97,7 +102,7 @@ class CecaModule(nn.Module):
k_size = t if t % 2 else t + 1
self.avg_pool = nn.AdaptiveAvgPool2d(1)
#pytorch circular padding mode is bugged as of pytorch 1.4
#pytorch circular padding mode is buggy as of pytorch 1.4
#see https://github.com/pytorch/pytorch/pull/17240
#implement manual circular padding
@ -106,10 +111,10 @@ class CecaModule(nn.Module):
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# feature descriptor on the global spatial information
# Feature descriptor on the global spatial information
y = self.avg_pool(x)
#manually implement circular padding, F.pad does not seemed to be bugged
# Manually implement circular padding, F.pad does not seemed to be bugged
y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular')
# Two different branches of ECA module

Loading…
Cancel
Save