👽 use hf_hub_download instead of cached_download

pull/1351/head
nateraw 2 years ago
parent 324a4e58b6
commit 51cca82aa1

@ -14,11 +14,11 @@ except ImportError:
from timm import __version__ from timm import __version__
try: try:
from huggingface_hub import HfApi, HfFolder, Repository, cached_download, hf_hub_url from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, hf_hub_url
cached_download = partial(cached_download, library_name="timm", library_version=__version__) hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
_has_hf_hub = True _has_hf_hub = True
except ImportError: except ImportError:
cached_download = None hf_hub_download = None
_has_hf_hub = False _has_hf_hub = False
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
@ -78,8 +78,7 @@ def load_cfg_from_json(json_file: Union[str, os.PathLike]):
def _download_from_hf(model_id: str, filename: str): def _download_from_hf(model_id: str, filename: str):
hf_model_id, hf_revision = hf_split(model_id) hf_model_id, hf_revision = hf_split(model_id)
url = hf_hub_url(hf_model_id, filename, revision=hf_revision) return hf_hub_download(hf_model_id, filename, revision=hf_revision, cache_dir=get_cache_dir('hf'))
return cached_download(url, cache_dir=get_cache_dir('hf'))
def load_model_config_from_hf(model_id: str): def load_model_config_from_hf(model_id: str):

Loading…
Cancel
Save