|
|
|
@ -358,15 +358,24 @@ class EfficientNetBuilder:
|
|
|
|
|
return stages
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_weight_goog(m, n=''):
|
|
|
|
|
def _init_weight_goog(m, n='', fix_group_fanout=False):
|
|
|
|
|
""" Weight initialization as per Tensorflow official implementations.
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
m (nn.Module): module to init
|
|
|
|
|
n (str): module name
|
|
|
|
|
fix_group_fanout (bool): enable correct fanout calculation w/ group convs
|
|
|
|
|
|
|
|
|
|
FIXME change fix_group_fanout to default to True if experiments show better training results
|
|
|
|
|
|
|
|
|
|
Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
|
|
|
|
|
* https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
|
|
|
|
|
* https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(m, CondConv2d):
|
|
|
|
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
|
|
|
if fix_group_fanout:
|
|
|
|
|
fan_out //= m.groups
|
|
|
|
|
init_weight_fn = get_condconv_initializer(
|
|
|
|
|
lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
|
|
|
|
|
init_weight_fn(m.weight)
|
|
|
|
@ -374,6 +383,8 @@ def _init_weight_goog(m, n=''):
|
|
|
|
|
m.bias.data.zero_()
|
|
|
|
|
elif isinstance(m, nn.Conv2d):
|
|
|
|
|
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
|
|
|
if fix_group_fanout:
|
|
|
|
|
fan_out //= m.groups
|
|
|
|
|
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
|
|
|
|
if m.bias is not None:
|
|
|
|
|
m.bias.data.zero_()
|
|
|
|
@ -390,21 +401,6 @@ def _init_weight_goog(m, n=''):
|
|
|
|
|
m.bias.data.zero_()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _init_weight_default(m, n=''):
|
|
|
|
|
""" Basic ResNet (Kaiming) style weight init"""
|
|
|
|
|
if isinstance(m, CondConv2d):
|
|
|
|
|
init_fn = get_condconv_initializer(partial(
|
|
|
|
|
nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape)
|
|
|
|
|
init_fn(m.weight)
|
|
|
|
|
elif isinstance(m, nn.Conv2d):
|
|
|
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
|
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
|
|
|
m.weight.data.fill_(1.0)
|
|
|
|
|
m.bias.data.zero_()
|
|
|
|
|
elif isinstance(m, nn.Linear):
|
|
|
|
|
nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def efficientnet_init_weights(model: nn.Module, init_fn=None):
|
|
|
|
|
init_fn = init_fn or _init_weight_goog
|
|
|
|
|
for n, m in model.named_modules():
|
|
|
|
|