Implement Adaptive Kernel selection

When channel size is given,
calculate adaptive kernel size according to original paper.
Otherwise use the given kernel size(k_size), which defaults to 3
pull/82/head
Chris Ha 4 years ago
parent 6db087a1ff
commit e6a762346a

@ -31,6 +31,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
''' '''
import math
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
@ -38,15 +39,25 @@ class EcaModule(nn.Module):
"""Constructs a ECA module. """Constructs a ECA module.
Args: Args:
channel: Number of channels of the input feature map channel: Number of channels of the input feature map for use in adaptive kernel sizes
k_size: Adaptive selection of kernel size for actual calculations according to channel.
gamma, beta: when channel is given parameters of mapping function
refer to original paper https://arxiv.org/pdf/1910.03151.pdf
(default=None. if channel size not given, use k_size given for kernel size.)
k_size: Adaptive selection of kernel size (default=3)
""" """
def __init__(self, channel, k_size=3): def __init__(self, channel=None, k_size=3, gamma=2, beta=1):
super(EcaModule, self).__init__() super(EcaModule, self).__init__()
assert k_size % 2 == 1 assert k_size % 2 == 1
if channel is not None:
t = int(abs(math.log(channel, 2)+beta) / gamma)
k_size = t if t % 2 else t + 1
self.avg_pool = nn.AdaptiveAvgPool2d(1) self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
def forward(self, x): def forward(self, x):
# feature descriptor on the global spatial information # feature descriptor on the global spatial information
y = self.avg_pool(x) y = self.avg_pool(x)
@ -58,25 +69,6 @@ class EcaModule(nn.Module):
y = self.sigmoid(y.view(x.shape[0], -1, 1, 1)) y = self.sigmoid(y.view(x.shape[0], -1, 1, 1))
return x * y.expand_as(x) return x * y.expand_as(x)
'''original implementation
def forward(self, x):
# x: input features with shape [b, c, h, w]
b, c, h, w = x.size()
# feature descriptor on the global spatial information
y = self.avg_pool(x)
# Two different branches of ECA module
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
# Multi-scale information fusion
y = self.sigmoid(y)
return x * y.expand_as(x)
'''
class CecaModule(nn.Module): class CecaModule(nn.Module):
"""Constructs a circular ECA module. """Constructs a circular ECA module.
the primary difference is that the conv uses a circular padding rather than zero padding. the primary difference is that the conv uses a circular padding rather than zero padding.
@ -88,15 +80,26 @@ class CecaModule(nn.Module):
(parameter size, throughput,latency, etc) (parameter size, throughput,latency, etc)
Args: Args:
channel: Number of channels of the input feature map channel: Number of channels of the input feature map for use in adaptive kernel sizes
k_size: Adaptive selection of kernel size for actual calculations according to channel.
gamma, beta: when channel is given parameters of mapping function
refer to original paper https://arxiv.org/pdf/1910.03151.pdf
(default=None. if channel size not given, use k_size given for kernel size.)
k_size: Adaptive selection of kernel size (default=3)
""" """
def __init__(self, channel, k_size=3):
def __init__(self, channel=None, k_size=3, gamma=2, beta=1):
super(CecaModule, self).__init__() super(CecaModule, self).__init__()
assert k_size % 2 == 1 assert k_size % 2 == 1
if channel is not None:
t = int(abs(math.log(channel, 2)+beta) / gamma)
k_size = t if t % 2 else t + 1
self.avg_pool = nn.AdaptiveAvgPool2d(1) self.avg_pool = nn.AdaptiveAvgPool2d(1)
#pytorch circular padding mode is bugged as of pytorch 1.4 #pytorch circular padding mode is bugged as of pytorch 1.4
# see https://github.com/pytorch/pytorch/pull/17240 #see https://github.com/pytorch/pytorch/pull/17240
#implement manual circular padding #implement manual circular padding
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=0, bias=False) self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=0, bias=False)
self.padding = (k_size - 1) // 2 self.padding = (k_size - 1) // 2

Loading…
Cancel
Save