from torch import nn as nn from .create_act import get_act_fn class SEModule(nn.Module): def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None, gate_fn='sigmoid'): super(SEModule, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) reduction_channels = reduction_channels or max(channels // reduction, min_channels) self.fc1 = nn.Conv2d( channels, reduction_channels, kernel_size=1, padding=0, bias=True) self.act = act_layer(inplace=True) self.fc2 = nn.Conv2d( reduction_channels, channels, kernel_size=1, padding=0, bias=True) self.gate_fn = get_act_fn(gate_fn) def forward(self, x): x_se = self.avg_pool(x) x_se = self.fc1(x_se) x_se = self.act(x_se) x_se = self.fc2(x_se) return x * self.gate_fn(x_se) class EffectiveSEModule(nn.Module): """ 'Effective Squeeze-Excitation From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 """ def __init__(self, channels, gate_fn='hard_sigmoid'): super(EffectiveSEModule, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) self.gate_fn = get_act_fn(gate_fn) def forward(self, x): x_se = self.avg_pool(x) x_se = self.fc(x_se) return x * self.gate_fn(x_se, inplace=True)