Enable fixed fanout calc in EfficientNet/MobileNetV3 weight init by default. Fix #84

pull/94/head
Ross Wightman 5 years ago
parent 27b3680d49
commit 9fee316752

@ -359,15 +359,13 @@ class EfficientNetBuilder:
return stages return stages
def _init_weight_goog(m, n='', fix_group_fanout=False): def _init_weight_goog(m, n='', fix_group_fanout=True):
""" Weight initialization as per Tensorflow official implementations. """ Weight initialization as per Tensorflow official implementations.
Args: Args:
m (nn.Module): module to init m (nn.Module): module to init
n (str): module name n (str): module name
fix_group_fanout (bool): enable correct fanout calculation w/ group convs fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs
FIXME change fix_group_fanout to default to True if experiments show better training results
Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc: Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
* https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py * https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py

Loading…
Cancel
Save