Remove unused default_init for EfficientNets, experimenting with fanout calc for #84

pull/95/head
Ross Wightman 5 years ago
parent cade829105
commit d0eb59ef46

@ -358,15 +358,24 @@ class EfficientNetBuilder:
return stages return stages
def _init_weight_goog(m, n=''): def _init_weight_goog(m, n='', fix_group_fanout=False):
""" Weight initialization as per Tensorflow official implementations. """ Weight initialization as per Tensorflow official implementations.
Args:
m (nn.Module): module to init
n (str): module name
fix_group_fanout (bool): enable correct fanout calculation w/ group convs
FIXME change fix_group_fanout to default to True if experiments show better training results
Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc: Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
* https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py * https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
* https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py * https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
""" """
if isinstance(m, CondConv2d): if isinstance(m, CondConv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
if fix_group_fanout:
fan_out //= m.groups
init_weight_fn = get_condconv_initializer( init_weight_fn = get_condconv_initializer(
lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
init_weight_fn(m.weight) init_weight_fn(m.weight)
@ -374,6 +383,8 @@ def _init_weight_goog(m, n=''):
m.bias.data.zero_() m.bias.data.zero_()
elif isinstance(m, nn.Conv2d): elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
if fix_group_fanout:
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None: if m.bias is not None:
m.bias.data.zero_() m.bias.data.zero_()
@ -390,21 +401,6 @@ def _init_weight_goog(m, n=''):
m.bias.data.zero_() m.bias.data.zero_()
def _init_weight_default(m, n=''):
""" Basic ResNet (Kaiming) style weight init"""
if isinstance(m, CondConv2d):
init_fn = get_condconv_initializer(partial(
nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape)
init_fn(m.weight)
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1.0)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')
def efficientnet_init_weights(model: nn.Module, init_fn=None): def efficientnet_init_weights(model: nn.Module, init_fn=None):
init_fn = init_fn or _init_weight_goog init_fn = init_fn or _init_weight_goog
for n, m in model.named_modules(): for n, m in model.named_modules():

Loading…
Cancel
Save