From da6cd2cc1fd8696986b1cf224a464f45819eec2d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 29 Oct 2020 15:43:39 -0700 Subject: [PATCH] Fix regression for pretrained classifier loading when using entrypt functions directly --- tests/test_models.py | 2 +- timm/models/helpers.py | 5 +++-- timm/models/hrnet.py | 1 + 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index c673dc96..db8efbf3 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -24,7 +24,7 @@ MAX_FWD_FEAT_SIZE = 448 @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-1])) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward(model_name, batch_size): """Run a single forward pass with each model""" diff --git a/timm/models/helpers.py b/timm/models/helpers.py index ac119295..b90ce1db 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -277,11 +277,12 @@ def build_model_with_cfg( if pruned: model = adapt_model_from_file(model, variant) + # for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats + num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) if pretrained: load_pretrained( model, - num_classes=kwargs.get('num_classes', 0), - in_chans=kwargs.get('in_chans', 3), + num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3), filter_fn=pretrained_filter_fn, strict=pretrained_strict) if features: diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 1e867686..2e8757b5 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -776,6 +776,7 @@ def _create_hrnet(variant, pretrained, **model_kwargs): strict = True if model_kwargs.pop('features_only', False): model_cls = HighResolutionNetFeatures + model_kwargs['num_classes'] = 0 strict = False return build_model_with_cfg(