From ebe69dd4d3729d33396da1cdd947ea6394c4bb61 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 11 Feb 2021 09:13:49 -0500 Subject: [PATCH 1/5] PoC for using HF model hub --- timm/models/helpers.py | 47 ++++++++++++++++++++++++------------------ timm/models/resnet.py | 2 +- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index d9b501da..f743f600 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -17,13 +17,31 @@ try: except ImportError: from torch.hub import _get_torch_home as get_dir +from huggingface_hub import cached_download + from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .layers import Conv2dSame, Linear +from ..version import __version__ _logger = logging.getLogger(__name__) +def get_cache_dir(): + """ + Returns the location of the directory where models are cached (and creates it if necessary). + """ + # Issue warning to move data if old env is set + if os.getenv('TORCH_MODEL_ZOO'): + _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') + + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(model_dir, exist_ok=True) + return model_dir + + def load_state_dict(checkpoint_path, use_ema=False): if checkpoint_path and os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') @@ -120,25 +138,10 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_ return url = cfg['url'] - # Issue warning to move data if old env is set - if os.getenv('TORCH_MODEL_ZOO'): - _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') - - hub_dir = get_dir() - model_dir = os.path.join(hub_dir, 'checkpoints') - - os.makedirs(model_dir, exist_ok=True) - - parts = urlparse(url) - filename = os.path.basename(parts.path) - cached_file = os.path.join(model_dir, filename) - if not os.path.exists(cached_file): - _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) - hash_prefix = None - if check_hash: - r = HASH_REGEX.search(filename) # r is Optional[Match[str]] - hash_prefix = r.group(1) if r else None - download_url_to_file(url, cached_file, hash_prefix, progress=progress) + # TODO, progress and check_hash are ignored. + cached_filed = cached_download( + url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir() + ) if load_fn is not None: load_fn(model, cached_file) @@ -180,7 +183,11 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non _logger.warning("No pretrained weights exist for this model. Using random initialization.") return - state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu') + # TODO, progress is ignored. + cached_filed = cached_download( + cfg['url'], library_name="timm", library_version=__version__, cache_dir=get_cache_dir() + ) + state_dict = torch.load(cached_filed, map_location='cpu') if filter_fn is not None: state_dict = filter_fn(state_dict) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 6dec9d53..fb1052f2 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -52,7 +52,7 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth', interpolation='bicubic'), 'resnet50d': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth', + url='https://huggingface.co/sgugger/resnet50d/resolve/main/pytorch_model.pth', interpolation='bicubic', first_conv='conv1.0'), 'resnet101': _cfg(url='', interpolation='bicubic'), 'resnet101d': _cfg( From 9857e12c0cd96bc0a3c989a553a6cf33cb04106e Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 11 Feb 2021 09:33:12 -0500 Subject: [PATCH 2/5] Be more opt-in --- timm/models/helpers.py | 33 ++++++++++++++++++++++----------- timm/models/registry.py | 1 + timm/models/resnet.py | 3 ++- 3 files changed, 25 insertions(+), 12 deletions(-) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index f743f600..5c174fad 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -17,7 +17,7 @@ try: except ImportError: from torch.hub import _get_torch_home as get_dir -from huggingface_hub import cached_download +from huggingface_hub import cached_download, hf_hub_url from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .layers import Conv2dSame, Linear @@ -138,10 +138,17 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_ return url = cfg['url'] - # TODO, progress and check_hash are ignored. - cached_filed = cached_download( - url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir() - ) + model_dir = get_cache_dir() + parts = urlparse(url) + filename = os.path.basename(parts.path) + cached_file = os.path.join(model_dir, filename) + if not os.path.exists(cached_file): + _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) + hash_prefix = None + if check_hash: + r = HASH_REGEX.search(filename) # r is Optional[Match[str]] + hash_prefix = r.group(1) if r else None + download_url_to_file(url, cached_file, hash_prefix, progress=progress) if load_fn is not None: load_fn(model, cached_file) @@ -179,15 +186,19 @@ def adapt_input_conv(in_chans, conv_weight): def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): if cfg is None: cfg = getattr(model, 'default_cfg') - if cfg is None or 'url' not in cfg or not cfg['url']: + if cfg is None or 'url' not in cfg or not cfg['url'] or 'hf_checkpoint' not in cfg or not cfg['hf_checkpoint']: _logger.warning("No pretrained weights exist for this model. Using random initialization.") return - # TODO, progress is ignored. - cached_filed = cached_download( - cfg['url'], library_name="timm", library_version=__version__, cache_dir=get_cache_dir() - ) - state_dict = torch.load(cached_filed, map_location='cpu') + if cfg.get('hf_checkpoint') is not None: + # TODO, progress is ignored. + url = hf_hub_url(cfg['hf_checkpoint'], "pytorch_model.pth", revision=cfg.get('hf_revision')) + cached_filed = cached_download( + url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir() + ) + state_dict = torch.load(cached_filed, map_location='cpu') + else: + state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu') if filter_fn is not None: state_dict = filter_fn(state_dict) diff --git a/timm/models/registry.py b/timm/models/registry.py index 3317eece..8930c1ae 100644 --- a/timm/models/registry.py +++ b/timm/models/registry.py @@ -37,6 +37,7 @@ def register_model(fn): # this will catch all models that have entrypoint matching cfg key, but miss any aliasing # entrypoints or non-matching combos has_pretrained = 'url' in mod.default_cfgs[model_name] and 'http' in mod.default_cfgs[model_name]['url'] + has_pretrained = has_pretrained or 'hf_checkpoint' in mod.default_cfgs[model_name] if has_pretrained: _model_has_pretrained.add(model_name) return fn diff --git a/timm/models/resnet.py b/timm/models/resnet.py index fb1052f2..64760a6d 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -52,7 +52,8 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth', interpolation='bicubic'), 'resnet50d': _cfg( - url='https://huggingface.co/sgugger/resnet50d/resolve/main/pytorch_model.pth', + url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth', + hf_revision="master", hf_model_id="sgugger/resnet50d", interpolation='bicubic', first_conv='conv1.0'), 'resnet101': _cfg(url='', interpolation='bicubic'), 'resnet101d': _cfg( From 4d0b5feed04c3bef81a6d207d100114b64a35af9 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 11 Feb 2021 10:34:15 -0500 Subject: [PATCH 3/5] Make `create_model` work from a checkpoint --- timm/models/factory.py | 13 +++++++++++-- timm/models/helpers.py | 37 ++++++++++++++++++++++++++++++------- 2 files changed, 41 insertions(+), 9 deletions(-) diff --git a/timm/models/factory.py b/timm/models/factory.py index a7b6c90e..6f196208 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -1,5 +1,5 @@ from .registry import is_model, is_model_in_modules, model_entrypoint -from .helpers import load_checkpoint +from .helpers import load_checkpoint, load_hf_checkpoint_config from .layers import set_layer_config @@ -52,7 +52,16 @@ def create_model( create_fn = model_entrypoint(model_name) model = create_fn(**model_args, **kwargs) else: - raise RuntimeError('Unknown model (%s)' % model_name) + try: + model_cfg = load_hf_checkpoint_config(model_name, revision=kwargs.get("hf_revision")) + # This does not work, but there is probably a way to have the values in the config override + # the defaults if needed. + # model_args["model_cfg"] = model_cfg + create_fn = model_entrypoint(model_cfg.pop("architecture")) + model = create_fn(**model_args, **kwargs) + except Exception as e: + raise RuntimeError('Unknown model or checkpoint from the Hugging Face hub (%s)' % model_name) + if checkpoint_path: load_checkpoint(model, checkpoint_path) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 5c174fad..fe1f509d 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -3,11 +3,12 @@ Hacked together by / Copyright 2020 Ross Wightman """ import logging +import json import os import math from collections import OrderedDict from copy import deepcopy -from typing import Callable +from typing import Callable, Optional, Union import torch import torch.nn as nn @@ -183,16 +184,21 @@ def adapt_input_conv(in_chans, conv_weight): return conv_weight -def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False): +def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, + progress=False, hf_checkpoint=None, hf_revision=None): if cfg is None: cfg = getattr(model, 'default_cfg') - if cfg is None or 'url' not in cfg or not cfg['url'] or 'hf_checkpoint' not in cfg or not cfg['hf_checkpoint']: + if hf_checkpoint is None: + hg_checkpoint = cfg.get('hf_checkpoint') + if hf_revision is None: + hg_revision = cfg.get('hf_revision') + if cfg is None or (('url' not in cfg or not cfg['url']) and hf_checkpoint is None): _logger.warning("No pretrained weights exist for this model. Using random initialization.") return - if cfg.get('hf_checkpoint') is not None: + if hf_checkpoint is not None: # TODO, progress is ignored. - url = hf_hub_url(cfg['hf_checkpoint'], "pytorch_model.pth", revision=cfg.get('hf_revision')) + url = hf_hub_url(hf_checkpoint, "pytorch_model.pth", revision=hf_revision) cached_filed = cached_download( url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir() ) @@ -354,6 +360,8 @@ def build_model_with_cfg( pretrained_custom_load: bool = False, **kwargs): pruned = kwargs.pop('pruned', False) + hf_checkpoint = kwargs.pop('hf_checkpoint', None) + hf_revision = kwargs.pop('hf_revision', None) features = False feature_cfg = feature_cfg or {} @@ -371,14 +379,15 @@ def build_model_with_cfg( # for classification models, check class attr, then kwargs, then default to 1k, otherwise 0 for feats num_classes_pretrained = 0 if features else getattr(model, 'num_classes', kwargs.get('num_classes', 1000)) - if pretrained: + if pretrained or hf_checkpoint is not None: if pretrained_custom_load: load_custom_pretrained(model) else: load_pretrained( model, num_classes=num_classes_pretrained, in_chans=kwargs.get('in_chans', 3), - filter_fn=pretrained_filter_fn, strict=pretrained_strict) + filter_fn=pretrained_filter_fn, strict=pretrained_strict, + hf_checkpoint=hf_checkpoint, hf_revision=hf_revision) if features: feature_cls = FeatureListNet @@ -394,3 +403,17 @@ def build_model_with_cfg( model.default_cfg = default_cfg_for_features(default_cfg) # add back default_cfg return model + + +def load_cfg_from_json(json_file: Union[str, os.PathLike]): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + +def load_hf_checkpoint_config(checkpoint: str, revision: Optional[str] = None): + url = hf_hub_url(checkpoint, "config.json", revision=revision) + cached_filed = cached_download( + url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir() + ) + return load_cfg_from_json(cached_filed) From 7dafa71c0e9465efdfb47cbfba1b498cfa4259b8 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 11 Feb 2021 10:47:30 -0500 Subject: [PATCH 4/5] Use the model hub config --- timm/models/factory.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/timm/models/factory.py b/timm/models/factory.py index 6f196208..fd3ce804 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -54,11 +54,11 @@ def create_model( else: try: model_cfg = load_hf_checkpoint_config(model_name, revision=kwargs.get("hf_revision")) - # This does not work, but there is probably a way to have the values in the config override - # the defaults if needed. - # model_args["model_cfg"] = model_cfg create_fn = model_entrypoint(model_cfg.pop("architecture")) model = create_fn(**model_args, **kwargs) + # Probably need some extra stuff, but this is a PoC of how the config in the model hub + # could overwrite the default config values. + model.default_cfg.update(model_cfg) except Exception as e: raise RuntimeError('Unknown model or checkpoint from the Hugging Face hub (%s)' % model_name) From 5482d2e501500c27afafd15a4799732c6e2eb24d Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 11 Feb 2021 11:11:26 -0500 Subject: [PATCH 5/5] Actually load the model --- timm/models/factory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/factory.py b/timm/models/factory.py index fd3ce804..fa23150c 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -55,10 +55,10 @@ def create_model( try: model_cfg = load_hf_checkpoint_config(model_name, revision=kwargs.get("hf_revision")) create_fn = model_entrypoint(model_cfg.pop("architecture")) - model = create_fn(**model_args, **kwargs) + model = create_fn(**model_args, hf_checkpoint=model_name, **kwargs) # Probably need some extra stuff, but this is a PoC of how the config in the model hub # could overwrite the default config values. - model.default_cfg.update(model_cfg) + # model.default_cfg.update(model_cfg) except Exception as e: raise RuntimeError('Unknown model or checkpoint from the Hugging Face hub (%s)' % model_name)