|
|
|
@ -7,15 +7,21 @@ from functools import partial
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from tempfile import TemporaryDirectory
|
|
|
|
|
from typing import Iterable, Optional, Union
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
|
|
|
|
|
import safetensors.torch
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
from torch.hub import get_dir
|
|
|
|
|
except ImportError:
|
|
|
|
|
from torch.hub import _get_torch_home as get_dir
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
import safetensors.torch
|
|
|
|
|
_has_safetensors = True
|
|
|
|
|
except ImportError:
|
|
|
|
|
_has_safetensors = False
|
|
|
|
|
|
|
|
|
|
if sys.version_info >= (3, 8):
|
|
|
|
|
from typing import Literal
|
|
|
|
|
else:
|
|
|
|
@ -45,6 +51,7 @@ __all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'l
|
|
|
|
|
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
|
|
|
|
|
HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_cache_dir(child_dir=''):
|
|
|
|
|
"""
|
|
|
|
|
Returns the location of the directory where models are cached (and creates it if necessary).
|
|
|
|
@ -164,21 +171,28 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
|
|
|
|
|
hf_model_id, hf_revision = hf_split(model_id)
|
|
|
|
|
|
|
|
|
|
# Look for .safetensors alternatives and load from it if it exists
|
|
|
|
|
if _has_safetensors:
|
|
|
|
|
for safe_filename in _get_safe_alternatives(filename):
|
|
|
|
|
try:
|
|
|
|
|
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
|
|
|
|
|
_logger.info(f"[{model_id}] Safe alternative available for '{filename}' (as '{safe_filename}'). Loading weights using safetensors.")
|
|
|
|
|
_logger.info(
|
|
|
|
|
f"[{model_id}] Safe alternative available for '{filename}' "
|
|
|
|
|
f"(as '{safe_filename}'). Loading weights using safetensors.")
|
|
|
|
|
return safetensors.torch.load_file(cached_safe_file, device="cpu")
|
|
|
|
|
except EntryNotFoundError:
|
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
|
# Otherwise, load using pytorch.load
|
|
|
|
|
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
|
|
|
|
|
_logger.info(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
|
|
|
|
|
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
|
|
|
|
|
return torch.load(cached_file, map_location='cpu')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None):
|
|
|
|
|
def save_config_for_hf(
|
|
|
|
|
model,
|
|
|
|
|
config_path: str,
|
|
|
|
|
model_config: Optional[dict] = None
|
|
|
|
|
):
|
|
|
|
|
model_config = model_config or {}
|
|
|
|
|
hf_config = {}
|
|
|
|
|
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
|
|
|
|
@ -220,8 +234,8 @@ def save_for_hf(
|
|
|
|
|
model,
|
|
|
|
|
save_directory: str,
|
|
|
|
|
model_config: Optional[dict] = None,
|
|
|
|
|
safe_serialization: Union[bool, Literal["both"]] = False
|
|
|
|
|
):
|
|
|
|
|
safe_serialization: Union[bool, Literal["both"]] = False,
|
|
|
|
|
):
|
|
|
|
|
assert has_hf_hub(True)
|
|
|
|
|
save_directory = Path(save_directory)
|
|
|
|
|
save_directory.mkdir(exist_ok=True, parents=True)
|
|
|
|
@ -229,6 +243,7 @@ def save_for_hf(
|
|
|
|
|
# Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
|
|
|
|
|
tensors = model.state_dict()
|
|
|
|
|
if safe_serialization is True or safe_serialization == "both":
|
|
|
|
|
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
|
|
|
|
|
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
|
|
|
|
|
if safe_serialization is False or safe_serialization == "both":
|
|
|
|
|
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
|
|
|
|
@ -247,7 +262,7 @@ def push_to_hf_hub(
|
|
|
|
|
create_pr: bool = False,
|
|
|
|
|
model_config: Optional[dict] = None,
|
|
|
|
|
model_card: Optional[dict] = None,
|
|
|
|
|
safe_serialization: Union[bool, Literal["both"]] = False
|
|
|
|
|
safe_serialization: Union[bool, Literal["both"]] = False,
|
|
|
|
|
):
|
|
|
|
|
"""
|
|
|
|
|
Arguments:
|
|
|
|
@ -341,6 +356,7 @@ def generate_readme(model_card: dict, model_name: str):
|
|
|
|
|
readme_text += f"```bibtex\n{c}\n```\n"
|
|
|
|
|
return readme_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_safe_alternatives(filename: str) -> Iterable[str]:
|
|
|
|
|
"""Returns potential safetensors alternatives for a given filename.
|
|
|
|
|
|
|
|
|
@ -350,5 +366,5 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]:
|
|
|
|
|
"""
|
|
|
|
|
if filename == HF_WEIGHTS_NAME:
|
|
|
|
|
yield HF_SAFE_WEIGHTS_NAME
|
|
|
|
|
if filename.endswith(".bin"):
|
|
|
|
|
yield filename[:-4] + ".safetensors"
|
|
|
|
|
if filename != HF_WEIGHTS_NAME and filename.endswith(".bin"):
|
|
|
|
|
return filename[:-4] + ".safetensors"
|
|
|
|
|