From 90a01f47d19838f52142f00805d56b8e94a6ea14 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 1 Sep 2020 17:37:55 -0700 Subject: [PATCH] hrnet features_only pretrained weight loading issue. Fix #232. --- tests/test_models.py | 6 ++++++ timm/models/hrnet.py | 4 +++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/test_models.py b/tests/test_models.py index 71d643dd..d6fcaf79 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -120,6 +120,12 @@ if 'GITHUB_ACTIONS' not in os.environ: in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change create_model(model_name, pretrained=True, in_chans=in_chans) + @pytest.mark.timeout(120) + @pytest.mark.parametrize('model_name', list_models(pretrained=True)) + @pytest.mark.parametrize('batch_size', [1]) + def test_model_features_pretrained(model_name, batch_size): + """Create that pretrained weights load when features_only==True.""" + create_model(model_name, pretrained=True, features_only=True) EXCLUDE_JIT_FILTERS = [ '*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index ad865887..1e867686 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -773,12 +773,14 @@ class HighResolutionNetFeatures(HighResolutionNet): def _create_hrnet(variant, pretrained, **model_kwargs): model_cls = HighResolutionNet + strict = True if model_kwargs.pop('features_only', False): model_cls = HighResolutionNetFeatures + strict = False return build_model_with_cfg( model_cls, variant, pretrained, default_cfg=default_cfgs[variant], - model_cfg=cfg_cls[variant], **model_kwargs) + model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs) @register_model