From 110a7c4982db758a33aabd4c7bbea09c08202bc7 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 1 Sep 2020 16:01:36 -0700 Subject: [PATCH] AdaptiveAvgPool2d -> mean((2,3)) for all SE/attn layers to avoid NaN with AMP + channels_last layout. See https://github.com/pytorch/pytorch/issues/43992 --- timm/models/efficientnet_blocks.py | 8 +++---- timm/models/layers/cbam.py | 9 ++++---- timm/models/layers/eca.py | 29 ++++++-------------------- timm/models/layers/se.py | 26 ++++++++++------------- timm/models/layers/selective_kernel.py | 4 +--- timm/models/rexnet.py | 9 +++----- timm/models/senet.py | 9 +++----- timm/models/tresnet.py | 5 ++--- 8 files changed, 33 insertions(+), 66 deletions(-) diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index d5fdce79..d7421ff4 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -106,20 +106,18 @@ class SqueezeExcite(nn.Module): def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1, **_): super(SqueezeExcite, self).__init__() - self.gate_fn = gate_fn reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) - self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) self.act1 = act_layer(inplace=True) self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) + self.gate_fn = gate_fn def forward(self, x): - x_se = self.avg_pool(x) + x_se = x.mean((2, 3), keepdim=True) x_se = self.conv_reduce(x_se) x_se = self.act1(x_se) x_se = self.conv_expand(x_se) - x = x * self.gate_fn(x_se) - return x + return x * self.gate_fn(x_se) class ConvBnAct(nn.Module): diff --git a/timm/models/layers/cbam.py b/timm/models/layers/cbam.py index 600d51fa..44e2fe6d 100644 --- a/timm/models/layers/cbam.py +++ b/timm/models/layers/cbam.py @@ -10,6 +10,7 @@ Hacked together by / Copyright 2020 Ross Wightman import torch from torch import nn as nn +import torch.nn.functional as F from .conv_bn_act import ConvBnAct @@ -18,15 +19,13 @@ class ChannelAttn(nn.Module): """ def __init__(self, channels, reduction=16, act_layer=nn.ReLU): super(ChannelAttn, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.max_pool = nn.AdaptiveMaxPool2d(1) self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False) self.act = act_layer(inplace=True) self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False) def forward(self, x): - x_avg = self.avg_pool(x) - x_max = self.max_pool(x) + x_avg = x.mean((2, 3), keepdim=True) + x_max = F.adaptive_max_pool2d(x, 1) x_avg = self.fc2(self.act(self.fc1(x_avg))) x_max = self.fc2(self.act(self.fc1(x_max))) x_attn = x_avg + x_max @@ -40,7 +39,7 @@ class LightChannelAttn(ChannelAttn): super(LightChannelAttn, self).__init__(channels, reduction) def forward(self, x): - x_pool = 0.5 * self.avg_pool(x) + 0.5 * self.max_pool(x) + x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * F.adaptive_max_pool2d(x, 1) x_attn = self.fc2(self.act(self.fc1(x_pool))) return x * x_attn.sigmoid() diff --git a/timm/models/layers/eca.py b/timm/models/layers/eca.py index 497c5b04..3a7f8b82 100644 --- a/timm/models/layers/eca.py +++ b/timm/models/layers/eca.py @@ -52,22 +52,15 @@ class EcaModule(nn.Module): def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1): super(EcaModule, self).__init__() assert kernel_size % 2 == 1 - if channels is not None: t = int(abs(math.log(channels, 2) + beta) / gamma) kernel_size = max(t if t % 2 else t + 1, 3) - self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False) def forward(self, x): - # Feature descriptor on the global spatial information - y = self.avg_pool(x) - # Reshape for convolution - y = y.view(x.shape[0], 1, -1) - # Two different branches of ECA module + y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv y = self.conv(y) - # Multi-scale information fusion y = y.view(x.shape[0], -1, 1, 1).sigmoid() return x * y.expand_as(x) @@ -95,30 +88,20 @@ class CecaModule(nn.Module): def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1): super(CecaModule, self).__init__() assert kernel_size % 2 == 1 - if channels is not None: t = int(abs(math.log(channels, 2) + beta) / gamma) kernel_size = max(t if t % 2 else t + 1, 3) - self.avg_pool = nn.AdaptiveAvgPool2d(1) - #pytorch circular padding mode is buggy as of pytorch 1.4 - #see https://github.com/pytorch/pytorch/pull/17240 - - #implement manual circular padding + # PyTorch circular padding mode is buggy as of pytorch 1.4 + # see https://github.com/pytorch/pytorch/pull/17240 + # implement manual circular padding self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False) self.padding = (kernel_size - 1) // 2 def forward(self, x): - # Feature descriptor on the global spatial information - y = self.avg_pool(x) - + y = x.mean((2, 3)).view(x.shape[0], 1, -1) # Manually implement circular padding, F.pad does not seemed to be bugged - y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular') - - # Two different branches of ECA module + y = F.pad(y, (self.padding, self.padding), mode='circular') y = self.conv(y) - - # Multi-scale information fusion y = y.view(x.shape[0], -1, 1, 1).sigmoid() - return x * y.expand_as(x) diff --git a/timm/models/layers/se.py b/timm/models/layers/se.py index 578ebf08..a896fb71 100644 --- a/timm/models/layers/se.py +++ b/timm/models/layers/se.py @@ -1,40 +1,36 @@ from torch import nn as nn -from .create_act import get_act_fn +from .create_act import create_act_layer class SEModule(nn.Module): def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None, - gate_fn='sigmoid'): + gate_layer='sigmoid'): super(SEModule, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) reduction_channels = reduction_channels or max(channels // reduction, min_channels) - self.fc1 = nn.Conv2d( - channels, reduction_channels, kernel_size=1, padding=0, bias=True) + self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True) self.act = act_layer(inplace=True) - self.fc2 = nn.Conv2d( - reduction_channels, channels, kernel_size=1, padding=0, bias=True) - self.gate_fn = get_act_fn(gate_fn) + self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True) + self.gate = create_act_layer(gate_layer) def forward(self, x): - x_se = self.avg_pool(x) + x_se = x.mean((2, 3), keepdim=True) x_se = self.fc1(x_se) x_se = self.act(x_se) x_se = self.fc2(x_se) - return x * self.gate_fn(x_se) + return x * self.gate(x_se) class EffectiveSEModule(nn.Module): """ 'Effective Squeeze-Excitation From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667 """ - def __init__(self, channels, gate_fn='hard_sigmoid'): + def __init__(self, channels, gate_layer='hard_sigmoid'): super(EffectiveSEModule, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0) - self.gate_fn = get_act_fn(gate_fn) + self.gate = create_act_layer(gate_layer, inplace=True) def forward(self, x): - x_se = self.avg_pool(x) + x_se = x.mean((2, 3), keepdim=True) x_se = self.fc(x_se) - return x * self.gate_fn(x_se, inplace=True) + return x * self.gate(x_se) diff --git a/timm/models/layers/selective_kernel.py b/timm/models/layers/selective_kernel.py index 2efaa487..10bfd0e0 100644 --- a/timm/models/layers/selective_kernel.py +++ b/timm/models/layers/selective_kernel.py @@ -27,7 +27,6 @@ class SelectiveKernelAttn(nn.Module): """ super(SelectiveKernelAttn, self).__init__() self.num_paths = num_paths - self.pool = nn.AdaptiveAvgPool2d(1) self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False) self.bn = norm_layer(attn_channels) self.act = act_layer(inplace=True) @@ -35,8 +34,7 @@ class SelectiveKernelAttn(nn.Module): def forward(self, x): assert x.shape[1] == self.num_paths - x = torch.sum(x, dim=1) - x = self.pool(x) + x = x.sum(1).mean((2, 3), keepdim=True) x = self.fc_reduce(x) x = self.bn(x) x = self.act(x) diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 03e6ee02..b7522a05 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -59,18 +59,15 @@ class SEWithNorm(nn.Module): def __init__(self, channels, reduction=16, act_layer=nn.ReLU, divisor=1, reduction_channels=None, gate_layer='sigmoid'): super(SEWithNorm, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) reduction_channels = reduction_channels or make_divisible(channels // reduction, divisor=divisor) - self.fc1 = nn.Conv2d( - channels, reduction_channels, kernel_size=1, padding=0, bias=True) + self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True) self.bn = nn.BatchNorm2d(reduction_channels) self.act = act_layer(inplace=True) - self.fc2 = nn.Conv2d( - reduction_channels, channels, kernel_size=1, padding=0, bias=True) + self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True) self.gate = create_act_layer(gate_layer) def forward(self, x): - x_se = self.avg_pool(x) + x_se = x.mean((2, 3), keepdim=True) x_se = self.fc1(x_se) x_se = self.bn(x_se) x_se = self.act(x_se) diff --git a/timm/models/senet.py b/timm/models/senet.py index 2155ec81..8073229a 100644 --- a/timm/models/senet.py +++ b/timm/models/senet.py @@ -71,17 +71,14 @@ class SEModule(nn.Module): def __init__(self, channels, reduction): super(SEModule, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.fc1 = nn.Conv2d( - channels, channels // reduction, kernel_size=1, padding=0) + self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1) self.relu = nn.ReLU(inplace=True) - self.fc2 = nn.Conv2d( - channels // reduction, channels, kernel_size=1, padding=0) + self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1) self.sigmoid = nn.Sigmoid() def forward(self, x): module_input = x - x = self.avg_pool(x) + x = x.mean((2, 3), keepdim=True) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) diff --git a/timm/models/tresnet.py b/timm/models/tresnet.py index e060c20d..75b545e5 100644 --- a/timm/models/tresnet.py +++ b/timm/models/tresnet.py @@ -56,10 +56,9 @@ class FastGlobalAvgPool2d(nn.Module): def forward(self, x): if self.flatten: - in_size = x.size() - return x.view((in_size[0], in_size[1], -1)).mean(dim=2) + return x.mean((2, 3)) else: - return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) + return x.mean((2, 3), keepdim=True) def feat_mult(self): return 1