AdaptiveAvgPool2d -> mean((2,3)) for all SE/attn layers to avoid NaN with AMP + channels_last layout. See https://github.com/pytorch/pytorch/issues/43992

pull/233/head
Ross Wightman 4 years ago
parent c2cd1a332e
commit 110a7c4982

@ -106,20 +106,18 @@ class SqueezeExcite(nn.Module):
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None,
act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1, **_): act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1, **_):
super(SqueezeExcite, self).__init__() super(SqueezeExcite, self).__init__()
self.gate_fn = gate_fn
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
self.gate_fn = gate_fn
def forward(self, x): def forward(self, x):
x_se = self.avg_pool(x) x_se = x.mean((2, 3), keepdim=True)
x_se = self.conv_reduce(x_se) x_se = self.conv_reduce(x_se)
x_se = self.act1(x_se) x_se = self.act1(x_se)
x_se = self.conv_expand(x_se) x_se = self.conv_expand(x_se)
x = x * self.gate_fn(x_se) return x * self.gate_fn(x_se)
return x
class ConvBnAct(nn.Module): class ConvBnAct(nn.Module):

@ -10,6 +10,7 @@ Hacked together by / Copyright 2020 Ross Wightman
import torch import torch
from torch import nn as nn from torch import nn as nn
import torch.nn.functional as F
from .conv_bn_act import ConvBnAct from .conv_bn_act import ConvBnAct
@ -18,15 +19,13 @@ class ChannelAttn(nn.Module):
""" """
def __init__(self, channels, reduction=16, act_layer=nn.ReLU): def __init__(self, channels, reduction=16, act_layer=nn.ReLU):
super(ChannelAttn, self).__init__() super(ChannelAttn, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False) self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False)
self.act = act_layer(inplace=True) self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False) self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False)
def forward(self, x): def forward(self, x):
x_avg = self.avg_pool(x) x_avg = x.mean((2, 3), keepdim=True)
x_max = self.max_pool(x) x_max = F.adaptive_max_pool2d(x, 1)
x_avg = self.fc2(self.act(self.fc1(x_avg))) x_avg = self.fc2(self.act(self.fc1(x_avg)))
x_max = self.fc2(self.act(self.fc1(x_max))) x_max = self.fc2(self.act(self.fc1(x_max)))
x_attn = x_avg + x_max x_attn = x_avg + x_max
@ -40,7 +39,7 @@ class LightChannelAttn(ChannelAttn):
super(LightChannelAttn, self).__init__(channels, reduction) super(LightChannelAttn, self).__init__(channels, reduction)
def forward(self, x): def forward(self, x):
x_pool = 0.5 * self.avg_pool(x) + 0.5 * self.max_pool(x) x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * F.adaptive_max_pool2d(x, 1)
x_attn = self.fc2(self.act(self.fc1(x_pool))) x_attn = self.fc2(self.act(self.fc1(x_pool)))
return x * x_attn.sigmoid() return x * x_attn.sigmoid()

@ -52,22 +52,15 @@ class EcaModule(nn.Module):
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1): def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
super(EcaModule, self).__init__() super(EcaModule, self).__init__()
assert kernel_size % 2 == 1 assert kernel_size % 2 == 1
if channels is not None: if channels is not None:
t = int(abs(math.log(channels, 2) + beta) / gamma) t = int(abs(math.log(channels, 2) + beta) / gamma)
kernel_size = max(t if t % 2 else t + 1, 3) kernel_size = max(t if t % 2 else t + 1, 3)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False) self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
def forward(self, x): def forward(self, x):
# Feature descriptor on the global spatial information y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv
y = self.avg_pool(x)
# Reshape for convolution
y = y.view(x.shape[0], 1, -1)
# Two different branches of ECA module
y = self.conv(y) y = self.conv(y)
# Multi-scale information fusion
y = y.view(x.shape[0], -1, 1, 1).sigmoid() y = y.view(x.shape[0], -1, 1, 1).sigmoid()
return x * y.expand_as(x) return x * y.expand_as(x)
@ -95,30 +88,20 @@ class CecaModule(nn.Module):
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1): def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
super(CecaModule, self).__init__() super(CecaModule, self).__init__()
assert kernel_size % 2 == 1 assert kernel_size % 2 == 1
if channels is not None: if channels is not None:
t = int(abs(math.log(channels, 2) + beta) / gamma) t = int(abs(math.log(channels, 2) + beta) / gamma)
kernel_size = max(t if t % 2 else t + 1, 3) kernel_size = max(t if t % 2 else t + 1, 3)
self.avg_pool = nn.AdaptiveAvgPool2d(1) # PyTorch circular padding mode is buggy as of pytorch 1.4
#pytorch circular padding mode is buggy as of pytorch 1.4 # see https://github.com/pytorch/pytorch/pull/17240
#see https://github.com/pytorch/pytorch/pull/17240 # implement manual circular padding
#implement manual circular padding
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False) self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False)
self.padding = (kernel_size - 1) // 2 self.padding = (kernel_size - 1) // 2
def forward(self, x): def forward(self, x):
# Feature descriptor on the global spatial information y = x.mean((2, 3)).view(x.shape[0], 1, -1)
y = self.avg_pool(x)
# Manually implement circular padding, F.pad does not seemed to be bugged # Manually implement circular padding, F.pad does not seemed to be bugged
y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular') y = F.pad(y, (self.padding, self.padding), mode='circular')
# Two different branches of ECA module
y = self.conv(y) y = self.conv(y)
# Multi-scale information fusion
y = y.view(x.shape[0], -1, 1, 1).sigmoid() y = y.view(x.shape[0], -1, 1, 1).sigmoid()
return x * y.expand_as(x) return x * y.expand_as(x)

