diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 52c5c81c..42435b0d 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -322,6 +322,7 @@ class EfficientNet(nn.Module): # Stem if not fix_stem: stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) + print(stem_size) self.conv_stem = create_conv2d(self._in_chs, stem_size, 3, stride=2, padding=pad_type) self.bn1 = norm_layer(stem_size, **norm_kwargs) self.act1 = act_layer(inplace=True) @@ -569,7 +570,8 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar return model -def _gen_mobilenet_v2(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): +def _gen_mobilenet_v2( + variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs): """ Generate MobileNet-V2 network Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py Paper: https://arxiv.org/abs/1801.04381 @@ -584,8 +586,10 @@ def _gen_mobilenet_v2(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre ['ir_r1_k3_s1_e6_c320'], ] model_kwargs = dict( - block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier), + block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head), + num_features=1280 if fix_stem_head else round_channels(1280, channel_multiplier, 8, None), stem_size=32, + fix_stem=fix_stem_head, channel_multiplier=channel_multiplier, norm_kwargs=resolve_bn_args(kwargs), act_layer=nn.ReLU6, @@ -955,23 +959,25 @@ def mobilenetv2_100(pretrained=False, **kwargs): @register_model -def mobilenetv2_100d(pretrained=False, **kwargs): +def mobilenetv2_140(pretrained=False, **kwargs): """ MobileNet V2 """ - model = _gen_mobilenet_v2('mobilenetv2_100d', 1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + model = _gen_mobilenet_v2('mobilenetv2_140', 1.4, pretrained=pretrained, **kwargs) return model @register_model def mobilenetv2_110d(pretrained=False, **kwargs): """ MobileNet V2 """ - model = _gen_mobilenet_v2('mobilenetv2_110d', 1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + model = _gen_mobilenet_v2( + 'mobilenetv2_100d', 1.1, depth_multiplier=1.2, fix_stem_head=True, pretrained=pretrained, **kwargs) return model @register_model -def mobilenetv2_140(pretrained=False, **kwargs): +def mobilenetv2_120d(pretrained=False, **kwargs): """ MobileNet V2 """ - model = _gen_mobilenet_v2('mobilenetv2_140', 1.4, pretrained=pretrained, **kwargs) + model = _gen_mobilenet_v2( + 'mobilenetv2_110d', 1.2, depth_multiplier=1.4, fix_stem_head=True, pretrained=pretrained, **kwargs) return model