|
|
|
@ -104,6 +104,18 @@ pretrained_config = {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _weight_init(m, n='', ll=''):
|
|
|
|
|
print(m, n, ll)
|
|
|
|
|
if isinstance(m, nn.Conv2d):
|
|
|
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
|
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
|
|
|
if ll and n == ll:
|
|
|
|
|
nn.init.constant_(m.weight, 0.)
|
|
|
|
|
else:
|
|
|
|
|
nn.init.constant_(m.weight, 1.)
|
|
|
|
|
nn.init.constant_(m.bias, 0.)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SEModule(nn.Module):
|
|
|
|
|
|
|
|
|
|
def __init__(self, channels, reduction):
|
|
|
|
@ -116,6 +128,9 @@ class SEModule(nn.Module):
|
|
|
|
|
channels // reduction, channels, kernel_size=1, padding=0)
|
|
|
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
|
|
|
|
|
|
for m in self.modules():
|
|
|
|
|
_weight_init(m)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
module_input = x
|
|
|
|
|
x = self.avg_pool(x)
|
|
|
|
@ -176,6 +191,9 @@ class SEBottleneck(Bottleneck):
|
|
|
|
|
self.downsample = downsample
|
|
|
|
|
self.stride = stride
|
|
|
|
|
|
|
|
|
|
for n, m in self.named_modules():
|
|
|
|
|
_weight_init(m, n, ll='bn3')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SEResNetBottleneck(Bottleneck):
|
|
|
|
|
"""
|
|
|
|
@ -201,6 +219,9 @@ class SEResNetBottleneck(Bottleneck):
|
|
|
|
|
self.downsample = downsample
|
|
|
|
|
self.stride = stride
|
|
|
|
|
|
|
|
|
|
for n, m in self.named_modules():
|
|
|
|
|
_weight_init(m, n, ll='bn3')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SEResNeXtBottleneck(Bottleneck):
|
|
|
|
|
"""
|
|
|
|
@ -225,6 +246,9 @@ class SEResNeXtBottleneck(Bottleneck):
|
|
|
|
|
self.downsample = downsample
|
|
|
|
|
self.stride = stride
|
|
|
|
|
|
|
|
|
|
for n, m in self.named_modules():
|
|
|
|
|
_weight_init(m, n, ll='bn3')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SEResNetBlock(nn.Module):
|
|
|
|
|
expansion = 1
|
|
|
|
@ -242,6 +266,9 @@ class SEResNetBlock(nn.Module):
|
|
|
|
|
self.downsample = downsample
|
|
|
|
|
self.stride = stride
|
|
|
|
|
|
|
|
|
|
for n, m in self.named_modules():
|
|
|
|
|
_weight_init(m, n, ll='bn2')
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
residual = x
|
|
|
|
|
|
|
|
|
@ -378,6 +405,12 @@ class SENet(nn.Module):
|
|
|
|
|
self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
|
|
|
|
|
self.last_linear = nn.Linear(512 * block.expansion, num_classes)
|
|
|
|
|
|
|
|
|
|
for n, m in self.named_children():
|
|
|
|
|
if n == 'layer0':
|
|
|
|
|
m.apply(_weight_init)
|
|
|
|
|
else:
|
|
|
|
|
_weight_init(m)
|
|
|
|
|
|
|
|
|
|
def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
|
|
|
|
|
downsample_kernel_size=1, downsample_padding=0):
|
|
|
|
|
downsample = None
|
|
|
|
|