diff --git a/models/senet.py b/models/senet.py index 2e88d480..e66082b8 100644 --- a/models/senet.py +++ b/models/senet.py @@ -38,7 +38,7 @@ class SEModule(nn.Module): def __init__(self, channels, reduction): super(SEModule, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) + #self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc1 = nn.Conv2d( channels, channels // reduction, kernel_size=1, padding=0) self.relu = nn.ReLU(inplace=True) @@ -48,7 +48,8 @@ class SEModule(nn.Module): def forward(self, x): module_input = x - x = self.avg_pool(x) + #x = self.avg_pool(x) + x = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1) x = self.fc1(x) x = self.relu(x) x = self.fc2(x)