diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index 7eeda3ca..4a89590b 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -453,19 +453,19 @@ class EfficientNetFeatures(nn.Module): def _create_effnet(model_kwargs, variant, pretrained=False): + features_only = False + model_cls = EfficientNet if model_kwargs.pop('features_only', False): - load_strict = False + features_only = True model_kwargs.pop('num_classes', 0) model_kwargs.pop('num_features', 0) model_kwargs.pop('head_conv', None) model_cls = EfficientNetFeatures - else: - load_strict = True - model_cls = EfficientNet model = build_model_with_cfg( model_cls, variant, pretrained, default_cfg=default_cfgs[variant], - pretrained_strict=load_strict, **model_kwargs) - model.default_cfg = default_cfg_for_features(model.default_cfg) + pretrained_strict=not features_only, **model_kwargs) + if features_only: + model.default_cfg = default_cfg_for_features(model.default_cfg) return model diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index d246812e..1c0bc9f0 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -773,16 +773,16 @@ class HighResolutionNetFeatures(HighResolutionNet): def _create_hrnet(variant, pretrained, **model_kwargs): model_cls = HighResolutionNet - strict = True + features_only = False if model_kwargs.pop('features_only', False): model_cls = HighResolutionNetFeatures model_kwargs['num_classes'] = 0 - strict = False - + features_only = True model = build_model_with_cfg( model_cls, variant, pretrained, default_cfg=default_cfgs[variant], - model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs) - model.default_cfg = default_cfg_for_features(model.default_cfg) + model_cfg=cfg_cls[variant], pretrained_strict=not features_only, **model_kwargs) + if features_only: + model.default_cfg = default_cfg_for_features(model.default_cfg) return model diff --git a/timm/models/mobilenetv3.py b/timm/models/mobilenetv3.py index afded75f..8a48ce72 100644 --- a/timm/models/mobilenetv3.py +++ b/timm/models/mobilenetv3.py @@ -201,20 +201,20 @@ class MobileNetV3Features(nn.Module): def _create_mnv3(model_kwargs, variant, pretrained=False): + features_only = False + model_cls = MobileNetV3 if model_kwargs.pop('features_only', False): - load_strict = False + features_only = True model_kwargs.pop('num_classes', 0) model_kwargs.pop('num_features', 0) model_kwargs.pop('head_conv', None) model_kwargs.pop('head_bias', None) model_cls = MobileNetV3Features - else: - load_strict = True - model_cls = MobileNetV3 model = build_model_with_cfg( model_cls, variant, pretrained, default_cfg=default_cfgs[variant], - pretrained_strict=load_strict, **model_kwargs) - model.default_cfg = default_cfg_for_features(model.default_cfg) + pretrained_strict=not features_only, **model_kwargs) + if features_only: + model.default_cfg = default_cfg_for_features(model.default_cfg) return model