AdaptiveAvgPool2d -> mean((2,3)) for all SE/attn layers to avoid NaN with AMP + channels_last layout. See https://github.com/pytorch/pytorch/issues/43992
parent
c2cd1a332e
commit
110a7c4982
@ -1,40 +1,36 @@
|
|||||||
from torch import nn as nn
|
from torch import nn as nn
|
||||||
from .create_act import get_act_fn
|
from .create_act import create_act_layer
|
||||||
|
|
||||||
|
|
||||||
class SEModule(nn.Module):
|
class SEModule(nn.Module):
|
||||||
|
|
||||||
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None,
|
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None,
|
||||||
gate_fn='sigmoid'):
|
gate_layer='sigmoid'):
|
||||||
super(SEModule, self).__init__()
|
super(SEModule, self).__init__()
|
||||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
||||||
reduction_channels = reduction_channels or max(channels // reduction, min_channels)
|
reduction_channels = reduction_channels or max(channels // reduction, min_channels)
|
||||||
self.fc1 = nn.Conv2d(
|
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
|
||||||
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
|
|
||||||
self.act = act_layer(inplace=True)
|
self.act = act_layer(inplace=True)
|
||||||
self.fc2 = nn.Conv2d(
|
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)
|
||||||
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
|
self.gate = create_act_layer(gate_layer)
|
||||||
self.gate_fn = get_act_fn(gate_fn)
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x_se = self.avg_pool(x)
|
x_se = x.mean((2, 3), keepdim=True)
|
||||||
x_se = self.fc1(x_se)
|
x_se = self.fc1(x_se)
|
||||||
x_se = self.act(x_se)
|
x_se = self.act(x_se)
|
||||||
x_se = self.fc2(x_se)
|
x_se = self.fc2(x_se)
|
||||||
return x * self.gate_fn(x_se)
|
return x * self.gate(x_se)
|
||||||
|
|
||||||
|
|
||||||
class EffectiveSEModule(nn.Module):
|
class EffectiveSEModule(nn.Module):
|
||||||
""" 'Effective Squeeze-Excitation
|
""" 'Effective Squeeze-Excitation
|
||||||
From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
|
From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
|
||||||
"""
|
"""
|
||||||
def __init__(self, channels, gate_fn='hard_sigmoid'):
|
def __init__(self, channels, gate_layer='hard_sigmoid'):
|
||||||
super(EffectiveSEModule, self).__init__()
|
super(EffectiveSEModule, self).__init__()
|
||||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
||||||
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
|
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
|
||||||
self.gate_fn = get_act_fn(gate_fn)
|
self.gate = create_act_layer(gate_layer, inplace=True)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x_se = self.avg_pool(x)
|
x_se = x.mean((2, 3), keepdim=True)
|
||||||
x_se = self.fc(x_se)
|
x_se = self.fc(x_se)
|
||||||
return x * self.gate_fn(x_se, inplace=True)
|
return x * self.gate(x_se)
|
||||||
|
Loading…
Reference in new issue