diff --git a/tests/test_models.py b/tests/test_models.py index c673dc96..db8efbf3 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -24,7 +24,7 @@ MAX_FWD_FEAT_SIZE = 448 @pytest.mark.timeout(120) -@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) +@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-1])) @pytest.mark.parametrize('batch_size', [1]) def test_model_forward(model_name, batch_size): """Run a single forward pass with each model""" diff --git a/timm/models/helpers.py b/timm/models/helpers.py index ac119295..b90ce1db 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -277,11 +277,12 @@ def build_model_with_cfg( if pruned: model = adapt_model_from_file(model, variant) + # for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats + num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) if pretrained: load_pretrained( model, - num_classes=kwargs.get('num_classes', 0), - in_chans=kwargs.get('in_chans', 3), + num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3), filter_fn=pretrained_filter_fn, strict=pretrained_strict) if features: diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 1e867686..2e8757b5 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -776,6 +776,7 @@ def _create_hrnet(variant, pretrained, **model_kwargs): strict = True if model_kwargs.pop('features_only', False): model_cls = HighResolutionNetFeatures + model_kwargs['num_classes'] = 0 strict = False return build_model_with_cfg(