@ -2,10 +2,11 @@ import json
import logging
import logging
import os
import os
from functools import partial
from functools import partial
from typing import Union , Optional
from pathlib import Path
from typing import Union
import torch
import torch
from torch . hub import load_state_dict_from_url , download_url_to_file , urlparse , HASH_REGEX
from torch . hub import HASH_REGEX , download_url_to_file , urlparse
try :
try :
from torch . hub import get_dir
from torch . hub import get_dir
except ImportError :
except ImportError :
@ -13,12 +14,12 @@ except ImportError:
from timm import __version__
from timm import __version__
try :
try :
from huggingface_hub import hf_hub_url
from huggingface_hub import HfApi , HfFolder , Repository , cached_download , hf_hub_url
from huggingface_hub import cached_download
cached_download = partial ( cached_download , library_name = " timm " , library_version = __version__ )
cached_download = partial ( cached_download , library_name = " timm " , library_version = __version__ )
_has_hf_hub = True
except ImportError :
except ImportError :
hf_hub_url = None
cached_download = None
cached_download = None
_has_hf_hub = False
_logger = logging . getLogger ( __name__ )
_logger = logging . getLogger ( __name__ )
@ -53,11 +54,11 @@ def download_cached_file(url, check_hash=True, progress=False):
def has_hf_hub ( necessary = False ) :
def has_hf_hub ( necessary = False ) :
if hf_hub_url is None and necessary :
if not _has_ hf_hub and necessary :
# if no HF Hub module installed and it is necessary to continue, raise error
# if no HF Hub module installed and it is necessary to continue, raise error
raise RuntimeError (
raise RuntimeError (
' Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`. ' )
' Hugging Face hub model specified but package not installed. Run `pip install huggingface_hub`. ' )
return hf_hub_url is not None
return _has_ hf_hub
def hf_split ( hf_id ) :
def hf_split ( hf_id ) :
@ -94,3 +95,77 @@ def load_state_dict_from_hf(model_id: str):
cached_file = _download_from_hf ( model_id , ' pytorch_model.bin ' )
cached_file = _download_from_hf ( model_id , ' pytorch_model.bin ' )
state_dict = torch . load ( cached_file , map_location = ' cpu ' )
state_dict = torch . load ( cached_file , map_location = ' cpu ' )
return state_dict
return state_dict
def save_for_hf ( model , save_directory , model_config = None ) :
assert has_hf_hub ( True )
model_config = model_config or { }
save_directory = Path ( save_directory )
save_directory . mkdir ( exist_ok = True , parents = True )
weights_path = save_directory / ' pytorch_model.bin '
torch . save ( model . state_dict ( ) , weights_path )
config_path = save_directory / ' config.json '
hf_config = model . default_cfg
hf_config [ ' num_classes ' ] = model_config . pop ( ' num_classes ' , model . num_classes )
hf_config [ ' num_features ' ] = model_config . pop ( ' num_features ' , model . num_features )
hf_config [ ' labels ' ] = model_config . pop ( ' labels ' , [ f " LABEL_ { i } " for i in range ( hf_config [ ' num_classes ' ] ) ] )
hf_config . update ( model_config )
with config_path . open ( ' w ' ) as f :
json . dump ( hf_config , f , indent = 2 )
def push_to_hf_hub (
model ,
local_dir ,
repo_namespace_or_url = None ,
commit_message = ' Add model ' ,
use_auth_token = True ,
git_email = None ,
git_user = None ,
revision = None ,
model_config = None ,
) :
if repo_namespace_or_url :
repo_owner , repo_name = repo_namespace_or_url . rstrip ( ' / ' ) . split ( ' / ' ) [ - 2 : ]
else :
if isinstance ( use_auth_token , str ) :
token = use_auth_token
else :
token = HfFolder . get_token ( )
if token is None :
raise ValueError (
" You must login to the Hugging Face hub on this computer by typing `transformers-cli login` and "
" entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
" token as the `use_auth_token` argument. "
)
repo_owner = HfApi ( ) . whoami ( token ) [ ' name ' ]
repo_name = Path ( local_dir ) . name
repo_url = f ' https://huggingface.co/ { repo_owner } / { repo_name } '
repo = Repository (
local_dir ,
clone_from = repo_url ,
use_auth_token = use_auth_token ,
git_user = git_user ,
git_email = git_email ,
revision = revision ,
)
# Prepare a default model card that includes the necessary tags to enable inference.
readme_text = f ' --- \n tags: \n - image-classification \n - timm \n library_tag: timm \n --- \n # Model card for { repo_name } '
with repo . commit ( commit_message ) :
# Save model weights and config.
save_for_hf ( model , repo . local_dir , model_config = model_config )
# Save a model card if it doesn't exist.
readme_path = Path ( repo . local_dir ) / ' README.md '
if not readme_path . exists ( ) :
readme_path . write_text ( readme_text )
return repo . git_remote_url ( )