diff --git a/.gitignore b/.gitignore index 85ddfe0e..54db5359 100644 --- a/.gitignore +++ b/.gitignore @@ -104,3 +104,5 @@ venv.bak/ *.tar *.pth *.gz +Untitled.ipynb +Testing notebook.ipynb diff --git a/timm/models/EcaModule.py b/timm/models/EcaModule.py new file mode 100644 index 00000000..b91b5801 --- /dev/null +++ b/timm/models/EcaModule.py @@ -0,0 +1,121 @@ +''' +ECA module from ECAnet +original paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks +https://arxiv.org/abs/1910.03151 + +https://github.com/BangguWu/ECANet +original ECA model borrowed from original github +modified circular ECA implementation and +adoptation for use in pytorch image models package +by Chris Ha https://github.com/VRandme + +MIT License + +Copyright (c) 2019 BangguWu, Qilong Wang + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +''' +import math +from torch import nn +import torch.nn.functional as F + +class EcaModule(nn.Module): + """Constructs a ECA module. + + Args: + channel: Number of channels of the input feature map for use in adaptive kernel sizes + for actual calculations according to channel. + gamma, beta: when channel is given parameters of mapping function + refer to original paper https://arxiv.org/pdf/1910.03151.pdf + (default=None. if channel size not given, use k_size given for kernel size.) + k_size: Adaptive selection of kernel size (default=3) + """ + def __init__(self, channel=None, k_size=3, gamma=2, beta=1): + super(EcaModule, self).__init__() + assert k_size % 2 == 1 + + if channel is not None: + t = int(abs(math.log(channel, 2)+beta) / gamma) + k_size = t if t % 2 else t + 1 + + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + # feature descriptor on the global spatial information + y = self.avg_pool(x) + # reshape for convolution + y = y.view(x.shape[0], 1, -1) + # Two different branches of ECA module + y = self.conv(y) + # Multi-scale information fusion + y = self.sigmoid(y.view(x.shape[0], -1, 1, 1)) + return x * y.expand_as(x) + +class CecaModule(nn.Module): + """Constructs a circular ECA module. + the primary difference is that the conv uses a circular padding rather than zero padding. + This is because unlike images, the channels themselves do not have inherent ordering nor + locality. Although this module in essence, applies such an assumption, it is unnecessary + to limit the channels on either "edge" from being circularly adapted to each other. + This will fundamentally increase connectivity and possibly increase performance metrics + (accuracy, robustness), without signficantly impacting resource metrics + (parameter size, throughput,latency, etc) + + Args: + channel: Number of channels of the input feature map for use in adaptive kernel sizes + for actual calculations according to channel. + gamma, beta: when channel is given parameters of mapping function + refer to original paper https://arxiv.org/pdf/1910.03151.pdf + (default=None. if channel size not given, use k_size given for kernel size.) + k_size: Adaptive selection of kernel size (default=3) + """ + + def __init__(self, channel=None, k_size=3, gamma=2, beta=1): + super(CecaModule, self).__init__() + assert k_size % 2 == 1 + + if channel is not None: + t = int(abs(math.log(channel, 2)+beta) / gamma) + k_size = t if t % 2 else t + 1 + + self.avg_pool = nn.AdaptiveAvgPool2d(1) + #pytorch circular padding mode is bugged as of pytorch 1.4 + #see https://github.com/pytorch/pytorch/pull/17240 + + #implement manual circular padding + self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=0, bias=False) + self.padding = (k_size - 1) // 2 + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + # feature descriptor on the global spatial information + y = self.avg_pool(x) + + #manually implement circular padding, F.pad does not seemed to be bugged + y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular') + + # Two different branches of ECA module + y = self.conv(y) + + # Multi-scale information fusion + y = self.sigmoid(y.view(x.shape[0], -1, 1, 1)) + + return x * y.expand_as(x) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 422eb0cb..da755373 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -14,6 +14,7 @@ import torch.nn.functional as F from .registry import register_model from .helpers import load_pretrained from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .EcaModule import EcaModule from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD @@ -100,6 +101,10 @@ default_cfgs = { 'seresnext26tn_32x4d': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth', interpolation='bicubic'), + 'ecaresnext26tn_32x4d': _cfg( + url='', + interpolation='bicubic'), + } @@ -132,7 +137,7 @@ class BasicBlock(nn.Module): expansion = 1 def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, + cardinality=1, base_width=64, use_se=False, use_eca = False, reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(BasicBlock, self).__init__() @@ -150,7 +155,10 @@ class BasicBlock(nn.Module): first_planes, outplanes, kernel_size=3, padding=previous_dilation, dilation=previous_dilation, bias=False) self.bn2 = norm_layer(outplanes) + self.se = SEModule(outplanes, planes // 4) if use_se else None + self.eca = EcaModule(outplanes) if use_eca else None + self.act2 = act_layer(inplace=True) self.downsample = downsample self.stride = stride @@ -167,6 +175,8 @@ class BasicBlock(nn.Module): if self.se is not None: out = self.se(out) + if self.eca is not None: + out = self.eca(out) if self.downsample is not None: residual = self.downsample(x) @@ -182,7 +192,7 @@ class Bottleneck(nn.Module): expansion = 4 def __init__(self, inplanes, planes, stride=1, downsample=None, - cardinality=1, base_width=64, use_se=False, + cardinality=1, base_width=64, use_se=False, use_eca=False, reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): super(Bottleneck, self).__init__() @@ -200,7 +210,10 @@ class Bottleneck(nn.Module): self.act2 = act_layer(inplace=True) self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) self.bn3 = norm_layer(outplanes) + self.se = SEModule(outplanes, planes // 4) if use_se else None + self.eca = Eca_Module(outplanes) if use_eca else None + self.act3 = act_layer(inplace=True) self.downsample = downsample self.stride = stride @@ -222,6 +235,8 @@ class Bottleneck(nn.Module): if self.se is not None: out = self.se(out) + if self.eca is not None: + out = self.eca(out) if self.downsample is not None: residual = self.downsample(x) @@ -275,6 +290,8 @@ class ResNet(nn.Module): Number of input (color) channels. use_se : bool, default False Enable Squeeze-Excitation module in blocks + use_eca : bool, default False + Enable ECA module in blocks cardinality : int, default 1 Number of convolution groups for 3x3 conv in Bottleneck. base_width : int, default 64 @@ -303,7 +320,7 @@ class ResNet(nn.Module): global_pool : str, default 'avg' Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax' """ - def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False, + def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False, use_eca=False, cardinality=1, base_width=64, stem_width=64, stem_type='', block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg', @@ -350,7 +367,7 @@ class ResNet(nn.Module): assert output_stride == 32 llargs = list(zip(channels, layers, strides, dilations)) lkwargs = dict( - use_se=use_se, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, + use_se=use_se, use_eca=use_eca, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer, avg_down=avg_down, down_kernel_size=down_kernel_size, **block_args) self.layer1 = self._make_layer(block, *llargs[0], **lkwargs) self.layer2 = self._make_layer(block, *llargs[1], **lkwargs) @@ -375,7 +392,7 @@ class ResNet(nn.Module): nn.init.constant_(m.bias, 0.) def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1, - use_se=False, avg_down=False, down_kernel_size=1, **kwargs): + use_se=False, use_eca=False,avg_down=False, down_kernel_size=1, **kwargs): norm_layer = kwargs.get('norm_layer') downsample = None down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size @@ -396,7 +413,7 @@ class ResNet(nn.Module): first_dilation = 1 if dilation in (1, 2) else 2 bkwargs = dict( cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first, - use_se=use_se, **kwargs) + use_se=use_se, use_eca=use_eca, **kwargs) layers = [block( self.inplanes, planes, stride, downsample, dilation=first_dilation, previous_dilation=dilation, **bkwargs)] self.inplanes = planes * block.expansion @@ -944,3 +961,20 @@ def seresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs if pretrained: load_pretrained(model, default_cfg, num_classes, in_chans) return model + +@register_model +def ecaresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs): + """Constructs a eca-ResNeXt-26-TN model. + This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels + in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant. + this model replaces SE module with the ECA module + """ + default_cfg = default_cfgs['ecaresnext26tn_32x4d'] + model = ResNet( + Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4, + stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, use_eca=True, + num_classes=num_classes, in_chans=in_chans, **kwargs) + model.default_cfg = default_cfg + if pretrained: + load_pretrained(model, default_cfg, num_classes, in_chans) + return model