diff --git a/timm/models/helpers.py b/timm/models/helpers.py index f743f600..5c174fad 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -17,7 +17,7 @@ try: except ImportError: from torch.hub import _get_torch_home as get_dir -from huggingface_hub import cached_download +from huggingface_hub import cached_download, hf_hub_url from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .layers import Conv2dSame, Linear @@ -138,10 +138,17 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_ return url = cfg['url'] - # TODO, progress and check_hash are ignored. - cached_filed = cached_download( - url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir() - ) + model_dir = get_cache_dir() + parts = urlparse(url) + filename = os.path.basename(parts.path) + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) + hash_prefix = None + if check_hash: + r = HASH_REGEX.search(filename) # r is Optional[Match[str]] + hash_prefix = r.group(1) if r else None + download_url_to_file(url, cached_file, hash_prefix, progress=progress) if load_fn is not None: load_fn(model, cached_file) @@ -179,15 +186,19 @@ def adapt_input_conv(in_chans, conv_weight): def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): if cfg is None: cfg = getattr(model, 'default_cfg') - if cfg is None or 'url' not in cfg or not cfg['url']: + if cfg is None or 'url' not in cfg or not cfg['url'] or 'hf_checkpoint' not in cfg or not cfg['hf_checkpoint']: _logger.warning("No pretrained weights exist for this model. Using random initialization.") return - # TODO, progress is ignored. - cached_filed = cached_download( - cfg['url'], library_name="timm", library_version=__version__, cache_dir=get_cache_dir() - ) - state_dict = torch.load(cached_filed, map_location='cpu') + if cfg.get('hf_checkpoint') is not None: + # TODO, progress is ignored. + url = hf_hub_url(cfg['hf_checkpoint'], "pytorch_model.pth", revision=cfg.get('hf_revision')) + cached_filed = cached_download( + url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir() + ) + state_dict = torch.load(cached_filed, map_location='cpu') + else: + state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu') if filter_fn is not None: state_dict = filter_fn(state_dict) diff --git a/timm/models/registry.py b/timm/models/registry.py index 3317eece..8930c1ae 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -37,6 +37,7 @@ def register_model(fn): # this will catch all models that have entrypoint matching cfg key, but miss any aliasing # entrypoints or non-matching combos has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] + has_pretrained = has_pretrained or 'hf_checkpoint' in mod.default_cfgs[model_name] if has_pretrained: _model_has_pretrained.add(model_name) return fn diff --git a/timm/models/resnet.py b/timm/models/resnet.py index fb1052f2..64760a6d 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -52,7 +52,8 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth', interpolation='bicubic'), 'resnet50d': _cfg( - url='https://huggingface.co/sgugger/resnet50d/resolve/main/pytorch_model.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth', + hf_revision="master", hf_model_id="sgugger/resnet50d", interpolation='bicubic', first_conv='conv1.0'), 'resnet101': _cfg(url='', interpolation='bicubic'), 'resnet101d': _cfg(