|
|
|
@ -34,6 +34,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
|
|
|
SOFTWARE.
|
|
|
|
|
"""
|
|
|
|
|
import math
|
|
|
|
|
import torch
|
|
|
|
|
from torch import nn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
@ -62,14 +63,30 @@ class EfficientChannelAttn(nn.Module):
|
|
|
|
|
self.gate_fn = gate_fn
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
y = self.avg_pool(x) # Feature descriptor on the global spatial information
|
|
|
|
|
y = y.view(x.shape[0], 1, -1) # Reshape for convolution
|
|
|
|
|
y = self.avg_pool(x)
|
|
|
|
|
y = y.view(x.shape[0], 1, -1) # Reshape 4d -> 3d for 1d convolution
|
|
|
|
|
y = self.conv(y)
|
|
|
|
|
y = y.view(x.shape[0], -1, 1, 1)
|
|
|
|
|
y = y.view(x.shape[0], -1, 1, 1) # Back to 4d
|
|
|
|
|
y = y.sigmoid() if self.gate_fn is None else self.gate_fn(y)
|
|
|
|
|
return x * y.expand_as(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def padding1d_circular(input, pad):
|
|
|
|
|
r"""input: torch.tensor([[[0., 1., 2.],
|
|
|
|
|
[3., 4., 5.]]])
|
|
|
|
|
pad: (1, 2)
|
|
|
|
|
output: tensor([[[2., 0., 1., 2., 0., 1.],
|
|
|
|
|
[5., 3., 4., 5., 3., 4.]]])
|
|
|
|
|
|
|
|
|
|
from: https://github.com/pytorch/pytorch/issues/24504
|
|
|
|
|
"""
|
|
|
|
|
input = torch.cat([input, input[:, :, 0:pad[-1]]], dim=2)
|
|
|
|
|
if pad[-1] == 0 and pad[-2] != 0:
|
|
|
|
|
return torch.cat([input[:, :, -(pad[-1] + pad[-2]):], input], dim=2)
|
|
|
|
|
else:
|
|
|
|
|
return torch.cat([input[:, :, -(pad[-1] + pad[-2]):-pad[-1]], input], dim=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CircularEfficientChannelAttn(nn.Module):
|
|
|
|
|
"""Constructs a circular ECA module.
|
|
|
|
|
|
|
|
|
@ -97,20 +114,20 @@ class CircularEfficientChannelAttn(nn.Module):
|
|
|
|
|
if channels is not None:
|
|
|
|
|
t = int(abs(math.log(channels, 2) + beta) / gamma)
|
|
|
|
|
kernel_size = max(t if t % 2 else t + 1, 3)
|
|
|
|
|
self.padding = (kernel_size - 1) // 2
|
|
|
|
|
|
|
|
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
|
|
|
#pytorch circular padding mode is buggy as of pytorch 1.4
|
|
|
|
|
#see https://github.com/pytorch/pytorch/pull/17240
|
|
|
|
|
# pytorch conv circular padding mode is buggy as of pytorch 1.4, will implement manually
|
|
|
|
|
# see https://github.com/pytorch/pytorch/pull/17240
|
|
|
|
|
# https://github.com/pytorch/pytorch/issues/24504
|
|
|
|
|
p = (kernel_size - 1) // 2
|
|
|
|
|
self.padding = (p, p)
|
|
|
|
|
|
|
|
|
|
#implement manual circular padding
|
|
|
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
|
|
|
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False)
|
|
|
|
|
self.gate_fn = gate_fn
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
y = self.avg_pool(x) # Feature descriptor on the global spatial information
|
|
|
|
|
# Manually implement circular padding, F.pad does not seemed to be bugged
|
|
|
|
|
y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular')
|
|
|
|
|
y = self.avg_pool(x)
|
|
|
|
|
y = padding1d_circular(y.view(x.shape[0], 1, -1), self.padding) # manual circular padding
|
|
|
|
|
y = self.conv(y)
|
|
|
|
|
y = y.view(x.shape[0], -1, 1, 1)
|
|
|
|
|
y = y.sigmoid() if self.gate_fn is None else self.gate_fn(y)
|
|
|
|
|