Make genmobilenet weight init switchable, fix fan_out in google style linear init

pull/1/head
Ross Wightman 6 years ago
parent 0a853990e7
commit afb357ff68

@ -289,9 +289,11 @@ class _BlockBuilder:
return blocks return blocks
def _initialize_weight(m): def _initialize_weight_goog(m):
# weight init as per Tensorflow Official impl
# https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels # fan-out
m.weight.data.normal_(0, math.sqrt(2.0 / n)) m.weight.data.normal_(0, math.sqrt(2.0 / n))
if m.bias is not None: if m.bias is not None:
m.bias.data.zero_() m.bias.data.zero_()
@ -299,12 +301,22 @@ def _initialize_weight(m):
m.weight.data.fill_(1.0) m.weight.data.fill_(1.0)
m.bias.data.zero_() m.bias.data.zero_()
elif isinstance(m, nn.Linear): elif isinstance(m, nn.Linear):
n = m.weight.size(1) n = m.weight.size(0) # fan-out
init_range = 1.0 / math.sqrt(n) init_range = 1.0 / math.sqrt(n)
m.weight.data.uniform_(-init_range, init_range) m.weight.data.uniform_(-init_range, init_range)
m.bias.data.zero_() m.bias.data.zero_()
def _initialize_weight_default(m):
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')
class DepthwiseSeparableConv(nn.Module): class DepthwiseSeparableConv(nn.Module):
def __init__(self, in_chs, out_chs, kernel_size, def __init__(self, in_chs, out_chs, kernel_size,
stride=1, act_fn=F.relu, noskip=False, pw_act=False, stride=1, act_fn=F.relu, noskip=False, pw_act=False,
@ -485,7 +497,8 @@ class GenMobileNet(nn.Module):
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280, def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
depth_multiplier=1.0, depth_divisor=8, min_depth=None, depth_multiplier=1.0, depth_divisor=8, min_depth=None,
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT, bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
drop_rate=0., act_fn=F.relu, global_pool='avg', skip_head_conv=False): drop_rate=0., act_fn=F.relu, global_pool='avg', skip_head_conv=False,
weight_init='goog'):
super(GenMobileNet, self).__init__() super(GenMobileNet, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
self.depth_multiplier = depth_multiplier self.depth_multiplier = depth_multiplier
@ -515,7 +528,10 @@ class GenMobileNet(nn.Module):
self.classifier = nn.Linear(self.num_features, self.num_classes) self.classifier = nn.Linear(self.num_features, self.num_classes)
for m in self.modules(): for m in self.modules():
_initialize_weight(m) if weight_init == 'goog':
_initialize_weight_goog(m)
else:
_initialize_weight_default(m)
def get_classifier(self): def get_classifier(self):
return self.classifier return self.classifier

Loading…
Cancel
Save