From dbdb3ce963b9f06cb5c0c51456fd9ceb8e963ef7 Mon Sep 17 00:00:00 2001 From: Chris Ha <15088501+VRandme@users.noreply.github.com> Date: Sun, 9 Feb 2020 23:09:13 +0900 Subject: [PATCH] Add CecaBam CecaBam is circular ECA followed by BAM for both channelwise and spatial attention --- timm/models/EcaModule.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/timm/models/EcaModule.py b/timm/models/EcaModule.py index b91b5801..09ef0038 100644 --- a/timm/models/EcaModule.py +++ b/timm/models/EcaModule.py @@ -32,6 +32,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ''' import math +import torch from torch import nn import torch.nn.functional as F @@ -119,3 +120,33 @@ class CecaModule(nn.Module): y = self.sigmoid(y.view(x.shape[0], -1, 1, 1)) return x * y.expand_as(x) + +class ChannelPool(nn.Module): + def forward(self, x): + return torch.cat((torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1) + +class SpatialGate(nn.Module): + def __init__(self): + super(SpatialGate, self).__init__() + kernel_size = 7 + self.compress = ChannelPool() + self.spatial = nn.Conv2d(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, bias=False) + def forward(self, x): + x_compress = self.compress(x) + x_out = self.spatial(x_compress) + scale = torch.sigmoid(x_out) # broadcasting + return x * scale + + +class CecaBam(nn.Module): + def __init__(self, gate_channels, no_spatial=False): + super(CecaBam, self).__init__() + self.CecaModule = CecaModule(gate_channels) + self.no_spatial=no_spatial + if not no_spatial: + self.SpatialGate = SpatialGate() + def forward(self, x): + x_out = self.CecaModule(x) + if not self.no_spatial: + x_out = self.SpatialGate(x_out) + return x_out