|
|
|
@ -289,9 +289,11 @@ class _BlockBuilder:
|
|
|
|
|
return blocks
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _initialize_weight(m):
|
|
|
|
|
def _initialize_weight_goog(m):
|
|
|
|
|
# weight init as per Tensorflow Official impl
|
|
|
|
|
# https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
|
|
|
|
|
if isinstance(m, nn.Conv2d):
|
|
|
|
|
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
|
|
|
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels # fan-out
|
|
|
|
|
m.weight.data.normal_(0, math.sqrt(2.0 / n))
|
|
|
|
|
if m.bias is not None:
|
|
|
|
|
m.bias.data.zero_()
|
|
|
|
@ -299,12 +301,22 @@ def _initialize_weight(m):
|
|
|
|
|
m.weight.data.fill_(1.0)
|
|
|
|
|
m.bias.data.zero_()
|
|
|
|
|
elif isinstance(m, nn.Linear):
|
|
|
|
|
n = m.weight.size(1)
|
|
|
|
|
n = m.weight.size(0) # fan-out
|
|
|
|
|
init_range = 1.0 / math.sqrt(n)
|
|
|
|
|
m.weight.data.uniform_(-init_range, init_range)
|
|
|
|
|
m.bias.data.zero_()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _initialize_weight_default(m):
|
|
|
|
|
if isinstance(m, nn.Conv2d):
|
|
|
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
|
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
|
|
|
m.weight.data.fill_(1.0)
|
|
|
|
|
m.bias.data.zero_()
|
|
|
|
|
elif isinstance(m, nn.Linear):
|
|
|
|
|
nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DepthwiseSeparableConv(nn.Module):
|
|
|
|
|
def __init__(self, in_chs, out_chs, kernel_size,
|
|
|
|
|
stride=1, act_fn=F.relu, noskip=False, pw_act=False,
|
|
|
|
@ -485,7 +497,8 @@ class GenMobileNet(nn.Module):
|
|
|
|
|
def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=32, num_features=1280,
|
|
|
|
|
depth_multiplier=1.0, depth_divisor=8, min_depth=None,
|
|
|
|
|
bn_momentum=_BN_MOMENTUM_PT_DEFAULT, bn_eps=_BN_EPS_PT_DEFAULT,
|
|
|
|
|
drop_rate=0., act_fn=F.relu, global_pool='avg', skip_head_conv=False):
|
|
|
|
|
drop_rate=0., act_fn=F.relu, global_pool='avg', skip_head_conv=False,
|
|
|
|
|
weight_init='goog'):
|
|
|
|
|
super(GenMobileNet, self).__init__()
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.depth_multiplier = depth_multiplier
|
|
|
|
@ -515,7 +528,10 @@ class GenMobileNet(nn.Module):
|
|
|
|
|
self.classifier = nn.Linear(self.num_features, self.num_classes)
|
|
|
|
|
|
|
|
|
|
for m in self.modules():
|
|
|
|
|
_initialize_weight(m)
|
|
|
|
|
if weight_init == 'goog':
|
|
|
|
|
_initialize_weight_goog(m)
|
|
|
|
|
else:
|
|
|
|
|
_initialize_weight_default(m)
|
|
|
|
|
|
|
|
|
|
def get_classifier(self):
|
|
|
|
|
return self.classifier
|
|
|
|
|