diff --git a/timm/models/factory.py b/timm/models/factory.py index a7b6c90e..6f196208 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -1,5 +1,5 @@ from .registry import is_model, is_model_in_modules, model_entrypoint -from .helpers import load_checkpoint +from .helpers import load_checkpoint, load_hf_checkpoint_config from .layers import set_layer_config @@ -52,7 +52,16 @@ def create_model( create_fn = model_entrypoint(model_name) model = create_fn(**model_args, **kwargs) else: - raise RuntimeError('Unknown model (%s)' % model_name) + try: + model_cfg = load_hf_checkpoint_config(model_name, revision=kwargs.get("hf_revision")) + # This does not work, but there is probably a way to have the values in the config override + # the defaults if needed. + # model_args["model_cfg"] = model_cfg + create_fn = model_entrypoint(model_cfg.pop("architecture")) + model = create_fn(**model_args, **kwargs) + except Exception as e: + raise RuntimeError('Unknown model or checkpoint from the Hugging Face hub (%s)' % model_name) + if checkpoint_path: load_checkpoint(model, checkpoint_path) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 5c174fad..fe1f509d 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -3,11 +3,12 @@ Hacked together by / Copyright 2020 Ross Wightman """ import logging +import json import os import math from collections import OrderedDict from copy import deepcopy -from typing import Callable +from typing import Callable, Optional, Union import torch import torch.nn as nn @@ -183,16 +184,21 @@ def adapt_input_conv(in_chans, conv_weight): return conv_weight -def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): +def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, + progress=False, hf_checkpoint=None, hf_revision=None): if cfg is None: cfg = getattr(model, 'default_cfg') - if cfg is None or 'url' not in cfg or not cfg['url'] or 'hf_checkpoint' not in cfg or not cfg['hf_checkpoint']: + if hf_checkpoint is None: + hg_checkpoint = cfg.get('hf_checkpoint') + if hf_revision is None: + hg_revision = cfg.get('hf_revision') + if cfg is None or (('url' not in cfg or not cfg['url']) and hf_checkpoint is None): _logger.warning("No pretrained weights exist for this model. Using random initialization.") return - if cfg.get('hf_checkpoint') is not None: + if hf_checkpoint is not None: # TODO, progress is ignored. - url = hf_hub_url(cfg['hf_checkpoint'], "pytorch_model.pth", revision=cfg.get('hf_revision')) + url = hf_hub_url(hf_checkpoint, "pytorch_model.pth", revision=hf_revision) cached_filed = cached_download( url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir() ) @@ -354,6 +360,8 @@ def build_model_with_cfg( pretrained_custom_load: bool = False, **kwargs): pruned = kwargs.pop('pruned', False) + hf_checkpoint = kwargs.pop('hf_checkpoint', None) + hf_revision = kwargs.pop('hf_revision', None) features = False feature_cfg = feature_cfg or {} @@ -371,14 +379,15 @@ def build_model_with_cfg( # for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) - if pretrained: + if pretrained or hf_checkpoint is not None: if pretrained_custom_load: load_custom_pretrained(model) else: load_pretrained( model, num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3), - filter_fn=pretrained_filter_fn, strict=pretrained_strict) + filter_fn=pretrained_filter_fn, strict=pretrained_strict, + hf_checkpoint=hf_checkpoint, hf_revision=hf_revision) if features: feature_cls = FeatureListNet @@ -394,3 +403,17 @@ def build_model_with_cfg( model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg return model + + +def load_cfg_from_json(json_file: Union[str, os.PathLike]): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + +def load_hf_checkpoint_config(checkpoint: str, revision: Optional[str] = None): + url = hf_hub_url(checkpoint, "config.json", revision=revision) + cached_filed = cached_download( + url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir() + ) + return load_cfg_from_json(cached_filed)