Add support to load safetensors weights

pull/1680/head
testbot 2 years ago committed by Ross Wightman
parent f35d6ea57b
commit 8470e29541

@ -17,21 +17,26 @@ import os
import glob import glob
import hashlib import hashlib
from timm.models import load_state_dict from timm.models import load_state_dict
import safetensors.torch
DEFAULT_OUTPUT = "./average.pth"
DEFAULT_SAFE_OUTPUT = "./average.safetensors"
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager') parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
parser.add_argument('--input', default='', type=str, metavar='PATH', parser.add_argument('--input', default='', type=str, metavar='PATH',
help='path to base input folder containing checkpoints') help='path to base input folder containing checkpoints')
parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD', parser.add_argument('--filter', default='*.pth.tar', type=str, metavar='WILDCARD',
help='checkpoint filter (path wildcard)') help='checkpoint filter (path wildcard)')
parser.add_argument('--output', default='./averaged.pth', type=str, metavar='PATH', parser.add_argument('--output', default=DEFAULT_OUTPUT, type=str, metavar='PATH',
help='output filename') help=f'Output filename. Defaults to {DEFAULT_SAFE_OUTPUT} when passing --safetensors.')
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true', parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
help='Force not using ema version of weights (if present)') help='Force not using ema version of weights (if present)')
parser.add_argument('--no-sort', dest='no_sort', action='store_true', parser.add_argument('--no-sort', dest='no_sort', action='store_true',
help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant') help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant')
parser.add_argument('-n', type=int, default=10, metavar='N', parser.add_argument('-n', type=int, default=10, metavar='N',
help='Number of checkpoints to average') help='Number of checkpoints to average')
parser.add_argument('--safetensors', action='store_true',
help='Save weights using safetensors instead of the default torch way (pickle).')
def checkpoint_metric(checkpoint_path): def checkpoint_metric(checkpoint_path):
if not checkpoint_path or not os.path.isfile(checkpoint_path): if not checkpoint_path or not os.path.isfile(checkpoint_path):
@ -55,6 +60,15 @@ def main():
# by default sort by checkpoint metric (if present) and avg top n checkpoints # by default sort by checkpoint metric (if present) and avg top n checkpoints
args.sort = not args.no_sort args.sort = not args.no_sort
if args.safetensors and args.output == DEFAULT_OUTPUT:
# Default path changes if using safetensors
args.output = DEFAULT_SAFE_OUTPUT
if args.safetensors and not args.output.endswith(".safetensors"):
print(
"Warning: saving weights as safetensors but output file extension is not "
f"set to '.safetensors': {args.output}"
)
if os.path.exists(args.output): if os.path.exists(args.output):
print("Error: Output filename ({}) already exists.".format(args.output)) print("Error: Output filename ({}) already exists.".format(args.output))
exit(1) exit(1)
@ -107,10 +121,13 @@ def main():
v = v.clamp(float32_info.min, float32_info.max) v = v.clamp(float32_info.min, float32_info.max)
final_state_dict[k] = v.to(dtype=torch.float32) final_state_dict[k] = v.to(dtype=torch.float32)
try: if args.safetensors:
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False) safetensors.torch.save_file(final_state_dict, args.output)
except: else:
torch.save(final_state_dict, args.output) try:
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False)
except:
torch.save(final_state_dict, args.output)
with open(args.output, 'rb') as f: with open(args.output, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest() sha_hash = hashlib.sha256(f.read()).hexdigest()

@ -11,8 +11,8 @@ import torch
import argparse import argparse
import os import os
import hashlib import hashlib
import safetensors.torch
import shutil import shutil
from collections import OrderedDict
from timm.models import load_state_dict from timm.models import load_state_dict
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner') parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
@ -24,6 +24,8 @@ parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
help='use ema version of weights if present') help='use ema version of weights if present')
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true', parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint') help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
parser.add_argument('--safetensors', action='store_true',
help='Save weights using safetensors instead of the default torch way (pickle).')
_TEMP_NAME = './_checkpoint.pth' _TEMP_NAME = './_checkpoint.pth'
@ -35,10 +37,10 @@ def main():
print("Error: Output filename ({}) already exists.".format(args.output)) print("Error: Output filename ({}) already exists.".format(args.output))
exit(1) exit(1)
clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn) clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn, safe_serialization=args.safetensors)
def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False): def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, safe_serialization: bool=False):
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save # Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
if checkpoint and os.path.isfile(checkpoint): if checkpoint and os.path.isfile(checkpoint):
print("=> Loading checkpoint '{}'".format(checkpoint)) print("=> Loading checkpoint '{}'".format(checkpoint))
@ -53,10 +55,13 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False):
new_state_dict[name] = v new_state_dict[name] = v
print("=> Loaded state_dict from '{}'".format(checkpoint)) print("=> Loaded state_dict from '{}'".format(checkpoint))
try: if safe_serialization:
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False) safetensors.torch.save_file(new_state_dict, _TEMP_NAME)
except: else:
torch.save(new_state_dict, _TEMP_NAME) try:
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
except:
torch.save(new_state_dict, _TEMP_NAME)
with open(_TEMP_NAME, 'rb') as f: with open(_TEMP_NAME, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest() sha_hash = hashlib.sha256(f.read()).hexdigest()
@ -67,7 +72,7 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False):
else: else:
checkpoint_root = '' checkpoint_root = ''
checkpoint_base = os.path.splitext(checkpoint)[0] checkpoint_base = os.path.splitext(checkpoint)[0]
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + '.pth' final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + ('.safetensors' if safe_serialization else '.pth')
shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename)) shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename))
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash)) print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
return final_filename return final_filename

