|
|
|
@ -156,26 +156,38 @@ class BasicBlock(nn.Module):
|
|
|
|
|
self.downsample = downsample
|
|
|
|
|
self.stride = stride
|
|
|
|
|
self.dilation = dilation
|
|
|
|
|
self.drop_block = drop_block
|
|
|
|
|
self.drop_path = drop_path
|
|
|
|
|
|
|
|
|
|
def zero_init_last_bn(self):
|
|
|
|
|
nn.init.zeros_(self.bn2.weight)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
residual = x
|
|
|
|
|
|
|
|
|
|
out = self.conv1(x)
|
|
|
|
|
out = self.bn1(out)
|
|
|
|
|
out = self.act1(out)
|
|
|
|
|
out = self.conv2(out)
|
|
|
|
|
out = self.bn2(out)
|
|
|
|
|
x = self.conv1(x)
|
|
|
|
|
x = self.bn1(x)
|
|
|
|
|
if self.drop_block is not None:
|
|
|
|
|
x = self.drop_block(x)
|
|
|
|
|
x = self.act1(x)
|
|
|
|
|
|
|
|
|
|
x = self.conv2(x)
|
|
|
|
|
x = self.bn2(x)
|
|
|
|
|
if self.drop_block is not None:
|
|
|
|
|
x = self.drop_block(x)
|
|
|
|
|
|
|
|
|
|
if self.se is not None:
|
|
|
|
|
out = self.se(out)
|
|
|
|
|
x = self.se(x)
|
|
|
|
|
|
|
|
|
|
if self.downsample is not None:
|
|
|
|
|
residual = self.downsample(x)
|
|
|
|
|
if self.drop_path is not None:
|
|
|
|
|
x = self.drop_path(x)
|
|
|
|
|
|
|
|
|
|
out += residual
|
|
|
|
|
out = self.act2(out)
|
|
|
|
|
if self.downsample is not None:
|
|
|
|
|
residual = self.downsample(residual)
|
|
|
|
|
x += residual
|
|
|
|
|
x = self.act2(x)
|
|
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Bottleneck(nn.Module):
|
|
|
|
@ -207,31 +219,44 @@ class Bottleneck(nn.Module):
|
|
|
|
|
self.downsample = downsample
|
|
|
|
|
self.stride = stride
|
|
|
|
|
self.dilation = dilation
|
|
|
|
|
self.drop_block = drop_block
|
|
|
|
|
self.drop_path = drop_path
|
|
|
|
|
|
|
|
|
|
def zero_init_last_bn(self):
|
|
|
|
|
nn.init.zeros_(self.bn3.weight)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
residual = x
|
|
|
|
|
|
|
|
|
|
out = self.conv1(x)
|
|
|
|
|
out = self.bn1(out)
|
|
|
|
|
out = self.act1(out)
|
|
|
|
|
x = self.conv1(x)
|
|
|
|
|
x = self.bn1(x)
|
|
|
|
|
if self.drop_block is not None:
|
|
|
|
|
x = self.drop_block(x)
|
|
|
|
|
x = self.act1(x)
|
|
|
|
|
|
|
|
|
|
out = self.conv2(out)
|
|
|
|
|
out = self.bn2(out)
|
|
|
|
|
out = self.act2(out)
|
|
|
|
|
x = self.conv2(x)
|
|
|
|
|
x = self.bn2(x)
|
|
|
|
|
if self.drop_block is not None:
|
|
|
|
|
x = self.drop_block(x)
|
|
|
|
|
x = self.act2(x)
|
|
|
|
|
|
|
|
|
|
out = self.conv3(out)
|
|
|
|
|
out = self.bn3(out)
|
|
|
|
|
x = self.conv3(x)
|
|
|
|
|
x = self.bn3(x)
|
|
|
|
|
if self.drop_block is not None:
|
|
|
|
|
x = self.drop_block(x)
|
|
|
|
|
|
|
|
|
|
if self.se is not None:
|
|
|
|
|
out = self.se(out)
|
|
|
|
|
x = self.se(x)
|
|
|
|
|
|
|
|
|
|
if self.downsample is not None:
|
|
|
|
|
residual = self.downsample(x)
|
|
|
|
|
if self.drop_path is not None:
|
|
|
|
|
x = self.drop_path(x)
|
|
|
|
|
|
|
|
|
|
out += residual
|
|
|
|
|
out = self.act3(out)
|
|
|
|
|
if self.downsample is not None:
|
|
|
|
|
residual = self.downsample(residual)
|
|
|
|
|
x += residual
|
|
|
|
|
x = self.act3(x)
|
|
|
|
|
|
|
|
|
|
return out
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResNet(nn.Module):
|
|
|
|
@ -367,17 +392,16 @@ class ResNet(nn.Module):
|
|
|
|
|
self.num_features = 512 * block.expansion
|
|
|
|
|
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
|
|
|
|
|
|
|
|
|
|
last_bn_name = 'bn3' if 'Bottle' in block.__name__ else 'bn2'
|
|
|
|
|
for n, m in self.named_modules():
|
|
|
|
|
if isinstance(m, nn.Conv2d):
|
|
|
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
|
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
|
|
|
if zero_init_last_bn and 'layer' in n and last_bn_name in n:
|
|
|
|
|
# Initialize weight/gamma of last BN in each residual block to zero
|
|
|
|
|
nn.init.constant_(m.weight, 0.)
|
|
|
|
|
else:
|
|
|
|
|
nn.init.constant_(m.weight, 1.)
|
|
|
|
|
nn.init.constant_(m.bias, 0.)
|
|
|
|
|
if zero_init_last_bn:
|
|
|
|
|
for m in self.modules():
|
|
|
|
|
if hasattr(m, 'zero_init_last_bn'):
|
|
|
|
|
m.zero_init_last_bn()
|
|
|
|
|
|
|
|
|
|
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1,
|
|
|
|
|
use_se=False, avg_down=False, down_kernel_size=1, **kwargs):
|
|
|
|
|