|
|
|
@ -38,7 +38,7 @@ class SEModule(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(self, channels, reduction):
|
|
|
|
|
super(SEModule, self).__init__()
|
|
|
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
|
|
|
#self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
|
|
|
self.fc1 = nn.Conv2d(
|
|
|
|
|
channels, channels // reduction, kernel_size=1, padding=0)
|
|
|
|
|
self.relu = nn.ReLU(inplace=True)
|
|
|
|
@ -48,7 +48,8 @@ class SEModule(nn.Module):
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
module_input = x
|
|
|
|
|
x = self.avg_pool(x)
|
|
|
|
|
#x = self.avg_pool(x)
|
|
|
|
|
x = x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
|
|
|
|
|
x = self.fc1(x)
|
|
|
|
|
x = self.relu(x)
|
|
|
|
|
x = self.fc2(x)
|
|
|
|
|