From a9d2424fd1680590146bcd4eed912cc84bbe6a5e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 30 Jan 2020 16:51:49 -0800 Subject: [PATCH] Add separate zero_init_last_bn function to support more block variety without a mess --- timm/models/res2net.py | 3 ++ timm/models/resnet.py | 86 +++++++++++++++++++++++++++--------------- timm/models/sknet.py | 6 +++ 3 files changed, 64 insertions(+), 31 deletions(-) diff --git a/timm/models/res2net.py b/timm/models/res2net.py index c83aba62..bcb7eaaf 100644 --- a/timm/models/res2net.py +++ b/timm/models/res2net.py @@ -87,6 +87,9 @@ class Bottle2neck(nn.Module): self.relu = act_layer(inplace=True) self.downsample = downsample + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) + def forward(self, x): residual = x diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 3e0ce23e..d97a6aad 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -156,26 +156,38 @@ class BasicBlock(nn.Module): self.downsample = downsample self.stride = stride self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn2.weight) def forward(self, x): residual = x - out = self.conv1(x) - out = self.bn1(out) - out = self.act1(out) - out = self.conv2(out) - out = self.bn2(out) + x = self.conv1(x) + x = self.bn1(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.bn2(x) + if self.drop_block is not None: + x = self.drop_block(x) if self.se is not None: - out = self.se(out) + x = self.se(x) - if self.downsample is not None: - residual = self.downsample(x) + if self.drop_path is not None: + x = self.drop_path(x) - out += residual - out = self.act2(out) + if self.downsample is not None: + residual = self.downsample(residual) + x += residual + x = self.act2(x) - return out + return x class Bottleneck(nn.Module): @@ -207,31 +219,44 @@ class Bottleneck(nn.Module): self.downsample = downsample self.stride = stride self.dilation = dilation + self.drop_block = drop_block + self.drop_path = drop_path + + def zero_init_last_bn(self): + nn.init.zeros_(self.bn3.weight) def forward(self, x): residual = x - out = self.conv1(x) - out = self.bn1(out) - out = self.act1(out) + x = self.conv1(x) + x = self.bn1(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act1(x) - out = self.conv2(out) - out = self.bn2(out) - out = self.act2(out) + x = self.conv2(x) + x = self.bn2(x) + if self.drop_block is not None: + x = self.drop_block(x) + x = self.act2(x) - out = self.conv3(out) - out = self.bn3(out) + x = self.conv3(x) + x = self.bn3(x) + if self.drop_block is not None: + x = self.drop_block(x) if self.se is not None: - out = self.se(out) + x = self.se(x) - if self.downsample is not None: - residual = self.downsample(x) + if self.drop_path is not None: + x = self.drop_path(x) - out += residual - out = self.act3(out) + if self.downsample is not None: + residual = self.downsample(residual) + x += residual + x = self.act3(x) - return out + return x class ResNet(nn.Module): @@ -367,17 +392,16 @@ class ResNet(nn.Module): self.num_features = 512 * block.expansion self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes) - last_bn_name = 'bn3' if 'Bottle' in block.__name__ else 'bn2' for n, m in self.named_modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') elif isinstance(m, nn.BatchNorm2d): - if zero_init_last_bn and 'layer' in n and last_bn_name in n: - # Initialize weight/gamma of last BN in each residual block to zero - nn.init.constant_(m.weight, 0.) - else: - nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.weight, 1.) nn.init.constant_(m.bias, 0.) + if zero_init_last_bn: + for m in self.modules(): + if hasattr(m, 'zero_init_last_bn'): + m.zero_init_last_bn() def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1, use_se=False, avg_down=False, down_kernel_size=1, **kwargs): diff --git a/timm/models/sknet.py b/timm/models/sknet.py index 41e19075..7cf4fbe6 100644 --- a/timm/models/sknet.py +++ b/timm/models/sknet.py @@ -63,6 +63,9 @@ class SelectiveKernelBasic(nn.Module): self.drop_block = drop_block self.drop_path = drop_path + def zero_init_last_bn(self): + nn.init.zeros_(self.conv2.bn.weight) + def forward(self, x): residual = x x = self.conv1(x) @@ -109,6 +112,9 @@ class SelectiveKernelBottleneck(nn.Module): self.drop_block = drop_block self.drop_path = drop_path + def zero_init_last_bn(self): + nn.init.zeros_(self.conv3.bn.weight) + def forward(self, x): residual = x x = self.conv1(x)