Add CecaBam

CecaBam is circular ECA followed by BAM for both channelwise and spatial attention
pull/87/head
Chris Ha 6 years ago
parent e6a762346a
commit dbdb3ce963

@ -32,6 +32,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
''' '''
import math import math
import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
@ -119,3 +120,33 @@ class CecaModule(nn.Module):
y = self.sigmoid(y.view(x.shape[0], -1, 1, 1)) y = self.sigmoid(y.view(x.shape[0], -1, 1, 1))
return x * y.expand_as(x) return x * y.expand_as(x)
class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat((torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1)
class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7
self.compress = ChannelPool()
self.spatial = nn.Conv2d(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, bias=False)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = torch.sigmoid(x_out) # broadcasting
return x * scale
class CecaBam(nn.Module):
def __init__(self, gate_channels, no_spatial=False):
super(CecaBam, self).__init__()
self.CecaModule = CecaModule(gate_channels)
self.no_spatial=no_spatial
if not no_spatial:
self.SpatialGate = SpatialGate()
def forward(self, x):
x_out = self.CecaModule(x)
if not self.no_spatial:
x_out = self.SpatialGate(x_out)
return x_out

Loading…
Cancel
Save