From 58571e992e9f938c03ce3315834b6d1a8d0b57a3 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 5 Apr 2019 10:53:13 -0700 Subject: [PATCH] Change block avgpool in senets to mean for performance issues with NVIDIA and AMP especially --- models/senet.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/models/senet.py b/models/senet.py index 2e88d480..e66082b8 100644 --- a/models/senet.py +++ b/models/senet.py @@ -38,7 +38,7 @@ class SEModule(nn.Module): def __init__(self, channels, reduction): super(SEModule, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) + #self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc1 = nn.Conv2d( channels, channels // reduction, kernel_size=1, padding=0) self.relu = nn.ReLU(inplace=True) @@ -48,7 +48,8 @@ class SEModule(nn.Module): def forward(self, x): module_input = x - x = self.avg_pool(x) + #x = self.avg_pool(x) + x = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) x = self.fc1(x) x = self.relu(x) x = self.fc2(x)