diff --git a/timm/models/layers/eca.py b/timm/models/layers/eca.py index f4072aeb..e763db35 100644 --- a/timm/models/layers/eca.py +++ b/timm/models/layers/eca.py @@ -36,6 +36,7 @@ SOFTWARE. import math import torch from torch import nn +from .cbam import SpatialAttn import torch.nn.functional as F @@ -124,32 +125,14 @@ class CecaModule(nn.Module): return x * y.expand_as(x) -class ChannelPool(nn.Module): - def forward(self, x): - return torch.cat((torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1) - -class SpatialGate(nn.Module): - def __init__(self): - super(SpatialGate, self).__init__() - kernel_size = 7 - self.compress = ChannelPool() - self.spatial = nn.Conv2d(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2) - def forward(self, x): - x_compress = self.compress(x) - x_out = self.spatial(x_compress) - scale = torch.sigmoid(x_out) # broadcasting - return x * scale - class CecaBam(nn.Module): - def __init__(self, gate_channels, no_spatial=False): + def __init__(self, gate_channels): super(CecaBam, self).__init__() self.CecaModule = CecaModule(gate_channels) - self.no_spatial=no_spatial - if not no_spatial: - self.SpatialGate = SpatialGate() + self.SpatialAttention = SpatialAttn(kernel_size=7) + def forward(self, x): - x_out = self.CecaModule(x) - if not self.no_spatial: - x_out = self.SpatialGate(x_out) - return x_out + y = self.CecaModule(x) + y = self.SpatialAttention(y) + return y