Add CecaBam

CecaBam is circular ECA followed by BAM for both channelwise and spatial attention
pull/87/head
Chris Ha 6 years ago
parent e6a762346a
commit dbdb3ce963

@ -32,6 +32,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
'''
import math
import torch
from torch import nn
import torch.nn.functional as F
@ -119,3 +120,33 @@ class CecaModule(nn.Module):
y = self.sigmoid(y.view(x.shape[0], -1, 1, 1))
return x * y.expand_as(x)
class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat((torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1)
class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7
self.compress = ChannelPool()
self.spatial = nn.Conv2d(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, bias=False)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = torch.sigmoid(x_out) # broadcasting
return x * scale
class CecaBam(nn.Module):
def __init__(self, gate_channels, no_spatial=False):
super(CecaBam, self).__init__()
self.CecaModule = CecaModule(gate_channels)
self.no_spatial=no_spatial
if not no_spatial:
self.SpatialGate = SpatialGate()
def forward(self, x):
x_out = self.CecaModule(x)
if not self.no_spatial:
x_out = self.SpatialGate(x_out)
return x_out

Loading…
Cancel
Save