Update EcaModule.py

Make pylint happy
(commas, unused imports, missed imports)
pull/82/head
Chris Ha 4 years ago
parent db91ba053b
commit 904c618040

@ -1,17 +1,17 @@
''' '''
ECA module from ECAnet ECA module from ECAnet
original paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks original paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks
https://arxiv.org/abs/1910.03151 https://arxiv.org/abs/1910.03151
https://github.com/BangguWu/ECANet https://github.com/BangguWu/ECANet
original ECA model borrowed from original github original ECA model borrowed from original github
modified circular ECA implementation and modified circular ECA implementation and
adoptation for use in pytorch image models package adoptation for use in pytorch image models package
by Chris Ha https://github.com/VRandme by Chris Ha https://github.com/VRandme
MIT License MIT License
Copyright (c) 2019 BangguWu, Qilong Wang Copyright (c) 2019 BangguWu, Qilong Wang
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal
@ -31,10 +31,8 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
''' '''
import torch
from torch import nn from torch import nn
from torch.nn.parameter import Parameter import torch.nn.functional as F
class EcaModule(nn.Module): class EcaModule(nn.Module):
"""Constructs a ECA module. """Constructs a ECA module.
@ -47,7 +45,7 @@ class EcaModule(nn.Module):
super(EcaModule, self).__init__() super(EcaModule, self).__init__()
assert k_size % 2 == 1 assert k_size % 2 == 1
self.avg_pool = nn.AdaptiveAvgPool2d(1) self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
def forward(self, x): def forward(self, x):
# feature descriptor on the global spatial information # feature descriptor on the global spatial information
@ -82,11 +80,11 @@ class EcaModule(nn.Module):
class CecaModule(nn.Module): class CecaModule(nn.Module):
"""Constructs a circular ECA module. """Constructs a circular ECA module.
the primary difference is that the conv uses a circular padding rather than zero padding. the primary difference is that the conv uses a circular padding rather than zero padding.
This is because unlike images, the channels themselves do not have inherent ordering nor This is because unlike images, the channels themselves do not have inherent ordering nor
locality. Although this module in essence, applies such an assumption, it is unnecessary locality. Although this module in essence, applies such an assumption, it is unnecessary
to limit the channels on either "edge" from being circularly adapted to each other. to limit the channels on either "edge" from being circularly adapted to each other.
This will fundamentally increase connectivity and possibly increase performance metrics This will fundamentally increase connectivity and possibly increase performance metrics
(accuracy, robustness), without signficantly impacting resource metrics (accuracy, robustness), without signficantly impacting resource metrics
(parameter size, throughput,latency, etc) (parameter size, throughput,latency, etc)
Args: Args:
@ -100,16 +98,16 @@ class CecaModule(nn.Module):
#pytorch circular padding mode is bugged as of pytorch 1.4 #pytorch circular padding mode is bugged as of pytorch 1.4
# see https://github.com/pytorch/pytorch/pull/17240 # see https://github.com/pytorch/pytorch/pull/17240
#implement manual circular padding #implement manual circular padding
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding = 0, bias=False) self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=0, bias=False)
self.padding = (k_size - 1) // 2 self.padding = (k_size - 1) // 2
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
def forward(self, x): def forward(self, x):
# feature descriptor on the global spatial information # feature descriptor on the global spatial information
y = self.avg_pool(x) y = self.avg_pool(x)
#manually implement circular padding, F.pad does not seemed to be bugged #manually implement circular padding, F.pad does not seemed to be bugged
y = F.pad(y.view(x.shape[0],1,-1),(self.padding,self.padding),mode='circular') y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular')
# Two different branches of ECA module # Two different branches of ECA module
y = self.conv(y) y = self.conv(y)

Loading…
Cancel
Save