@ -1,40 +1,36 @@
from torch import nn as nn from torch import nn as nn
from .create_act import get_act_fn from .create_act import create_act_layer
class SEModule(nn.Module): class SEModule(nn.Module):
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None, def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None,
gate_fn='sigmoid'): gate_layer='sigmoid'):
super(SEModule, self).__init__() super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
reduction_channels = reduction_channels or max(channels // reduction, min_channels) reduction_channels = reduction_channels or max(channels // reduction, min_channels)
self.fc1 = nn.Conv2d( self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
self.act = act_layer(inplace=True) self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d( self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)
reduction_channels, channels, kernel_size=1, padding=0, bias=True) self.gate = create_act_layer(gate_layer)
self.gate_fn = get_act_fn(gate_fn)
def forward(self, x): def forward(self, x):
x_se = self.avg_pool(x) x_se = x.mean((2, 3), keepdim=True)
x_se = self.fc1(x_se) x_se = self.fc1(x_se)
x_se = self.act(x_se) x_se = self.act(x_se)
x_se = self.fc2(x_se) x_se = self.fc2(x_se)
return x * self.gate_fn(x_se) return x * self.gate(x_se)
class EffectiveSEModule(nn.Module): class EffectiveSEModule(nn.Module):
""" 'Effective Squeeze-Excitation """ 'Effective Squeeze-Excitation
From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
""" """
def __init__(self, channels, gate_fn='hard_sigmoid'): def __init__(self, channels, gate_layer='hard_sigmoid'):
super(EffectiveSEModule, self).__init__() super(EffectiveSEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
self.gate_fn = get_act_fn(gate_fn) self.gate = create_act_layer(gate_layer, inplace=True)
def forward(self, x): def forward(self, x):
x_se = self.avg_pool(x) x_se = x.mean((2, 3), keepdim=True)
x_se = self.fc(x_se) x_se = self.fc(x_se)
return x * self.gate_fn(x_se, inplace=True) return x * self.gate(x_se)

@ -27,7 +27,6 @@ class SelectiveKernelAttn(nn.Module):
""" """
super(SelectiveKernelAttn, self).__init__() super(SelectiveKernelAttn, self).__init__()
self.num_paths = num_paths self.num_paths = num_paths
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False)
self.bn = norm_layer(attn_channels) self.bn = norm_layer(attn_channels)
self.act = act_layer(inplace=True) self.act = act_layer(inplace=True)
@ -35,8 +34,7 @@ class SelectiveKernelAttn(nn.Module):
def forward(self, x): def forward(self, x):
assert x.shape[1] == self.num_paths assert x.shape[1] == self.num_paths
x = torch.sum(x, dim=1) x = x.sum(1).mean((2, 3), keepdim=True)
x = self.pool(x)
x = self.fc_reduce(x) x = self.fc_reduce(x)
x = self.bn(x) x = self.bn(x)
x = self.act(x) x = self.act(x)

@ -59,18 +59,15 @@ class SEWithNorm(nn.Module):
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, divisor=1, reduction_channels=None, def __init__(self, channels, reduction=16, act_layer=nn.ReLU, divisor=1, reduction_channels=None,
gate_layer='sigmoid'): gate_layer='sigmoid'):
super(SEWithNorm, self).__init__() super(SEWithNorm, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
reduction_channels = reduction_channels or make_divisible(channels // reduction, divisor=divisor) reduction_channels = reduction_channels or make_divisible(channels // reduction, divisor=divisor)
self.fc1 = nn.Conv2d( self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
self.bn = nn.BatchNorm2d(reduction_channels) self.bn = nn.BatchNorm2d(reduction_channels)
self.act = act_layer(inplace=True) self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d( self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
self.gate = create_act_layer(gate_layer) self.gate = create_act_layer(gate_layer)
def forward(self, x): def forward(self, x):
x_se = self.avg_pool(x) x_se = x.mean((2, 3), keepdim=True)
x_se = self.fc1(x_se) x_se = self.fc1(x_se)
x_se = self.bn(x_se) x_se = self.bn(x_se)
x_se = self.act(x_se) x_se = self.act(x_se)

@ -71,17 +71,14 @@ class SEModule(nn.Module):
def __init__(self, channels, reduction): def __init__(self, channels, reduction):
super(SEModule, self).__init__() super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1)
self.fc1 = nn.Conv2d(
channels, channels // reduction, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d( self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1)
channels // reduction, channels, kernel_size=1, padding=0)
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
def forward(self, x): def forward(self, x):
module_input = x module_input = x
x = self.avg_pool(x) x = x.mean((2, 3), keepdim=True)
x = self.fc1(x) x = self.fc1(x)
x = self.relu(x) x = self.relu(x)
x = self.fc2(x) x = self.fc2(x)

@ -56,10 +56,9 @@ class FastGlobalAvgPool2d(nn.Module):
def forward(self, x): def forward(self, x):
if self.flatten: if self.flatten:
in_size = x.size() return x.mean((2, 3))
return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
else: else:
return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) return x.mean((2, 3), keepdim=True)
def feat_mult(self): def feat_mult(self):
return 1 return 1

Loading…
Cancel
Save