diff --git a/timm/models/factory.py b/timm/models/factory.py index fa23150c..aac29181 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -55,7 +55,7 @@ def create_model( try: model_cfg = load_hf_checkpoint_config(model_name, revision=kwargs.get("hf_revision")) create_fn = model_entrypoint(model_cfg.pop("architecture")) - model = create_fn(**model_args, hf_checkpoint=model_name, **kwargs) + model = create_fn(**model_args, hf_model_id=model_name, **kwargs) # Probably need some extra stuff, but this is a PoC of how the config in the model hub # could overwrite the default config values. # model.default_cfg.update(model_cfg) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 62558913..f4e291b0 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -184,19 +184,19 @@ def adapt_input_conv(in_chans, conv_weight): def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, - progress=False, hf_checkpoint=None, hf_revision=None): + progress=False, hf_model_id=None, hf_revision=None): cfg = cfg or getattr(model, 'default_cfg') - if hf_checkpoint is None: - hg_checkpoint = cfg.get('hf_checkpoint') + if hf_model_id is None: + hg_checkpoint = cfg.get('hf_model_id') if hf_revision is None: hg_revision = cfg.get('hf_revision') - if cfg is None or (('url' not in cfg or not cfg['url']) and hf_checkpoint is None): + if cfg is None or (('url' not in cfg or not cfg['url']) and hf_model_id is None): _logger.warning("No pretrained weights exist for this model. Using random initialization.") return - if hf_checkpoint is not None: + if hf_model_id is not None: # TODO, progress is ignored. - url = hf_hub_url(hf_checkpoint, "pytorch_model.bin", revision=hf_revision) + url = hf_hub_url(hf_model_id, "pytorch_model.bin", revision=hf_revision) cached_filed = cached_download( url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir() ) @@ -358,7 +358,7 @@ def build_model_with_cfg( pretrained_custom_load: bool = False, **kwargs): pruned = kwargs.pop('pruned', False) - hf_checkpoint = kwargs.pop('hf_checkpoint', None) + hf_model_id = kwargs.pop('hf_model_id', None) hf_revision = kwargs.pop('hf_revision', None) features = False feature_cfg = feature_cfg or {} @@ -377,7 +377,7 @@ def build_model_with_cfg( # for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) - if pretrained or hf_checkpoint is not None: + if pretrained or hf_model_id is not None: if pretrained_custom_load: load_custom_pretrained(model) else: @@ -385,7 +385,7 @@ def build_model_with_cfg( model, num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3), filter_fn=pretrained_filter_fn, strict=pretrained_strict, - hf_checkpoint=hf_checkpoint, hf_revision=hf_revision) + hf_model_id=hf_model_id, hf_revision=hf_revision) if features: feature_cls = FeatureListNet diff --git a/timm/models/registry.py b/timm/models/registry.py index 8930c1ae..da0df3f4 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -37,7 +37,7 @@ def register_model(fn): # this will catch all models that have entrypoint matching cfg key, but miss any aliasing # entrypoints or non-matching combos has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] - has_pretrained = has_pretrained or 'hf_checkpoint' in mod.default_cfgs[model_name] + has_pretrained = has_pretrained or 'hf_model_id' in mod.default_cfgs[model_name] if has_pretrained: _model_has_pretrained.add(model_name) return fn