A little bit of ECA cleanup

pull/87/head
Ross Wightman 5 years ago
parent 46471df7b2
commit 7011cd0902

@ -1,14 +1,16 @@
''' """
ECA module from ECAnet ECA module from ECAnet
original paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks
paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks
https://arxiv.org/abs/1910.03151 https://arxiv.org/abs/1910.03151
https://github.com/BangguWu/ECANet Original ECA model borrowed from https://github.com/BangguWu/ECANet
original ECA model borrowed from original github
modified circular ECA implementation and Modified circular ECA implementation and adaption for use in timm package
adoptation for use in pytorch image models package
by Chris Ha https://github.com/VRandme by Chris Ha https://github.com/VRandme
Original License:
MIT License MIT License
Copyright (c) 2019 BangguWu, Qilong Wang Copyright (c) 2019 BangguWu, Qilong Wang
@ -30,13 +32,14 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
''' """
import math import math
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
class EcaModule(nn.Module): class EcaModule(nn.Module):
"""Constructs a ECA module. """Constructs an ECA module.
Args: Args:
channel: Number of channels of the input feature map for use in adaptive kernel sizes channel: Number of channels of the input feature map for use in adaptive kernel sizes
@ -59,9 +62,9 @@ class EcaModule(nn.Module):
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
def forward(self, x): def forward(self, x):
# feature descriptor on the global spatial information # Feature descriptor on the global spatial information
y = self.avg_pool(x) y = self.avg_pool(x)
# reshape for convolution # Reshape for convolution
y = y.view(x.shape[0], 1, -1) y = y.view(x.shape[0], 1, -1)
# Two different branches of ECA module # Two different branches of ECA module
y = self.conv(y) y = self.conv(y)
@ -69,10 +72,12 @@ class EcaModule(nn.Module):
y = self.sigmoid(y.view(x.shape[0], -1, 1, 1)) y = self.sigmoid(y.view(x.shape[0], -1, 1, 1))
return x * y.expand_as(x) return x * y.expand_as(x)
class CecaModule(nn.Module): class CecaModule(nn.Module):
"""Constructs a circular ECA module. """Constructs a circular ECA module.
the primary difference is that the conv uses a circular padding rather than zero padding.
This is because unlike images, the channels themselves do not have inherent ordering nor ECA module where the conv uses circular padding rather than zero padding.
Unlike the spatial dimension, the channels do not have inherent ordering nor
locality. Although this module in essence, applies such an assumption, it is unnecessary locality. Although this module in essence, applies such an assumption, it is unnecessary
to limit the channels on either "edge" from being circularly adapted to each other. to limit the channels on either "edge" from being circularly adapted to each other.
This will fundamentally increase connectivity and possibly increase performance metrics This will fundamentally increase connectivity and possibly increase performance metrics
@ -97,7 +102,7 @@ class CecaModule(nn.Module):
k_size = t if t % 2 else t + 1 k_size = t if t % 2 else t + 1
self.avg_pool = nn.AdaptiveAvgPool2d(1) self.avg_pool = nn.AdaptiveAvgPool2d(1)
#pytorch circular padding mode is bugged as of pytorch 1.4 #pytorch circular padding mode is buggy as of pytorch 1.4
#see https://github.com/pytorch/pytorch/pull/17240 #see https://github.com/pytorch/pytorch/pull/17240
#implement manual circular padding #implement manual circular padding
@ -106,10 +111,10 @@ class CecaModule(nn.Module):
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
def forward(self, x): def forward(self, x):
# feature descriptor on the global spatial information # Feature descriptor on the global spatial information
y = self.avg_pool(x) y = self.avg_pool(x)
#manually implement circular padding, F.pad does not seemed to be bugged # Manually implement circular padding, F.pad does not seemed to be bugged
y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular') y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular')
# Two different branches of ECA module # Two different branches of ECA module

Loading…
Cancel
Save