Add support to load safetensors weights

pull/1680/head
testbot 1 year ago committed by Ross Wightman
parent f35d6ea57b
commit 8470e29541

@ -17,21 +17,26 @@ import os
import glob
import hashlib
from timm.models import load_state_dict
import safetensors.torch
DEFAULT_OUTPUT = "./average.pth"
DEFAULT_SAFE_OUTPUT = "./average.safetensors"
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
parser.add_argument('--input', default='', type=str, metavar='PATH',
help='path to base input folder containing checkpoints')
parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD',
help='checkpoint filter (path wildcard)')
parser.add_argument('--output', default='./averaged.pth', type=str, metavar='PATH',
help='output filename')
parser.add_argument('--output', default=DEFAULT_OUTPUT, type=str, metavar='PATH',
help=f'Output filename. Defaults to {DEFAULT_SAFE_OUTPUT} when passing --safetensors.')
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
help='Force not using ema version of weights (if present)')
parser.add_argument('--no-sort', dest='no_sort', action='store_true',
help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant')
parser.add_argument('-n', type=int, default=10, metavar='N',
help='Number of checkpoints to average')
parser.add_argument('--safetensors', action='store_true',
help='Save weights using safetensors instead of the default torch way (pickle).')
def checkpoint_metric(checkpoint_path):
if not checkpoint_path or not os.path.isfile(checkpoint_path):
@ -55,6 +60,15 @@ def main():
# by default sort by checkpoint metric (if present) and avg top n checkpoints
args.sort = not args.no_sort
if args.safetensors and args.output == DEFAULT_OUTPUT:
# Default path changes if using safetensors
args.output = DEFAULT_SAFE_OUTPUT
if args.safetensors and not args.output.endswith(".safetensors"):
print(
"Warning: saving weights as safetensors but output file extension is not "
f"set to '.safetensors': {args.output}"
)
if os.path.exists(args.output):
print("Error: Output filename ({}) already exists.".format(args.output))
exit(1)
@ -107,10 +121,13 @@ def main():
v = v.clamp(float32_info.min, float32_info.max)
final_state_dict[k] = v.to(dtype=torch.float32)
try:
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False)
except:
torch.save(final_state_dict, args.output)
if args.safetensors:
safetensors.torch.save_file(final_state_dict, args.output)
else:
try:
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False)
except:
torch.save(final_state_dict, args.output)
with open(args.output, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()

@ -11,8 +11,8 @@ import torch
import argparse
import os
import hashlib
import safetensors.torch
import shutil
from collections import OrderedDict
from timm.models import load_state_dict
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
@ -24,6 +24,8 @@ parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
help='use ema version of weights if present')
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
parser.add_argument('--safetensors', action='store_true',
help='Save weights using safetensors instead of the default torch way (pickle).')
_TEMP_NAME = './_checkpoint.pth'
@ -35,10 +37,10 @@ def main():
print("Error: Output filename ({}) already exists.".format(args.output))
exit(1)
clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn)
clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn, safe_serialization=args.safetensors)
def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False):
def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, safe_serialization: bool=False):
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
if checkpoint and os.path.isfile(checkpoint):
print("=> Loading checkpoint '{}'".format(checkpoint))
@ -53,10 +55,13 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False):
new_state_dict[name] = v
print("=> Loaded state_dict from '{}'".format(checkpoint))
try:
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
except:
torch.save(new_state_dict, _TEMP_NAME)
if safe_serialization:
safetensors.torch.save_file(new_state_dict, _TEMP_NAME)
else:
try:
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
except:
torch.save(new_state_dict, _TEMP_NAME)
with open(_TEMP_NAME, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()
@ -67,7 +72,7 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False):
else:
checkpoint_root = ''
checkpoint_base = os.path.splitext(checkpoint)[0]
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + '.pth'
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + ('.safetensors' if safe_serialization else '.pth')
shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename))
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
return final_filename

@ -2,3 +2,4 @@ torch>=1.7
torchvision
pyyaml
huggingface_hub
safetensors>=0.2

