From a00011188e248d0380618533a8e78269bf489f75 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger Date: Thu, 11 Feb 2021 09:05:15 -0500 Subject: [PATCH] PoC to add HF hub --- timm/models/helpers.py | 47 ++++++++++++++++++++++++------------------ timm/models/resnet.py | 2 +- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index d9b501da..f743f600 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -17,13 +17,31 @@ try: except ImportError: from torch.hub import _get_torch_home as get_dir +from huggingface_hub import cached_download + from .features import FeatureListNet, FeatureDictNet, FeatureHookNet from .layers import Conv2dSame, Linear +from ..version import __version__ _logger = logging.getLogger(__name__) +def get_cache_dir(): + """ + Returns the location of the directory where models are cached (and creates it if necessary). + """ + # Issue warning to move data if old env is set + if os.getenv('TORCH_MODEL_ZOO'): + _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') + + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(model_dir, exist_ok=True) + return model_dir + + def load_state_dict(checkpoint_path, use_ema=False): if checkpoint_path and os.path.isfile(checkpoint_path): checkpoint = torch.load(checkpoint_path, map_location='cpu') @@ -120,25 +138,10 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_ return url = cfg['url'] - # Issue warning to move data if old env is set - if os.getenv('TORCH_MODEL_ZOO'): - _logger.warning('TORCH_MODEL_ZOO is deprecated, please use env TORCH_HOME instead') - - hub_dir = get_dir() - model_dir = os.path.join(hub_dir, 'checkpoints') - - os.makedirs(model_dir, exist_ok=True) - - parts = urlparse(url) - filename = os.path.basename(parts.path) - cached_file = os.path.join(model_dir, filename) - if not os.path.exists(cached_file): - _logger.info('Downloading: "{}" to {}\n'.format(url, cached_file)) - hash_prefix = None - if check_hash: - r = HASH_REGEX.search(filename) # r is Optional[Match[str]] - hash_prefix = r.group(1) if r else None - download_url_to_file(url, cached_file, hash_prefix, progress=progress) + # TODO, progress and check_hash are ignored. + cached_filed = cached_download( + url, library_name="timm", library_version=__version__, cache_dir=get_cache_dir() + ) if load_fn is not None: load_fn(model, cached_file) @@ -180,7 +183,11 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non _logger.warning("No pretrained weights exist for this model. Using random initialization.") return - state_dict = load_state_dict_from_url(cfg['url'], progress=progress, map_location='cpu') + # TODO, progress is ignored. + cached_filed = cached_download( + cfg['url'], library_name="timm", library_version=__version__, cache_dir=get_cache_dir() + ) + state_dict = torch.load(cached_filed, map_location='cpu') if filter_fn is not None: state_dict = filter_fn(state_dict) diff --git a/timm/models/resnet.py b/timm/models/resnet.py index 6dec9d53..fb1052f2 100644 --- a/timm/models/resnet.py +++ b/timm/models/resnet.py @@ -52,7 +52,7 @@ default_cfgs = { url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50_ram-a26f946b.pth', interpolation='bicubic'), 'resnet50d': _cfg( - url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet50d_ra2-464e36ba.pth', + url='https://huggingface.co/sgugger/resnet50d/resolve/main/pytorch_model.pth', interpolation='bicubic', first_conv='conv1.0'), 'resnet101': _cfg(url='', interpolation='bicubic'), 'resnet101d': _cfg(