Consistently use hf_model_id

pull/440/head
Sylvain Gugger 4 years ago
parent 482ab548dc
commit f269f2d9d5

@ -55,7 +55,7 @@ def create_model(
try: try:
model_cfg = load_hf_checkpoint_config(model_name, revision=kwargs.get("hf_revision")) model_cfg = load_hf_checkpoint_config(model_name, revision=kwargs.get("hf_revision"))
create_fn = model_entrypoint(model_cfg.pop("architecture")) create_fn = model_entrypoint(model_cfg.pop("architecture"))
model = create_fn(**model_args, hf_checkpoint=model_name, **kwargs) model = create_fn(**model_args, hf_model_id=model_name, **kwargs)
# Probably need some extra stuff, but this is a PoC of how the config in the model hub # Probably need some extra stuff, but this is a PoC of how the config in the model hub
# could overwrite the default config values. # could overwrite the default config values.
# model.default_cfg.update(model_cfg) # model.default_cfg.update(model_cfg)

@ -184,19 +184,19 @@ def adapt_input_conv(in_chans, conv_weight):
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True,
progress=False, hf_checkpoint=None, hf_revision=None): progress=False, hf_model_id=None, hf_revision=None):
cfg = cfg or getattr(model, 'default_cfg') cfg = cfg or getattr(model, 'default_cfg')
if hf_checkpoint is None: if hf_model_id is None:
hg_checkpoint = cfg.get('hf_checkpoint') hg_checkpoint = cfg.get('hf_model_id')
if hf_revision is None: if hf_revision is None:
hg_revision = cfg.get('hf_revision') hg_revision = cfg.get('hf_revision')
if cfg is None or (('url' not in cfg or not cfg['url']) and hf_checkpoint is None): if cfg is None or (('url' not in cfg or not cfg['url']) and hf_model_id is None):
_logger.warning("No pretrained weights exist for this model. Using random initialization.") _logger.warning("No pretrained weights exist for this model. Using random initialization.")
return return
if hf_checkpoint is not None: if hf_model_id is not None:
# TODO, progress is ignored. # TODO, progress is ignored.
url = hf_hub_url(hf_checkpoint, "pytorch_model.bin", revision=hf_revision) url = hf_hub_url(hf_model_id, "pytorch_model.bin", revision=hf_revision)
cached_filed = cached_download( cached_filed = cached_download(
url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir() url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir()
) )
@ -358,7 +358,7 @@ def build_model_with_cfg(
pretrained_custom_load: bool = False, pretrained_custom_load: bool = False,
**kwargs): **kwargs):
pruned = kwargs.pop('pruned', False) pruned = kwargs.pop('pruned', False)
hf_checkpoint = kwargs.pop('hf_checkpoint', None) hf_model_id = kwargs.pop('hf_model_id', None)
hf_revision = kwargs.pop('hf_revision', None) hf_revision = kwargs.pop('hf_revision', None)
features = False features = False
feature_cfg = feature_cfg or {} feature_cfg = feature_cfg or {}
@ -377,7 +377,7 @@ def build_model_with_cfg(
# for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats # for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats
num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000))
if pretrained or hf_checkpoint is not None: if pretrained or hf_model_id is not None:
if pretrained_custom_load: if pretrained_custom_load:
load_custom_pretrained(model) load_custom_pretrained(model)
else: else:
@ -385,7 +385,7 @@ def build_model_with_cfg(
model, model,
num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3), num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3),
filter_fn=pretrained_filter_fn, strict=pretrained_strict, filter_fn=pretrained_filter_fn, strict=pretrained_strict,
hf_checkpoint=hf_checkpoint, hf_revision=hf_revision) hf_model_id=hf_model_id, hf_revision=hf_revision)
if features: if features:
feature_cls = FeatureListNet feature_cls = FeatureListNet

@ -37,7 +37,7 @@ def register_model(fn):
# this will catch all models that have entrypoint matching cfg key, but miss any aliasing # this will catch all models that have entrypoint matching cfg key, but miss any aliasing
# entrypoints or non-matching combos # entrypoints or non-matching combos
has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url']
has_pretrained = has_pretrained or 'hf_checkpoint' in mod.default_cfgs[model_name] has_pretrained = has_pretrained or 'hf_model_id' in mod.default_cfgs[model_name]
if has_pretrained: if has_pretrained:
_model_has_pretrained.add(model_name) _model_has_pretrained.add(model_name)
return fn return fn

Loading…
Cancel
Save