From ade1ba5fe3782cdb15bd1ff0379f1daaa0f9350b Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 25 Feb 2020 16:12:41 -0800 Subject: [PATCH] Bring in code that should fix 1d circular padding properly, tweaks to ECA impl, using CECA in MobileNetV3 experiment --- timm/models/layers/eca.py | 39 +++++++++++++++++++++++++++----------- timm/models/mobilenetv3.py | 2 +- 2 files changed, 29 insertions(+), 12 deletions(-) diff --git a/timm/models/layers/eca.py b/timm/models/layers/eca.py index b731a1ad..c9e31de3 100644 --- a/timm/models/layers/eca.py +++ b/timm/models/layers/eca.py @@ -34,6 +34,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ import math +import torch from torch import nn import torch.nn.functional as F @@ -62,14 +63,30 @@ class EfficientChannelAttn(nn.Module): self.gate_fn = gate_fn def forward(self, x): - y = self.avg_pool(x) # Feature descriptor on the global spatial information - y = y.view(x.shape[0], 1, -1) # Reshape for convolution + y = self.avg_pool(x) + y = y.view(x.shape[0], 1, -1) # Reshape 4d -> 3d for 1d convolution y = self.conv(y) - y = y.view(x.shape[0], -1, 1, 1) + y = y.view(x.shape[0], -1, 1, 1) # Back to 4d y = y.sigmoid() if self.gate_fn is None else self.gate_fn(y) return x * y.expand_as(x) +def padding1d_circular(input, pad): + r"""input: torch.tensor([[[0., 1., 2.], + [3., 4., 5.]]]) + pad: (1, 2) + output: tensor([[[2., 0., 1., 2., 0., 1.], + [5., 3., 4., 5., 3., 4.]]]) + + from: https://github.com/pytorch/pytorch/issues/24504 + """ + input = torch.cat([input, input[:, :, 0:pad[-1]]], dim=2) + if pad[-1] == 0 and pad[-2] != 0: + return torch.cat([input[:, :, -(pad[-1] + pad[-2]):], input], dim=2) + else: + return torch.cat([input[:, :, -(pad[-1] + pad[-2]):-pad[-1]], input], dim=2) + + class CircularEfficientChannelAttn(nn.Module): """Constructs a circular ECA module. @@ -97,20 +114,20 @@ class CircularEfficientChannelAttn(nn.Module): if channels is not None: t = int(abs(math.log(channels, 2) + beta) / gamma) kernel_size = max(t if t % 2 else t + 1, 3) - self.padding = (kernel_size - 1) // 2 - self.avg_pool = nn.AdaptiveAvgPool2d(1) - #pytorch circular padding mode is buggy as of pytorch 1.4 - #see https://github.com/pytorch/pytorch/pull/17240 + # pytorch conv circular padding mode is buggy as of pytorch 1.4, will implement manually + # see https://github.com/pytorch/pytorch/pull/17240 + # https://github.com/pytorch/pytorch/issues/24504 + p = (kernel_size - 1) // 2 + self.padding = (p, p) - #implement manual circular padding + self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False) self.gate_fn = gate_fn def forward(self, x): - y = self.avg_pool(x) # Feature descriptor on the global spatial information - # Manually implement circular padding, F.pad does not seemed to be bugged - y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular') + y = self.avg_pool(x) + y = padding1d_circular(y.view(x.shape[0], 1, -1), self.padding) # manual circular padding y = self.conv(y) y = y.view(x.shape[0], -1, 1, 1) y = y.sigmoid() if self.gate_fn is None else self.gate_fn(y) diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index c4d3586f..fb8c5cc4 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -413,7 +413,7 @@ def _gen_mobilenet_v3_eca(variant, channel_multiplier=1.0, pretrained=False, **k channel_multiplier=channel_multiplier, norm_kwargs=resolve_bn_args(kwargs), act_layer=act_layer, - attn_layer='eca', + attn_layer='ceca', attn_kwargs=dict(gate_fn=hard_sigmoid), **kwargs, )