From 5d2df0746e17e6ae9211722ce5a5dd05113bc8c3 Mon Sep 17 00:00:00 2001 From: Chris Ha <15088501+VRandme@users.noreply.github.com> Date: Tue, 11 Feb 2020 22:55:07 +0900 Subject: [PATCH] refactor cecabam with cbam --- timm/models/layers/eca.py | 31 +++++++------------------------ 1 file changed, 7 insertions(+), 24 deletions(-) diff --git a/timm/models/layers/eca.py b/timm/models/layers/eca.py index f4072aeb..e763db35 100644 --- a/timm/models/layers/eca.py +++ b/timm/models/layers/eca.py @@ -36,6 +36,7 @@ SOFTWARE. import math import torch from torch import nn +from .cbam import SpatialAttn import torch.nn.functional as F @@ -124,32 +125,14 @@ class CecaModule(nn.Module): return x * y.expand_as(x) -class ChannelPool(nn.Module): - def forward(self, x): - return torch.cat((torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1) - -class SpatialGate(nn.Module): - def __init__(self): - super(SpatialGate, self).__init__() - kernel_size = 7 - self.compress = ChannelPool() - self.spatial = nn.Conv2d(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2) - def forward(self, x): - x_compress = self.compress(x) - x_out = self.spatial(x_compress) - scale = torch.sigmoid(x_out) # broadcasting - return x * scale - class CecaBam(nn.Module): - def __init__(self, gate_channels, no_spatial=False): + def __init__(self, gate_channels): super(CecaBam, self).__init__() self.CecaModule = CecaModule(gate_channels) - self.no_spatial=no_spatial - if not no_spatial: - self.SpatialGate = SpatialGate() + self.SpatialAttention = SpatialAttn(kernel_size=7) + def forward(self, x): - x_out = self.CecaModule(x) - if not self.no_spatial: - x_out = self.SpatialGate(x_out) - return x_out + y = self.CecaModule(x) + y = self.SpatialAttention(y) + return y