Bring in code that should fix 1d circular padding properly, tweaks to ECA impl, using CECA in MobileNetV3 experiment

efficientnet_attn
Ross Wightman 5 years ago
parent 67e759f710
commit ade1ba5fe3

@ -34,6 +34,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
""" """
import math import math
import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
@ -62,14 +63,30 @@ class EfficientChannelAttn(nn.Module):
self.gate_fn = gate_fn self.gate_fn = gate_fn
def forward(self, x): def forward(self, x):
y = self.avg_pool(x) # Feature descriptor on the global spatial information y = self.avg_pool(x)
y = y.view(x.shape[0], 1, -1) # Reshape for convolution y = y.view(x.shape[0], 1, -1) # Reshape 4d -> 3d for 1d convolution
y = self.conv(y) y = self.conv(y)
y = y.view(x.shape[0], -1, 1, 1) y = y.view(x.shape[0], -1, 1, 1) # Back to 4d
y = y.sigmoid() if self.gate_fn is None else self.gate_fn(y) y = y.sigmoid() if self.gate_fn is None else self.gate_fn(y)
return x * y.expand_as(x) return x * y.expand_as(x)
def padding1d_circular(input, pad):
r"""input: torch.tensor([[[0., 1., 2.],
[3., 4., 5.]]])
pad: (1, 2)
output: tensor([[[2., 0., 1., 2., 0., 1.],
[5., 3., 4., 5., 3., 4.]]])
from: https://github.com/pytorch/pytorch/issues/24504
"""
input = torch.cat([input, input[:, :, 0:pad[-1]]], dim=2)
if pad[-1] == 0 and pad[-2] != 0:
return torch.cat([input[:, :, -(pad[-1] + pad[-2]):], input], dim=2)
else:
return torch.cat([input[:, :, -(pad[-1] + pad[-2]):-pad[-1]], input], dim=2)
class CircularEfficientChannelAttn(nn.Module): class CircularEfficientChannelAttn(nn.Module):
"""Constructs a circular ECA module. """Constructs a circular ECA module.
@ -97,20 +114,20 @@ class CircularEfficientChannelAttn(nn.Module):
if channels is not None: if channels is not None:
t = int(abs(math.log(channels, 2) + beta) / gamma) t = int(abs(math.log(channels, 2) + beta) / gamma)
kernel_size = max(t if t % 2 else t + 1, 3) kernel_size = max(t if t % 2 else t + 1, 3)
self.padding = (kernel_size - 1) // 2
self.avg_pool = nn.AdaptiveAvgPool2d(1) # pytorch conv circular padding mode is buggy as of pytorch 1.4, will implement manually
#pytorch circular padding mode is buggy as of pytorch 1.4 # see https://github.com/pytorch/pytorch/pull/17240
#see https://github.com/pytorch/pytorch/pull/17240 # https://github.com/pytorch/pytorch/issues/24504
p = (kernel_size - 1) // 2
self.padding = (p, p)
#implement manual circular padding self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False) self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False)
self.gate_fn = gate_fn self.gate_fn = gate_fn
def forward(self, x): def forward(self, x):
y = self.avg_pool(x) # Feature descriptor on the global spatial information y = self.avg_pool(x)
# Manually implement circular padding, F.pad does not seemed to be bugged y = padding1d_circular(y.view(x.shape[0], 1, -1), self.padding) # manual circular padding
y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular')
y = self.conv(y) y = self.conv(y)
y = y.view(x.shape[0], -1, 1, 1) y = y.view(x.shape[0], -1, 1, 1)
y = y.sigmoid() if self.gate_fn is None else self.gate_fn(y) y = y.sigmoid() if self.gate_fn is None else self.gate_fn(y)

@ -413,7 +413,7 @@ def _gen_mobilenet_v3_eca(variant, channel_multiplier=1.0, pretrained=False, **k
channel_multiplier=channel_multiplier, channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
act_layer=act_layer, act_layer=act_layer,
attn_layer='eca', attn_layer='ceca',
attn_kwargs=dict(gate_fn=hard_sigmoid), attn_kwargs=dict(gate_fn=hard_sigmoid),
**kwargs, **kwargs,
) )

Loading…
Cancel
Save