Add separate zero_init_last_bn function to support more block variety without a mess

pull/87/head
Ross Wightman 5 years ago
parent 355aa152d5
commit a9d2424fd1

@ -87,6 +87,9 @@ class Bottle2neck(nn.Module):
self.relu = act_layer(inplace=True)
self.downsample = downsample
def zero_init_last_bn(self):
nn.init.zeros_(self.bn3.weight)
def forward(self, x):
residual = x

@ -156,26 +156,38 @@ class BasicBlock(nn.Module):
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.drop_block = drop_block
self.drop_path = drop_path
def zero_init_last_bn(self):
nn.init.zeros_(self.bn2.weight)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.act1(out)
out = self.conv2(out)
out = self.bn2(out)
x = self.conv1(x)
x = self.bn1(x)
if self.drop_block is not None:
x = self.drop_block(x)
x = self.act1(x)
x = self.conv2(x)
x = self.bn2(x)
if self.drop_block is not None:
x = self.drop_block(x)
if self.se is not None:
out = self.se(out)
x = self.se(x)
if self.downsample is not None:
residual = self.downsample(x)
if self.drop_path is not None:
x = self.drop_path(x)
out += residual
out = self.act2(out)
if self.downsample is not None:
residual = self.downsample(residual)
x += residual
x = self.act2(x)
return out
return x
class Bottleneck(nn.Module):
@ -207,31 +219,44 @@ class Bottleneck(nn.Module):
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.drop_block = drop_block
self.drop_path = drop_path
def zero_init_last_bn(self):
nn.init.zeros_(self.bn3.weight)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.act1(out)
x = self.conv1(x)
x = self.bn1(x)
if self.drop_block is not None:
x = self.drop_block(x)
x = self.act1(x)
out = self.conv2(out)
out = self.bn2(out)
out = self.act2(out)
x = self.conv2(x)
x = self.bn2(x)
if self.drop_block is not None:
x = self.drop_block(x)
x = self.act2(x)
out = self.conv3(out)
out = self.bn3(out)
x = self.conv3(x)
x = self.bn3(x)
if self.drop_block is not None:
x = self.drop_block(x)
if self.se is not None:
out = self.se(out)
x = self.se(x)
if self.downsample is not None:
residual = self.downsample(x)
if self.drop_path is not None:
x = self.drop_path(x)
out += residual
out = self.act3(out)
if self.downsample is not None:
residual = self.downsample(residual)
x += residual
x = self.act3(x)
return out
return x
class ResNet(nn.Module):
@ -367,17 +392,16 @@ class ResNet(nn.Module):
self.num_features = 512 * block.expansion
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
last_bn_name = 'bn3' if 'Bottle' in block.__name__ else 'bn2'
for n, m in self.named_modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
if zero_init_last_bn and 'layer' in n and last_bn_name in n:
# Initialize weight/gamma of last BN in each residual block to zero
nn.init.constant_(m.weight, 0.)
else:
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.weight, 1.)
nn.init.constant_(m.bias, 0.)
if zero_init_last_bn:
for m in self.modules():
if hasattr(m, 'zero_init_last_bn'):
m.zero_init_last_bn()
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1,
use_se=False, avg_down=False, down_kernel_size=1, **kwargs):

@ -63,6 +63,9 @@ class SelectiveKernelBasic(nn.Module):
self.drop_block = drop_block
self.drop_path = drop_path
def zero_init_last_bn(self):
nn.init.zeros_(self.conv2.bn.weight)
def forward(self, x):
residual = x
x = self.conv1(x)
@ -109,6 +112,9 @@ class SelectiveKernelBottleneck(nn.Module):
self.drop_block = drop_block
self.drop_path = drop_path
def zero_init_last_bn(self):
nn.init.zeros_(self.conv3.bn.weight)
def forward(self, x):
residual = x
x = self.conv1(x)

Loading…
Cancel
Save