diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index db6f54f9..ca2060c4 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -358,15 +358,24 @@ class EfficientNetBuilder: return stages -def _init_weight_goog(m, n=''): +def _init_weight_goog(m, n='', fix_group_fanout=False): """ Weight initialization as per Tensorflow official implementations. + Args: + m (nn.Module): module to init + n (str): module name + fix_group_fanout (bool): enable correct fanout calculation w/ group convs + + FIXME change fix_group_fanout to default to True if experiments show better training results + Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc: * https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py * https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py """ if isinstance(m, CondConv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups init_weight_fn = get_condconv_initializer( lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) init_weight_fn(m.weight) @@ -374,6 +383,8 @@ def _init_weight_goog(m, n=''): m.bias.data.zero_() elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() @@ -390,21 +401,6 @@ def _init_weight_goog(m, n=''): m.bias.data.zero_() -def _init_weight_default(m, n=''): - """ Basic ResNet (Kaiming) style weight init""" - if isinstance(m, CondConv2d): - init_fn = get_condconv_initializer(partial( - nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape) - init_fn(m.weight) - elif isinstance(m, nn.Conv2d): - nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(m, nn.BatchNorm2d): - m.weight.data.fill_(1.0) - m.bias.data.zero_() - elif isinstance(m, nn.Linear): - nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') - - def efficientnet_init_weights(model: nn.Module, init_fn=None): init_fn = init_fn or _init_weight_goog for n, m in model.named_modules():