|
|
@ -2,19 +2,25 @@ import hashlib
|
|
|
|
import json
|
|
|
|
import json
|
|
|
|
import logging
|
|
|
|
import logging
|
|
|
|
import os
|
|
|
|
import os
|
|
|
|
|
|
|
|
import sys
|
|
|
|
from functools import partial
|
|
|
|
from functools import partial
|
|
|
|
from pathlib import Path
|
|
|
|
from pathlib import Path
|
|
|
|
from tempfile import TemporaryDirectory
|
|
|
|
from tempfile import TemporaryDirectory
|
|
|
|
from typing import Optional, Union
|
|
|
|
from typing import Iterable, Optional, Union
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
|
|
|
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
|
|
|
|
|
|
|
import safetensors.torch
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
try:
|
|
|
|
from torch.hub import get_dir
|
|
|
|
from torch.hub import get_dir
|
|
|
|
except ImportError:
|
|
|
|
except ImportError:
|
|
|
|
from torch.hub import _get_torch_home as get_dir
|
|
|
|
from torch.hub import _get_torch_home as get_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if sys.version_info >= (3, 8):
|
|
|
|
|
|
|
|
from typing import Literal
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
from typing_extensions import Literal
|
|
|
|
|
|
|
|
|
|
|
|
from timm import __version__
|
|
|
|
from timm import __version__
|
|
|
|
from timm.models._pretrained import filter_pretrained_cfg
|
|
|
|
from timm.models._pretrained import filter_pretrained_cfg
|
|
|
|
|
|
|
|
|
|
|
@ -35,6 +41,9 @@ _logger = logging.getLogger(__name__)
|
|
|
|
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
|
|
|
|
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
|
|
|
|
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
|
|
|
|
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Default name for a weights file hosted on the Huggingface Hub.
|
|
|
|
|
|
|
|
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
|
|
|
|
|
|
|
|
HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
|
|
|
|
|
|
|
|
|
|
|
|
def get_cache_dir(child_dir=''):
|
|
|
|
def get_cache_dir(child_dir=''):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
@ -150,11 +159,23 @@ def load_model_config_from_hf(model_id: str):
|
|
|
|
return pretrained_cfg, model_name
|
|
|
|
return pretrained_cfg, model_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
|
|
|
|
def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
|
|
|
|
assert has_hf_hub(True)
|
|
|
|
assert has_hf_hub(True)
|
|
|
|
cached_file = download_from_hf(model_id, filename)
|
|
|
|
hf_model_id, hf_revision = hf_split(model_id)
|
|
|
|
state_dict = torch.load(cached_file, map_location='cpu')
|
|
|
|
|
|
|
|
return state_dict
|
|
|
|
# Look for .safetensors alternatives and load from it if it exists
|
|
|
|
|
|
|
|
for safe_filename in _get_safe_alternatives(filename):
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
|
|
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
|
|
|
|
|
|
|
|
_logger.warning(f"[{model_id}] Safe alternative available for '{filename}' (as '{safe_filename}'). Loading weights using safetensors.")
|
|
|
|
|
|
|
|
return safetensors.torch.load_file(cached_safe_file, device="cpu")
|
|
|
|
|
|
|
|
except EntryNotFoundError:
|
|
|
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Otherwise, load using pytorch.load
|
|
|
|
|
|
|
|
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
|
|
|
|
|
|
|
|
_logger.warning(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
|
|
|
|
|
|
|
|
return torch.load(cached_file, map_location='cpu')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None):
|
|
|
|
def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None):
|
|
|
@ -195,13 +216,22 @@ def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = N
|
|
|
|
json.dump(hf_config, f, indent=2)
|
|
|
|
json.dump(hf_config, f, indent=2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_for_hf(model, save_directory: str, model_config: Optional[dict] = None):
|
|
|
|
def save_for_hf(
|
|
|
|
|
|
|
|
model,
|
|
|
|
|
|
|
|
save_directory: str,
|
|
|
|
|
|
|
|
model_config: Optional[dict] = None,
|
|
|
|
|
|
|
|
safe_serialization: Union[bool, Literal["both"]] = False
|
|
|
|
|
|
|
|
):
|
|
|
|
assert has_hf_hub(True)
|
|
|
|
assert has_hf_hub(True)
|
|
|
|
save_directory = Path(save_directory)
|
|
|
|
save_directory = Path(save_directory)
|
|
|
|
save_directory.mkdir(exist_ok=True, parents=True)
|
|
|
|
save_directory.mkdir(exist_ok=True, parents=True)
|
|
|
|
|
|
|
|
|
|
|
|
weights_path = save_directory / 'pytorch_model.bin'
|
|
|
|
# Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
|
|
|
|
torch.save(model.state_dict(), weights_path)
|
|
|
|
tensors = model.state_dict()
|
|
|
|
|
|
|
|
if safe_serialization is True or safe_serialization == "both":
|
|
|
|
|
|
|
|
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
|
|
|
|
|
|
|
|
if safe_serialization is False or safe_serialization == "both":
|
|
|
|
|
|
|
|
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
|
|
|
|
|
|
|
|
|
|
|
|
config_path = save_directory / 'config.json'
|
|
|
|
config_path = save_directory / 'config.json'
|
|
|
|
save_config_for_hf(model, config_path, model_config=model_config)
|
|
|
|
save_config_for_hf(model, config_path, model_config=model_config)
|
|
|
@ -217,7 +247,15 @@ def push_to_hf_hub(
|
|
|
|
create_pr: bool = False,
|
|
|
|
create_pr: bool = False,
|
|
|
|
model_config: Optional[dict] = None,
|
|
|
|
model_config: Optional[dict] = None,
|
|
|
|
model_card: Optional[dict] = None,
|
|
|
|
model_card: Optional[dict] = None,
|
|
|
|
|
|
|
|
safe_serialization: Union[bool, Literal["both"]] = False
|
|
|
|
):
|
|
|
|
):
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
|
|
|
|
(...)
|
|
|
|
|
|
|
|
safe_serialization (`bool` or `"both"`, *optional*, defaults to `False`):
|
|
|
|
|
|
|
|
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
|
|
|
|
|
|
|
|
Can be set to `"both"` in order to push both safe and unsafe weights.
|
|
|
|
|
|
|
|
"""
|
|
|
|
# Create repo if it doesn't exist yet
|
|
|
|
# Create repo if it doesn't exist yet
|
|
|
|
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
|
|
|
|
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
@ -236,7 +274,7 @@ def push_to_hf_hub(
|
|
|
|
# Dump model and push to Hub
|
|
|
|
# Dump model and push to Hub
|
|
|
|
with TemporaryDirectory() as tmpdir:
|
|
|
|
with TemporaryDirectory() as tmpdir:
|
|
|
|
# Save model weights and config.
|
|
|
|
# Save model weights and config.
|
|
|
|
save_for_hf(model, tmpdir, model_config=model_config)
|
|
|
|
save_for_hf(model, tmpdir, model_config=model_config, safe_serialization=safe_serialization)
|
|
|
|
|
|
|
|
|
|
|
|
# Add readme if it does not exist
|
|
|
|
# Add readme if it does not exist
|
|
|
|
if not has_readme:
|
|
|
|
if not has_readme:
|
|
|
@ -302,3 +340,15 @@ def generate_readme(model_card: dict, model_name: str):
|
|
|
|
for c in citations:
|
|
|
|
for c in citations:
|
|
|
|
readme_text += f"```bibtex\n{c}\n```\n"
|
|
|
|
readme_text += f"```bibtex\n{c}\n```\n"
|
|
|
|
return readme_text
|
|
|
|
return readme_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_safe_alternatives(filename: str) -> Iterable[str]:
|
|
|
|
|
|
|
|
"""Returns potential safetensors alternatives for a given filename.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Use case:
|
|
|
|
|
|
|
|
When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it.
|
|
|
|
|
|
|
|
Main use case is filename "pytorch_model.bin" => check for "model.safetensors" or "pytorch_model.safetensors".
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
if filename == HF_WEIGHTS_NAME:
|
|
|
|
|
|
|
|
yield HF_SAFE_WEIGHTS_NAME
|
|
|
|
|
|
|
|
if filename.endswith(".bin"):
|
|
|
|
|
|
|
|
yield filename[:-4] + ".safetensors"
|