diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index 137705de..f8f0df8a 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -359,15 +359,13 @@ class EfficientNetBuilder: return stages -def _init_weight_goog(m, n='', fix_group_fanout=False): +def _init_weight_goog(m, n='', fix_group_fanout=True): """ Weight initialization as per Tensorflow official implementations. Args: m (nn.Module): module to init n (str): module name - fix_group_fanout (bool): enable correct fanout calculation w/ group convs - - FIXME change fix_group_fanout to default to True if experiments show better training results + fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc: * https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py