From afb357ff68f8efafb7985bf031047e7e42876dc9 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 22 Apr 2019 17:46:17 -0700 Subject: [PATCH] Make genmobilenet weight init switchable, fix fan_out in google style linear init --- models/genmobilenet.py | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) diff --git a/models/genmobilenet.py b/models/genmobilenet.py index 46647fd8..842ba33f 100644 --- a/models/genmobilenet.py +++ b/models/genmobilenet.py @@ -289,9 +289,11 @@ class _BlockBuilder: return blocks -def _initialize_weight(m): +def _initialize_weight_goog(m): + # weight init as per Tensorflow Official impl + # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py if isinstance(m, nn.Conv2d): - n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels # fan-out m.weight.data.normal_(0, math.sqrt(2.0 / n)) if m.bias is not None: m.bias.data.zero_() @@ -299,12 +301,22 @@ def _initialize_weight(m): m.weight.data.fill_(1.0) m.bias.data.zero_() elif isinstance(m, nn.Linear): - n = m.weight.size(1) + n = m.weight.size(0) # fan-out init_range = 1.0 / math.sqrt(n) m.weight.data.uniform_(-init_range, init_range) m.bias.data.zero_() +def _initialize_weight_default(m): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') + + class DepthwiseSeparableConv(nn.Module): def __init__(self, in_chs, out_chs, kernel_size, stride=1, act_fn=F.relu, noskip=False, pw_act=False, @@ -485,7 +497,8 @@ class GenMobileNet(nn.Module): def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280, depth_multiplier=1.0, depth_divisor=8, min_depth=None, bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT, - drop_rate=0., act_fn=F.relu, global_pool='avg', skip_head_conv=False): + drop_rate=0., act_fn=F.relu, global_pool='avg', skip_head_conv=False, + weight_init='goog'): super(GenMobileNet, self).__init__() self.num_classes = num_classes self.depth_multiplier = depth_multiplier @@ -515,7 +528,10 @@ class GenMobileNet(nn.Module): self.classifier = nn.Linear(self.num_features, self.num_classes) for m in self.modules(): - _initialize_weight(m) + if weight_init == 'goog': + _initialize_weight_goog(m) + else: + _initialize_weight_default(m) def get_classifier(self): return self.classifier