hrnet features_only pretrained weight loading issue. Fix .

pull/233/head
Ross Wightman 4 years ago
parent 110a7c4982
commit 90a01f47d1

@ -120,6 +120,12 @@ if 'GITHUB_ACTIONS' not in os.environ:
in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change in_chans = 3 if 'pruned' in model_name else 1 # pruning not currently supported with in_chans change
create_model(model_name, pretrained=True, in_chans=in_chans) create_model(model_name, pretrained=True, in_chans=in_chans)
@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', list_models(pretrained=True))
@pytest.mark.parametrize('batch_size', [1])
def test_model_features_pretrained(model_name, batch_size):
"""Create that pretrained weights load when features_only==True."""
create_model(model_name, pretrained=True, features_only=True)
EXCLUDE_JIT_FILTERS = [ EXCLUDE_JIT_FILTERS = [
'*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable '*iabn*', 'tresnet*', # models using inplace abn unlikely to ever be scriptable

@ -773,12 +773,14 @@ class HighResolutionNetFeatures(HighResolutionNet):
def _create_hrnet(variant, pretrained, **model_kwargs): def _create_hrnet(variant, pretrained, **model_kwargs):
model_cls = HighResolutionNet model_cls = HighResolutionNet
strict = True
if model_kwargs.pop('features_only', False): if model_kwargs.pop('features_only', False):
model_cls = HighResolutionNetFeatures model_cls = HighResolutionNetFeatures
strict = False
return build_model_with_cfg( return build_model_with_cfg(
model_cls, variant, pretrained, default_cfg=default_cfgs[variant], model_cls, variant, pretrained, default_cfg=default_cfgs[variant],
model_cfg=cfg_cls[variant], **model_kwargs) model_cfg=cfg_cls[variant], pretrained_strict=strict, **model_kwargs)
@register_model @register_model

Loading…
Cancel
Save