@ -2,3 +2,4 @@ torch>=1.7
torchvision torchvision
pyyaml pyyaml
huggingface_hub huggingface_hub
safetensors>=0.2

@ -7,6 +7,7 @@ import os
from collections import OrderedDict from collections import OrderedDict
import torch import torch
import safetensors.torch
import timm.models._builder import timm.models._builder
@ -26,7 +27,12 @@ def clean_state_dict(state_dict):
def load_state_dict(checkpoint_path, use_ema=True): def load_state_dict(checkpoint_path, use_ema=True):
if checkpoint_path and os.path.isfile(checkpoint_path): if checkpoint_path and os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path, map_location='cpu') # Check if safetensors or not and load weights accordingly
if str(checkpoint_path).endswith(".safetensors"):
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
else:
checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict_key = '' state_dict_key = ''
if isinstance(checkpoint, dict): if isinstance(checkpoint, dict):
if use_ema and checkpoint.get('state_dict_ema', None) is not None: if use_ema and checkpoint.get('state_dict_ema', None) is not None:

@ -2,19 +2,25 @@ import hashlib
import json import json
import logging import logging
import os import os
import sys
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Optional, Union from typing import Iterable, Optional, Union
import torch import torch
from torch.hub import HASH_REGEX, download_url_to_file, urlparse from torch.hub import HASH_REGEX, download_url_to_file, urlparse
import safetensors.torch
try: try:
from torch.hub import get_dir from torch.hub import get_dir
except ImportError: except ImportError:
from torch.hub import _get_torch_home as get_dir from torch.hub import _get_torch_home as get_dir
if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
from timm import __version__ from timm import __version__
from timm.models._pretrained import filter_pretrained_cfg from timm.models._pretrained import filter_pretrained_cfg
@ -35,6 +41,9 @@ _logger = logging.getLogger(__name__)
__all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf', __all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'load_model_config_from_hf',
'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub'] 'load_state_dict_from_hf', 'save_for_hf', 'push_to_hf_hub']
# Default name for a weights file hosted on the Huggingface Hub.
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
def get_cache_dir(child_dir=''): def get_cache_dir(child_dir=''):
""" """
@ -150,11 +159,23 @@ def load_model_config_from_hf(model_id: str):
return pretrained_cfg, model_name return pretrained_cfg, model_name
def load_state_dict_from_hf(model_id: str, filename: str = 'pytorch_model.bin'): def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
assert has_hf_hub(True) assert has_hf_hub(True)
cached_file = download_from_hf(model_id, filename) hf_model_id, hf_revision = hf_split(model_id)
state_dict = torch.load(cached_file, map_location='cpu')
return state_dict # Look for .safetensors alternatives and load from it if it exists
for safe_filename in _get_safe_alternatives(filename):
try:
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
_logger.warning(f"[{model_id}] Safe alternative available for '{filename}' (as '{safe_filename}'). Loading weights using safetensors.")
return safetensors.torch.load_file(cached_safe_file, device="cpu")
except EntryNotFoundError:
pass
# Otherwise, load using pytorch.load
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
_logger.warning(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
return torch.load(cached_file, map_location='cpu')
def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None): def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None):
@ -195,13 +216,22 @@ def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = N
json.dump(hf_config, f, indent=2) json.dump(hf_config, f, indent=2)
def save_for_hf(model, save_directory: str, model_config: Optional[dict] = None): def save_for_hf(
model,
save_directory: str,
model_config: Optional[dict] = None,
safe_serialization: Union[bool, Literal["both"]] = False
):
assert has_hf_hub(True) assert has_hf_hub(True)
save_directory = Path(save_directory) save_directory = Path(save_directory)
save_directory.mkdir(exist_ok=True, parents=True) save_directory.mkdir(exist_ok=True, parents=True)
weights_path = save_directory / 'pytorch_model.bin' # Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
torch.save(model.state_dict(), weights_path) tensors = model.state_dict()
if safe_serialization is True or safe_serialization == "both":
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
if safe_serialization is False or safe_serialization == "both":
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
config_path = save_directory / 'config.json' config_path = save_directory / 'config.json'
save_config_for_hf(model, config_path, model_config=model_config) save_config_for_hf(model, config_path, model_config=model_config)
@ -217,7 +247,15 @@ def push_to_hf_hub(
create_pr: bool = False, create_pr: bool = False,
model_config: Optional[dict] = None, model_config: Optional[dict] = None,
model_card: Optional[dict] = None, model_card: Optional[dict] = None,
safe_serialization: Union[bool, Literal["both"]] = False
): ):
"""
Arguments:
(...)
safe_serialization (`bool` or `"both"`, *optional*, defaults to `False`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
Can be set to `"both"` in order to push both safe and unsafe weights.
"""
# Create repo if it doesn't exist yet # Create repo if it doesn't exist yet
repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True) repo_url = create_repo(repo_id, token=token, private=private, exist_ok=True)
@ -236,7 +274,7 @@ def push_to_hf_hub(
# Dump model and push to Hub # Dump model and push to Hub
with TemporaryDirectory() as tmpdir: with TemporaryDirectory() as tmpdir:
# Save model weights and config. # Save model weights and config.
save_for_hf(model, tmpdir, model_config=model_config) save_for_hf(model, tmpdir, model_config=model_config, safe_serialization=safe_serialization)
# Add readme if it does not exist # Add readme if it does not exist
if not has_readme: if not has_readme:
@ -302,3 +340,15 @@ def generate_readme(model_card: dict, model_name: str):
for c in citations: for c in citations:
readme_text += f"```bibtex\n{c}\n```\n" readme_text += f"```bibtex\n{c}\n```\n"
return readme_text return readme_text
def _get_safe_alternatives(filename: str) -> Iterable[str]:
"""Returns potential safetensors alternatives for a given filename.
Use case:
When downloading a model from the Huggingface Hub, we first look if a .safetensors file exists and if yes, we use it.
Main use case is filename "pytorch_model.bin" => check for "model.safetensors" or "pytorch_model.safetensors".
"""
if filename == HF_WEIGHTS_NAME:
yield HF_SAFE_WEIGHTS_NAME
if filename.endswith(".bin"):
yield filename[:-4] + ".safetensors"
Loading…
Cancel
Save