You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/models/layers/se.py

41 lines
1.5 KiB

from torch import nn as nn
from .create_act import get_act_fn
class SEModule(nn.Module):
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None,
gate_fn='sigmoid'):
super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
reduction_channels = reduction_channels or max(channels // reduction, min_channels)
self.fc1 = nn.Conv2d(
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d(
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
self.gate_fn = get_act_fn(gate_fn)
def forward(self, x):
x_se = self.avg_pool(x)
x_se = self.fc1(x_se)
x_se = self.act(x_se)
x_se = self.fc2(x_se)
return x * self.gate_fn(x_se)
class EffectiveSEModule(nn.Module):
""" 'Effective Squeeze-Excitation
From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
"""
def __init__(self, channel, gate_fn='hard_sigmoid'):
super(EffectiveSEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Conv2d(channel, channel, kernel_size=1, padding=0)
self.gate_fn = get_act_fn(gate_fn)
def forward(self, x):
x_se = self.avg_pool(x)
x_se = self.fc(x_se)
return x * self.gate_fn(x_se, inplace=True)