Fix regression for pretrained classifier loading when using entrypt functions directly

pull/263/head
Ross Wightman 4 years ago
parent f591e90b0d
commit da6cd2cc1f

@ -24,7 +24,7 @@ MAX_FWD_FEAT_SIZE = 448
@pytest.mark.timeout(120) @pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS)) @pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS[:-1]))
@pytest.mark.parametrize('batch_size', [1]) @pytest.mark.parametrize('batch_size', [1])
def test_model_forward(model_name, batch_size): def test_model_forward(model_name, batch_size):
"""Run a single forward pass with each model""" """Run a single forward pass with each model"""

@ -277,11 +277,12 @@ def build_model_with_cfg(
if pruned: if pruned:
model = adapt_model_from_file(model, variant) model = adapt_model_from_file(model, variant)
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
if pretrained: if pretrained:
load_pretrained( load_pretrained(
model, model,
num_classes=kwargs.get('num_classes', 0), num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn, strict=pretrained_strict) filter_fn=pretrained_filter_fn, strict=pretrained_strict)
if features: if features:

@ -776,6 +776,7 @@ def _create_hrnet(variant, pretrained, **model_kwargs):
strict = True strict = True
if model_kwargs.pop('features_only', False): if model_kwargs.pop('features_only', False):
model_cls = HighResolutionNetFeatures model_cls = HighResolutionNetFeatures
model_kwargs['num_classes'] = 0
strict = False strict = False
return build_model_with_cfg( return build_model_with_cfg(

Loading…
Cancel
Save