From 0697ab183bae341c7da8f6617172eb7d6ca07cc2 Mon Sep 17 00:00:00 2001 From: Chris Ha <15088501+VRandme@users.noreply.github.com> Date: Thu, 30 Jan 2020 20:40:55 +0900 Subject: [PATCH 1/6] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 85ddfe0e..66172bca 100644 --- a/.gitignore +++ b/.gitignore @@ -104,3 +104,4 @@ venv.bak/ *.tar *.pth *.gz +Untitled.ipynb From f87fcd7e88fa2e8caf85c2f2b1deef9787c416a9 Mon Sep 17 00:00:00 2001 From: Chris Ha <15088501+VRandme@users.noreply.github.com> Date: Tue, 4 Feb 2020 23:15:29 +0900 Subject: [PATCH 2/6] Implement Eca modules implement ECA module by 1. adopting original eca_module.py into models folder 2. adding use_eca layer besides every instance of SE layer --- .gitignore | 1 + timm/models/eca_module.py | 110 ++++++++++++++++++++++++++++++++++++++ timm/models/resnet.py | 46 +++++++++++++--- 3 files changed, 151 insertions(+), 6 deletions(-) create mode 100644 timm/models/eca_module.py diff --git a/.gitignore b/.gitignore index 66172bca..54db5359 100644 --- a/.gitignore +++ b/.gitignore @@ -105,3 +105,4 @@ venv.bak/ *.pth *.gz Untitled.ipynb +Testing notebook.ipynb diff --git a/timm/models/eca_module.py b/timm/models/eca_module.py new file mode 100644 index 00000000..7664354e --- /dev/null +++ b/timm/models/eca_module.py @@ -0,0 +1,110 @@ +''' +ECA module from ECAnet +original paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks +https://arxiv.org/abs/1910.03151 + +https://github.com/BangguWu/ECANet +original ECA model borrowed from original github +modified circular ECA implementation and +adoptation for use in pytorch image models package +by Chris Ha https://github.com/VRandme + +MIT License + +Copyright (c) 2019 BangguWu, Qilong Wang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +''' +import torch +from torch import nn +from torch.nn.parameter import Parameter + + +class eca_layer(nn.Module): + """Constructs a ECA module. + + Args: + channel: Number of channels of the input feature map + k_size: Adaptive selection of kernel size + """ + def __init__(self, channel, k_size=3): + super(eca_layer, self).__init__() + assert k_size % 2 == 1 + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + # x: input features with shape [b, c, h, w] + b, c, h, w = x.size() + + # feature descriptor on the global spatial information + y = self.avg_pool(x) + + # Two different branches of ECA module + y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) + + # Multi-scale information fusion + y = self.sigmoid(y) + + return x * y.expand_as(x) + +class ceca_layer(nn.Module): + """Constructs a circular ECA module. + the primary difference is that the conv uses a circular padding rather than zero padding. + This is because unlike images, the channels themselves do not have inherent ordering nor + locality. Although this module in essence, applies such an assumption, it is unnecessary + to limit the channels on either "edge" from being circularly adapted to each other. + This will fundamentally increase connectivity and possibly increase performance metrics + (accuracy, robustness), without signficantly impacting resource metrics + (parameter size, throughput,latency, etc) + + + Args: + channel: Number of channels of the input feature map + k_size: Adaptive selection of kernel size + """ + def __init__(self, channel, k_size=3): + super(ceca_layer, self).__init__() + assert k_size % 2 == 1 + self.avg_pool = nn.AdaptiveAvgPool2d(1) + #pytorch circular padding mode is bugged as of pytorch 1.4 + # see https://github.com/pytorch/pytorch/pull/17240 + #implement manual circular padding + self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding = 0, bias=False) + self.padding = (k_size - 1) // 2 + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + # x: input features with shape [b, c, h, w] + b, c, h, w = x.size() + # feature descriptor on the global spatial information + y = self.avg_pool(x) + + #manually implement circular padding + y = torch.cat((y[:,:self.padding,:,:], y, y[:,-self.padding:,:,:]),dim=1) + + + # Two different branches of ECA module + y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) + + # Multi-scale information fusion + y = self.sigmoid(y) + + return x * y.expand_as(x) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 422eb0cb..49395c83 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -14,6 +14,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .eca_module import eca_layer from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -100,6 +101,10 @@ default_cfgs = { 'seresnext26tn_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth', interpolation='bicubic'), + 'ecaresnext26tn_32x4d': _cfg( + url='', + interpolation='bicubic'), + } @@ -132,7 +137,7 @@ class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, + cardinality=1, base_width=64, use_se=False, use_eca = False, reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(BasicBlock, self).__init__() @@ -150,7 +155,10 @@ class BasicBlock(nn.Module): first_planes, outplanes, kernel_size=3, padding=previous_dilation, dilation=previous_dilation, bias=False) self.bn2 = norm_layer(outplanes) + self.se = SEModule(outplanes, planes // 4) if use_se else None + self.eca = eca_layer(outplanes) if use_eca else None + self.act2 = act_layer(inplace=True) self.downsample = downsample self.stride = stride @@ -167,6 +175,8 @@ class BasicBlock(nn.Module): if self.se is not None: out = self.se(out) + if self.eca is not None: + out = self.eca(out) if self.downsample is not None: residual = self.downsample(x) @@ -182,7 +192,7 @@ class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, + cardinality=1, base_width=64, use_se=False, use_eca=False, reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(Bottleneck, self).__init__() @@ -200,7 +210,10 @@ class Bottleneck(nn.Module): self.act2 = act_layer(inplace=True) self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) self.bn3 = norm_layer(outplanes) + self.se = SEModule(outplanes, planes // 4) if use_se else None + self.eca = eca_layer(outplanes) if use_eca else None + self.act3 = act_layer(inplace=True) self.downsample = downsample self.stride = stride @@ -222,6 +235,8 @@ class Bottleneck(nn.Module): if self.se is not None: out = self.se(out) + if self.eca is not None: + out = self.eca(out) if self.downsample is not None: residual = self.downsample(x) @@ -275,6 +290,8 @@ class ResNet(nn.Module): Number of input (color) channels. use_se : bool, default False Enable Squeeze-Excitation module in blocks + use_eca : bool, default False + Enable ECA module in blocks cardinality : int, default 1 Number of convolution groups for 3x3 conv in Bottleneck. base_width : int, default 64 @@ -303,7 +320,7 @@ class ResNet(nn.Module): global_pool : str, default 'avg' Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' """ - def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False, + def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False, use_eca=False, cardinality=1, base_width=64, stem_width=64, stem_type='', block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg', @@ -350,7 +367,7 @@ class ResNet(nn.Module): assert output_stride == 32 llargs = list(zip(channels, layers, strides, dilations)) lkwargs = dict( - use_se=use_se, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, + use_se=use_se, use_eca=use_eca, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, avg_down=avg_down, down_kernel_size=down_kernel_size, **block_args) self.layer1 = self._make_layer(block, *llargs[0], **lkwargs) self.layer2 = self._make_layer(block, *llargs[1], **lkwargs) @@ -375,7 +392,7 @@ class ResNet(nn.Module): nn.init.constant_(m.bias, 0.) def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1, - use_se=False, avg_down=False, down_kernel_size=1, **kwargs): + use_se=False, use_eca=False,avg_down=False, down_kernel_size=1, **kwargs): norm_layer = kwargs.get('norm_layer') downsample = None down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size @@ -396,7 +413,7 @@ class ResNet(nn.Module): first_dilation = 1 if dilation in (1, 2) else 2 bkwargs = dict( cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first, - use_se=use_se, **kwargs) + use_se=use_se, use_eca=use_eca, **kwargs) layers = [block( self.inplanes, planes, stride, downsample, dilation=first_dilation, previous_dilation=dilation, **bkwargs)] self.inplanes = planes * block.expansion @@ -944,3 +961,20 @@ def seresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) return model + +@register_model +def ecaresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a eca-ResNeXt-26-TN model. + This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels + in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant. + this model replaces SE module with the ECA module + """ + default_cfg = default_cfgs['ecaresnext26tn_32x4d'] + model = ResNet( + Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4, + stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, use_eca=True, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model From d63ae121d52463d86116021e9ee9160a4628e46a Mon Sep 17 00:00:00 2001 From: Chris Ha <15088501+VRandme@users.noreply.github.com> Date: Thu, 6 Feb 2020 22:44:33 +0900 Subject: [PATCH 3/6] Clean up eca_module code functionally similar adjusted rwightman's version of reshaping and viewing. Use F.pad for circular eca version for cleaner code --- timm/models/eca_module.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/timm/models/eca_module.py b/timm/models/eca_module.py index 7664354e..5cb52d96 100644 --- a/timm/models/eca_module.py +++ b/timm/models/eca_module.py @@ -49,7 +49,19 @@ class eca_layer(nn.Module): self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) self.sigmoid = nn.Sigmoid() + def forward(self, x): + # feature descriptor on the global spatial information + y = self.avg_pool(x) + # reshape for convolution + y = y.view(x.shape[0], 1, -1) + # Two different branches of ECA module + y = self.conv(y) + # Multi-scale information fusion + y = self.sigmoid(y.view(x.shape[0], -1, 1, 1)) + return x * y.expand_as(x) + + '''original implementation def forward(self, x): # x: input features with shape [b, c, h, w] b, c, h, w = x.size() @@ -62,8 +74,10 @@ class eca_layer(nn.Module): # Multi-scale information fusion y = self.sigmoid(y) - return x * y.expand_as(x) + ''' + + class ceca_layer(nn.Module): """Constructs a circular ECA module. @@ -75,7 +89,6 @@ class ceca_layer(nn.Module): (accuracy, robustness), without signficantly impacting resource metrics (parameter size, throughput,latency, etc) - Args: channel: Number of channels of the input feature map k_size: Adaptive selection of kernel size @@ -92,19 +105,16 @@ class ceca_layer(nn.Module): self.sigmoid = nn.Sigmoid() def forward(self, x): - # x: input features with shape [b, c, h, w] - b, c, h, w = x.size() # feature descriptor on the global spatial information y = self.avg_pool(x) - #manually implement circular padding - y = torch.cat((y[:,:self.padding,:,:], y, y[:,-self.padding:,:,:]),dim=1) - + #manually implement circular padding, F.pad does not seemed to be bugged + y = F.pad(y.view(x.shape[0],1,-1),(self.padding,self.padding),mode='circular') # Two different branches of ECA module - y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) + y = self.conv(y) # Multi-scale information fusion - y = self.sigmoid(y) + y = self.sigmoid(y.view(x.shape[0], -1, 1, 1)) return x * y.expand_as(x) From db91ba053bbdc892eb020203190a3fe96c11860f Mon Sep 17 00:00:00 2001 From: Chris Ha <15088501+VRandme@users.noreply.github.com> Date: Fri, 7 Feb 2020 19:28:07 +0900 Subject: [PATCH 4/6] EcaModule(CamelCase) CamelCased EcaModule. Renamed all instances of ecalayer to EcaModule. eca_module.py->EcaModule.py --- timm/models/{eca_module.py => EcaModule.py} | 8 ++++---- timm/models/resnet.py | 6 +++--- 2 files changed, 7 insertions(+), 7 deletions(-) rename timm/models/{eca_module.py => EcaModule.py} (96%) diff --git a/timm/models/eca_module.py b/timm/models/EcaModule.py similarity index 96% rename from timm/models/eca_module.py rename to timm/models/EcaModule.py index 5cb52d96..74da3170 100644 --- a/timm/models/eca_module.py +++ b/timm/models/EcaModule.py @@ -36,7 +36,7 @@ from torch import nn from torch.nn.parameter import Parameter -class eca_layer(nn.Module): +class EcaModule(nn.Module): """Constructs a ECA module. Args: @@ -44,7 +44,7 @@ class eca_layer(nn.Module): k_size: Adaptive selection of kernel size """ def __init__(self, channel, k_size=3): - super(eca_layer, self).__init__() + super(EcaModule, self).__init__() assert k_size % 2 == 1 self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) @@ -79,7 +79,7 @@ class eca_layer(nn.Module): -class ceca_layer(nn.Module): +class CecaModule(nn.Module): """Constructs a circular ECA module. the primary difference is that the conv uses a circular padding rather than zero padding. This is because unlike images, the channels themselves do not have inherent ordering nor @@ -94,7 +94,7 @@ class ceca_layer(nn.Module): k_size: Adaptive selection of kernel size """ def __init__(self, channel, k_size=3): - super(ceca_layer, self).__init__() + super(CecaModule, self).__init__() assert k_size % 2 == 1 self.avg_pool = nn.AdaptiveAvgPool2d(1) #pytorch circular padding mode is bugged as of pytorch 1.4 diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 49395c83..da755373 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -14,7 +14,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained from .adaptive_avgmax_pool import SelectAdaptivePool2d -from .eca_module import eca_layer +from .EcaModule import EcaModule from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -157,7 +157,7 @@ class BasicBlock(nn.Module): self.bn2 = norm_layer(outplanes) self.se = SEModule(outplanes, planes // 4) if use_se else None - self.eca = eca_layer(outplanes) if use_eca else None + self.eca = EcaModule(outplanes) if use_eca else None self.act2 = act_layer(inplace=True) self.downsample = downsample @@ -212,7 +212,7 @@ class Bottleneck(nn.Module): self.bn3 = norm_layer(outplanes) self.se = SEModule(outplanes, planes // 4) if use_se else None - self.eca = eca_layer(outplanes) if use_eca else None + self.eca = Eca_Module(outplanes) if use_eca else None self.act3 = act_layer(inplace=True) self.downsample = downsample From 904c618040e98b0e5fab3cdab007448488de9a12 Mon Sep 17 00:00:00 2001 From: Chris Ha <15088501+VRandme@users.noreply.github.com> Date: Fri, 7 Feb 2020 19:36:18 +0900 Subject: [PATCH 5/6] Update EcaModule.py Make pylint happy (commas, unused imports, missed imports) --- timm/models/EcaModule.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/timm/models/EcaModule.py b/timm/models/EcaModule.py index 74da3170..2eaeeb8b 100644 --- a/timm/models/EcaModule.py +++ b/timm/models/EcaModule.py @@ -1,17 +1,17 @@ ''' -ECA module from ECAnet +ECA module from ECAnet original paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks https://arxiv.org/abs/1910.03151 https://github.com/BangguWu/ECANet original ECA model borrowed from original github -modified circular ECA implementation and +modified circular ECA implementation and adoptation for use in pytorch image models package by Chris Ha https://github.com/VRandme MIT License -Copyright (c) 2019 BangguWu, Qilong Wang +Copyright (c) 2019 BangguWu, Qilong Wang Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal @@ -31,10 +31,8 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ''' -import torch from torch import nn -from torch.nn.parameter import Parameter - +import torch.nn.functional as F class EcaModule(nn.Module): """Constructs a ECA module. @@ -47,7 +45,7 @@ class EcaModule(nn.Module): super(EcaModule, self).__init__() assert k_size % 2 == 1 self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) + self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): # feature descriptor on the global spatial information @@ -82,11 +80,11 @@ class EcaModule(nn.Module): class CecaModule(nn.Module): """Constructs a circular ECA module. the primary difference is that the conv uses a circular padding rather than zero padding. - This is because unlike images, the channels themselves do not have inherent ordering nor + This is because unlike images, the channels themselves do not have inherent ordering nor locality. Although this module in essence, applies such an assumption, it is unnecessary to limit the channels on either "edge" from being circularly adapted to each other. This will fundamentally increase connectivity and possibly increase performance metrics - (accuracy, robustness), without signficantly impacting resource metrics + (accuracy, robustness), without signficantly impacting resource metrics (parameter size, throughput,latency, etc) Args: @@ -100,16 +98,16 @@ class CecaModule(nn.Module): #pytorch circular padding mode is bugged as of pytorch 1.4 # see https://github.com/pytorch/pytorch/pull/17240 #implement manual circular padding - self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding = 0, bias=False) + self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=0, bias=False) self.padding = (k_size - 1) // 2 self.sigmoid = nn.Sigmoid() def forward(self, x): # feature descriptor on the global spatial information y = self.avg_pool(x) - + #manually implement circular padding, F.pad does not seemed to be bugged - y = F.pad(y.view(x.shape[0],1,-1),(self.padding,self.padding),mode='circular') + y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular') # Two different branches of ECA module y = self.conv(y) From e6a762346a14f48f581da4552a11f5694ebef255 Mon Sep 17 00:00:00 2001 From: Chris Ha <15088501+VRandme@users.noreply.github.com> Date: Sun, 9 Feb 2020 11:58:03 +0900 Subject: [PATCH 6/6] Implement Adaptive Kernel selection When channel size is given, calculate adaptive kernel size according to original paper. Otherwise use the given kernel size(k_size), which defaults to 3 --- timm/models/EcaModule.py | 55 +++++++++++++++++++++------------------- 1 file changed, 29 insertions(+), 26 deletions(-) diff --git a/timm/models/EcaModule.py b/timm/models/EcaModule.py index 2eaeeb8b..b91b5801 100644 --- a/timm/models/EcaModule.py +++ b/timm/models/EcaModule.py @@ -31,6 +31,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ''' +import math from torch import nn import torch.nn.functional as F @@ -38,15 +39,25 @@ class EcaModule(nn.Module): """Constructs a ECA module. Args: - channel: Number of channels of the input feature map - k_size: Adaptive selection of kernel size + channel: Number of channels of the input feature map for use in adaptive kernel sizes + for actual calculations according to channel. + gamma, beta: when channel is given parameters of mapping function + refer to original paper https://arxiv.org/pdf/1910.03151.pdf + (default=None. if channel size not given, use k_size given for kernel size.) + k_size: Adaptive selection of kernel size (default=3) """ - def __init__(self, channel, k_size=3): + def __init__(self, channel=None, k_size=3, gamma=2, beta=1): super(EcaModule, self).__init__() assert k_size % 2 == 1 + + if channel is not None: + t = int(abs(math.log(channel, 2)+beta) / gamma) + k_size = t if t % 2 else t + 1 + self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) self.sigmoid = nn.Sigmoid() + def forward(self, x): # feature descriptor on the global spatial information y = self.avg_pool(x) @@ -58,25 +69,6 @@ class EcaModule(nn.Module): y = self.sigmoid(y.view(x.shape[0], -1, 1, 1)) return x * y.expand_as(x) - - '''original implementation - def forward(self, x): - # x: input features with shape [b, c, h, w] - b, c, h, w = x.size() - - # feature descriptor on the global spatial information - y = self.avg_pool(x) - - # Two different branches of ECA module - y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) - - # Multi-scale information fusion - y = self.sigmoid(y) - return x * y.expand_as(x) - ''' - - - class CecaModule(nn.Module): """Constructs a circular ECA module. the primary difference is that the conv uses a circular padding rather than zero padding. @@ -88,15 +80,26 @@ class CecaModule(nn.Module): (parameter size, throughput,latency, etc) Args: - channel: Number of channels of the input feature map - k_size: Adaptive selection of kernel size + channel: Number of channels of the input feature map for use in adaptive kernel sizes + for actual calculations according to channel. + gamma, beta: when channel is given parameters of mapping function + refer to original paper https://arxiv.org/pdf/1910.03151.pdf + (default=None. if channel size not given, use k_size given for kernel size.) + k_size: Adaptive selection of kernel size (default=3) """ - def __init__(self, channel, k_size=3): + + def __init__(self, channel=None, k_size=3, gamma=2, beta=1): super(CecaModule, self).__init__() assert k_size % 2 == 1 + + if channel is not None: + t = int(abs(math.log(channel, 2)+beta) / gamma) + k_size = t if t % 2 else t + 1 + self.avg_pool = nn.AdaptiveAvgPool2d(1) #pytorch circular padding mode is bugged as of pytorch 1.4 - # see https://github.com/pytorch/pytorch/pull/17240 + #see https://github.com/pytorch/pytorch/pull/17240 + #implement manual circular padding self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=0, bias=False) self.padding = (k_size - 1) // 2