@ -7,6 +7,7 @@ import os
from collections import OrderedDict
import torch
import safetensors.torch
import timm.models._builder
@ -26,7 +27,12 @@ def clean_state_dict(state_dict):
def load_state_dict(checkpoint_path, use_ema=True):
if checkpoint_path and os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu')
# Check if safetensors or not and load weights accordingly
if str(checkpoint_path).endswith(".safetensors"):
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
else:
checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict_key = ''
if isinstance(checkpoint, dict):
if use_ema and checkpoint.get('state_dict_ema', None) is not None:

@ -2,19 +2,25 @@ import hashlib
import json
import logging
import os
import sys
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional, Union
from typing import Iterable, Optional, Union
import torch
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
import safetensors.torch
try:
from torch.hub import get_dir
except ImportError:
from torch.hub import _get_torch_home as get_dir
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
from timm import __version__
from timm.models._pretrained import filter_pretrained_cfg
@ -35,6 +41,9 @@ _logger = logging.getLogger(__name__)
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
# Default name for a weights file hosted on the Huggingface Hub.
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
def get_cache_dir(child_dir=''):
"""
@ -150,11 +159,23 @@ def load_model_config_from_hf(model_id: str):
return pretrained_cfg, model_name
def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'):
def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
assert has_hf_hub(True)
cached_file = download_from_hf(model_id, filename)
state_dict = torch.load(cached_file, map_location='cpu')
return state_dict
hf_model_id, hf_revision = hf_split(model_id)
# Look for .safetensors alternatives and load from it if it exists
for safe_filename in _get_safe_alternatives(filename):
try:
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
_logger.warning(f"[{model_id}] Safe alternative available for '{filename}' (as '{safe_filename}'). Loading weights using safetensors.")
return safetensors.torch.load_file(cached_safe_file, device="cpu")
except EntryNotFoundError:
pass
# Otherwise, load using pytorch.load
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
_logger.warning(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
return torch.load(cached_file, map_location='cpu')
def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None):
@ -195,13 +216,22 @@ def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = N
json.dump(hf_config, f, indent=2)
def save_for_hf(model, save_directory: str, model_config: Optional[dict] = None):
def save_for_hf(
model,
save_directory: str,
model_config: Optional[dict] = None,
safe_serialization: Union[bool, Literal["both"]] = False
):
assert has_hf_hub(True)
save_directory = Path(save_directory)
save_directory.mkdir(exist_ok=True, parents=True)
weights_path = save_directory / 'pytorch_model.bin'
torch.save(model.state_dict(), weights_path)
# Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
tensors = model.state_dict()
if safe_serialization is True or safe_serialization == "both":
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
if safe_serialization is False or safe_serialization == "both":
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
config_path = save_directory / 'config.json'
save_config_for_hf(model, config_path, model_config=model_config)
@ -217,7 +247,15 @@ def push_to_hf_hub(
create_pr: bool = False,
model_config: Optional[dict] = None,
model_card: Optional[dict] = None,
safe_serialization: Union[bool, Literal["both"]] = False
):
"""
Arguments:
(...)
safe_serialization (`bool` or `"both"`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
Can be set to `"both"` in order to push both safe and unsafe weights.
"""
# Create repo if it doesn't exist yet
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
@ -236,7 +274,7 @@ def push_to_hf_hub(
# Dump model and push to Hub
with TemporaryDirectory() as tmpdir:
# Save model weights and config.
save_for_hf(model, tmpdir, model_config=model_config)
save_for_hf(model, tmpdir, model_config=model_config, safe_serialization=safe_serialization)
# Add readme if it does not exist
if not has_readme:
@ -302,3 +340,15 @@ def generate_readme(model_card: dict, model_name: str):
for c in citations:
readme_text += f"```bibtex\n{c}\n```\n"
return readme_text
def _get_safe_alternatives(filename: str) -> Iterable[str]:
"""Returns potential safetensors alternatives for a given filename.
Use case:
When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it.
Main use case is filename "pytorch_model.bin" => check for "model.safetensors" or "pytorch_model.safetensors".
"""
if filename == HF_WEIGHTS_NAME:
yield HF_SAFE_WEIGHTS_NAME
if filename.endswith(".bin"):
yield filename[:-4] + ".safetensors"
Loading…
Cancel
Save