|
|
|
@ -3,10 +3,12 @@ import logging
|
|
|
|
|
import os
|
|
|
|
|
from functools import partial
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from typing import Union
|
|
|
|
|
from tempfile import TemporaryDirectory
|
|
|
|
|
from typing import Optional, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
from torch.hub import get_dir
|
|
|
|
|
except ImportError:
|
|
|
|
@ -15,7 +17,10 @@ except ImportError:
|
|
|
|
|
from timm import __version__
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
from huggingface_hub import HfApi, HfFolder, Repository, hf_hub_download, hf_hub_url
|
|
|
|
|
from huggingface_hub import (create_repo, get_hf_file_metadata,
|
|
|
|
|
hf_hub_download, hf_hub_url,
|
|
|
|
|
repo_type_and_id_from_hf_id, upload_folder)
|
|
|
|
|
from huggingface_hub.utils import EntryNotFoundError
|
|
|
|
|
hf_hub_download = partial(hf_hub_download, library_name="timm", library_version=__version__)
|
|
|
|
|
_has_hf_hub = True
|
|
|
|
|
except ImportError:
|
|
|
|
@ -121,56 +126,45 @@ def save_for_hf(model, save_directory, model_config=None):
|
|
|
|
|
|
|
|
|
|
def push_to_hf_hub(
|
|
|
|
|
model,
|
|
|
|
|
local_dir,
|
|
|
|
|
repo_namespace_or_url=None,
|
|
|
|
|
commit_message='Add model',
|
|
|
|
|
use_auth_token=True,
|
|
|
|
|
git_email=None,
|
|
|
|
|
git_user=None,
|
|
|
|
|
revision=None,
|
|
|
|
|
model_config=None,
|
|
|
|
|
repo_id: str,
|
|
|
|
|
commit_message: str ='Add model',
|
|
|
|
|
token: Optional[str] = None,
|
|
|
|
|
revision: Optional[str] = None,
|
|
|
|
|
private: bool = False,
|
|
|
|
|
create_pr: bool = False,
|
|
|
|
|
model_config: Optional[dict] = None,
|
|
|
|
|
):
|
|
|
|
|
if isinstance(use_auth_token, str):
|
|
|
|
|
token = use_auth_token
|
|
|
|
|
else:
|
|
|
|
|
token = HfFolder.get_token()
|
|
|
|
|
if token is None:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"You must login to the Hugging Face hub on this computer by typing `huggingface-cli login` and "
|
|
|
|
|
"entering your credentials to use `use_auth_token=True`. Alternatively, you can pass your own "
|
|
|
|
|
"token as the `use_auth_token` argument."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if repo_namespace_or_url:
|
|
|
|
|
repo_owner, repo_name = repo_namespace_or_url.rstrip('/').split('/')[-2:]
|
|
|
|
|
else:
|
|
|
|
|
repo_owner = HfApi().whoami(token)['name']
|
|
|
|
|
repo_name = Path(local_dir).name
|
|
|
|
|
|
|
|
|
|
repo_id = f'{repo_owner}/{repo_name}'
|
|
|
|
|
repo_url = f'https://huggingface.co/{repo_id}'
|
|
|
|
|
|
|
|
|
|
# Create repo if doesn't exist yet
|
|
|
|
|
HfApi().create_repo(repo_id, token=use_auth_token, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
repo = Repository(
|
|
|
|
|
local_dir,
|
|
|
|
|
clone_from=repo_url,
|
|
|
|
|
use_auth_token=use_auth_token,
|
|
|
|
|
git_user=git_user,
|
|
|
|
|
git_email=git_email,
|
|
|
|
|
revision=revision,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Prepare a default model card that includes the necessary tags to enable inference.
|
|
|
|
|
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_name}'
|
|
|
|
|
with repo.commit(commit_message):
|
|
|
|
|
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
# Infer complete repo_id from repo_url
|
|
|
|
|
# Can be different from the input `repo_id` if repo_owner was implicit
|
|
|
|
|
_, repo_owner, repo_name = repo_type_and_id_from_hf_id(repo_url)
|
|
|
|
|
repo_id = f"{repo_owner}/{repo_name}"
|
|
|
|
|
|
|
|
|
|
# Check if README file already exist in repo
|
|
|
|
|
try:
|
|
|
|
|
get_hf_file_metadata(hf_hub_url(repo_id=repo_id, filename="README.md", revision=revision))
|
|
|
|
|
has_readme = True
|
|
|
|
|
except EntryNotFoundError:
|
|
|
|
|
has_readme = False
|
|
|
|
|
|
|
|
|
|
# Dump model and push to Hub
|
|
|
|
|
with TemporaryDirectory() as tmpdir:
|
|
|
|
|
# Save model weights and config.
|
|
|
|
|
save_for_hf(model, repo.local_dir, model_config=model_config)
|
|
|
|
|
save_for_hf(model, tmpdir, model_config=model_config)
|
|
|
|
|
|
|
|
|
|
# Save a model card if it doesn't exist.
|
|
|
|
|
readme_path = Path(repo.local_dir) / 'README.md'
|
|
|
|
|
if not readme_path.exists():
|
|
|
|
|
# Add readme if does not exist
|
|
|
|
|
if not has_readme:
|
|
|
|
|
readme_path = Path(tmpdir) / "README.md"
|
|
|
|
|
readme_text = f'---\ntags:\n- image-classification\n- timm\nlibrary_tag: timm\n---\n# Model card for {repo_id}'
|
|
|
|
|
readme_path.write_text(readme_text)
|
|
|
|
|
|
|
|
|
|
return repo.git_remote_url()
|
|
|
|
|
# Upload model and return
|
|
|
|
|
return upload_folder(
|
|
|
|
|
repo_id=repo_id,
|
|
|
|
|
folder_path=tmpdir,
|
|
|
|
|
revision=revision,
|
|
|
|
|
create_pr=create_pr,
|
|
|
|
|
commit_message=commit_message,
|
|
|
|
|
)
|
|
|
|
|