|
|
|
@ -36,6 +36,7 @@ SOFTWARE.
|
|
|
|
|
import math
|
|
|
|
|
import torch
|
|
|
|
|
from torch import nn
|
|
|
|
|
from .cbam import SpatialAttn
|
|
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -124,32 +125,14 @@ class CecaModule(nn.Module):
|
|
|
|
|
|
|
|
|
|
return x * y.expand_as(x)
|
|
|
|
|
|
|
|
|
|
class ChannelPool(nn.Module):
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
return torch.cat((torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1)
|
|
|
|
|
|
|
|
|
|
class SpatialGate(nn.Module):
|
|
|
|
|
def __init__(self):
|
|
|
|
|
super(SpatialGate, self).__init__()
|
|
|
|
|
kernel_size = 7
|
|
|
|
|
self.compress = ChannelPool()
|
|
|
|
|
self.spatial = nn.Conv2d(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2)
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x_compress = self.compress(x)
|
|
|
|
|
x_out = self.spatial(x_compress)
|
|
|
|
|
scale = torch.sigmoid(x_out) # broadcasting
|
|
|
|
|
return x * scale
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CecaBam(nn.Module):
|
|
|
|
|
def __init__(self, gate_channels, no_spatial=False):
|
|
|
|
|
def __init__(self, gate_channels):
|
|
|
|
|
super(CecaBam, self).__init__()
|
|
|
|
|
self.CecaModule = CecaModule(gate_channels)
|
|
|
|
|
self.no_spatial=no_spatial
|
|
|
|
|
if not no_spatial:
|
|
|
|
|
self.SpatialGate = SpatialGate()
|
|
|
|
|
self.SpatialAttention = SpatialAttn(kernel_size=7)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x_out = self.CecaModule(x)
|
|
|
|
|
if not self.no_spatial:
|
|
|
|
|
x_out = self.SpatialGate(x_out)
|
|
|
|
|
return x_out
|
|
|
|
|
y = self.CecaModule(x)
|
|
|
|
|
y = self.SpatialAttention(y)
|
|
|
|
|
return y
|
|
|
|
|