Bug in last mod for features_only default_cfg

pull/297/head
Ross Wightman 4 years ago
parent 867a0e5a04
commit cd72e66eff

@ -453,19 +453,19 @@ class EfficientNetFeatures(nn.Module):
def _create_effnet(model_kwargs, variant, pretrained=False):
features_only = False
model_cls = EfficientNet
if model_kwargs.pop('features_only', False):
load_strict = False
features_only = True
model_kwargs.pop('num_classes', 0)
model_kwargs.pop('num_features', 0)
model_kwargs.pop('head_conv', None)
model_cls = EfficientNetFeatures
else:
load_strict = True
model_cls = EfficientNet
model = build_model_with_cfg(
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
pretrained_strict=load_strict, **model_kwargs)
model.default_cfg = default_cfg_for_features(model.default_cfg)
pretrained_strict=not features_only, **model_kwargs)
if features_only:
model.default_cfg = default_cfg_for_features(model.default_cfg)
return model

@ -773,16 +773,16 @@ class HighResolutionNetFeatures(HighResolutionNet):
def _create_hrnet(variant, pretrained, **model_kwargs):
model_cls = HighResolutionNet
strict = True
features_only = False
if model_kwargs.pop('features_only', False):
model_cls = HighResolutionNetFeatures
model_kwargs['num_classes'] = 0
strict = False
features_only = True
model = build_model_with_cfg(
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs)
model.default_cfg = default_cfg_for_features(model.default_cfg)
model_cfg=cfg_cls[variant], pretrained_strict=not features_only, **model_kwargs)
if features_only:
model.default_cfg = default_cfg_for_features(model.default_cfg)
return model

@ -201,20 +201,20 @@ class MobileNetV3Features(nn.Module):
def _create_mnv3(model_kwargs, variant, pretrained=False):
features_only = False
model_cls = MobileNetV3
if model_kwargs.pop('features_only', False):
load_strict = False
features_only = True
model_kwargs.pop('num_classes', 0)
model_kwargs.pop('num_features', 0)
model_kwargs.pop('head_conv', None)
model_kwargs.pop('head_bias', None)
model_cls = MobileNetV3Features
else:
load_strict = True
model_cls = MobileNetV3
model = build_model_with_cfg(
model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
pretrained_strict=load_strict, **model_kwargs)
model.default_cfg = default_cfg_for_features(model.default_cfg)
pretrained_strict=not features_only, **model_kwargs)
if features_only:
model.default_cfg = default_cfg_for_features(model.default_cfg)
return model

Loading…
Cancel
Save