Make safetensor import option for now. Improve avg/clean checkpoints ext handling a bit (more consistent).

pull/1553/merge
Ross Wightman 2 years ago
parent 7d9e321b76
commit d0b45c9b4d

@ -17,10 +17,14 @@ import os
import glob
import hashlib
from timm.models import load_state_dict
import safetensors.torch
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
DEFAULT_OUTPUT = "./average.pth"
DEFAULT_SAFE_OUTPUT = "./average.safetensors"
DEFAULT_OUTPUT = "./averaged.pth"
DEFAULT_SAFE_OUTPUT = "./averaged.safetensors"
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
parser.add_argument('--input', default='', type=str, metavar='PATH',
@ -38,6 +42,7 @@ parser.add_argument('-n', type=int, default=10, metavar='N',
parser.add_argument('--safetensors', action='store_true',
help='Save weights using safetensors instead of the default torch way (pickle).')
def checkpoint_metric(checkpoint_path):
if not checkpoint_path or not os.path.isfile(checkpoint_path):
return {}
@ -63,14 +68,20 @@ def main():
if args.safetensors and args.output == DEFAULT_OUTPUT:
# Default path changes if using safetensors
args.output = DEFAULT_SAFE_OUTPUT
if args.safetensors and not args.output.endswith(".safetensors"):
output, output_ext = os.path.splitext(args.output)
if not output_ext:
output_ext = ('.safetensors' if args.safetensors else '.pth')
output = output + output_ext
if args.safetensors and not output_ext == ".safetensors":
print(
"Warning: saving weights as safetensors but output file extension is not "
f"set to '.safetensors': {args.output}"
)
if os.path.exists(args.output):
print("Error: Output filename ({}) already exists.".format(args.output))
if os.path.exists(output):
print("Error: Output filename ({}) already exists.".format(output))
exit(1)
pattern = args.input
@ -87,22 +98,27 @@ def main():
checkpoint_metrics.append((metric, c))
checkpoint_metrics = list(sorted(checkpoint_metrics))
checkpoint_metrics = checkpoint_metrics[-args.n:]
print("Selected checkpoints:")
[print(m, c) for m, c in checkpoint_metrics]
if checkpoint_metrics:
print("Selected checkpoints:")
[print(m, c) for m, c in checkpoint_metrics]
avg_checkpoints = [c for m, c in checkpoint_metrics]
else:
avg_checkpoints = checkpoints
print("Selected checkpoints:")
[print(c) for c in checkpoints]
if avg_checkpoints:
print("Selected checkpoints:")
[print(c) for c in checkpoints]
if not avg_checkpoints:
print('Error: No checkpoints found to average.')
exit(1)
avg_state_dict = {}
avg_counts = {}
for c in avg_checkpoints:
new_state_dict = load_state_dict(c, args.use_ema)
if not new_state_dict:
print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint))
print(f"Error: Checkpoint ({c}) doesn't exist")
continue
for k, v in new_state_dict.items():
if k not in avg_state_dict:
avg_state_dict[k] = v.clone().to(dtype=torch.float64)
@ -122,16 +138,14 @@ def main():
final_state_dict[k] = v.to(dtype=torch.float32)
if args.safetensors:
safetensors.torch.save_file(final_state_dict, args.output)
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
safetensors.torch.save_file(final_state_dict, output)
else:
try:
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False)
except:
torch.save(final_state_dict, args.output)
torch.save(final_state_dict, output)
with open(args.output, 'rb') as f:
with open(output, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()
print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash))
print(f"=> Saved state_dict to '{output}, SHA256: {sha_hash}'")
if __name__ == '__main__':

@ -11,9 +11,14 @@ import torch
import argparse
import os
import hashlib
import safetensors.torch
import shutil
import tempfile
from timm.models import load_state_dict
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
@ -22,13 +27,13 @@ parser.add_argument('--output', default='', type=str, metavar='PATH',
help='output path')
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
help='use ema version of weights if present')
parser.add_argument('--no-hash', dest='no_hash', action='store_true',
help='no hash in output filename')
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
parser.add_argument('--safetensors', action='store_true',
help='Save weights using safetensors instead of the default torch way (pickle).')
_TEMP_NAME = './_checkpoint.pth'
def main():
args = parser.parse_args()
@ -37,10 +42,24 @@ def main():
print("Error: Output filename ({}) already exists.".format(args.output))
exit(1)
clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn, safe_serialization=args.safetensors)
clean_checkpoint(
args.checkpoint,
args.output,
not args.no_use_ema,
args.no_hash,
args.clean_aux_bn,
safe_serialization=args.safetensors,
)
def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, safe_serialization: bool=False):
def clean_checkpoint(
checkpoint,
output,
use_ema=True,
no_hash=False,
clean_aux_bn=False,
safe_serialization: bool=False,
):
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
if checkpoint and os.path.isfile(checkpoint):
print("=> Loading checkpoint '{}'".format(checkpoint))
@ -55,25 +74,36 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, sa
new_state_dict[name] = v
print("=> Loaded state_dict from '{}'".format(checkpoint))
ext = ''
if output:
checkpoint_root, checkpoint_base = os.path.split(output)
checkpoint_base, ext = os.path.splitext(checkpoint_base)
else:
checkpoint_root = ''
checkpoint_base = os.path.split(checkpoint)[1]
checkpoint_base = os.path.splitext(checkpoint_base)[0]
temp_filename = '__' + checkpoint_base
if safe_serialization:
safetensors.torch.save_file(new_state_dict, _TEMP_NAME)
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
safetensors.torch.save_file(new_state_dict, temp_filename)
else:
try:
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
except:
torch.save(new_state_dict, _TEMP_NAME)
torch.save(new_state_dict, temp_filename)
with open(_TEMP_NAME, 'rb') as f:
with open(temp_filename, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()
if output:
checkpoint_root, checkpoint_base = os.path.split(output)
checkpoint_base = os.path.splitext(checkpoint_base)[0]
if ext:
final_ext = ext
else:
checkpoint_root = ''
checkpoint_base = os.path.splitext(checkpoint)[0]
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + ('.safetensors' if safe_serialization else '.pth')
shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename))
final_ext = ('.safetensors' if safe_serialization else '.pth')
if no_hash:
final_filename = checkpoint_base + final_ext
else:
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + final_ext
shutil.move(temp_filename, os.path.join(checkpoint_root, final_filename))
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
return final_filename
else:

@ -7,7 +7,11 @@ import os
from collections import OrderedDict
import torch
import safetensors.torch
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
import timm.models._builder
@ -29,6 +33,7 @@ def load_state_dict(checkpoint_path, use_ema=True):
if checkpoint_path and os.path.isfile(checkpoint_path):
# Check if safetensors or not and load weights accordingly
if str(checkpoint_path).endswith(".safetensors"):
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
else:
checkpoint = torch.load(checkpoint_path, map_location='cpu')

@ -7,15 +7,21 @@ from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Iterable, Optional, Union
import torch
from torch.hub import HASH_REGEX, download_url_to_file, urlparse
import safetensors.torch
try:
from torch.hub import get_dir
except ImportError:
from torch.hub import _get_torch_home as get_dir
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
if sys.version_info >= (3, 8):
from typing import Literal
else:
@ -45,6 +51,7 @@ __all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'l
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
def get_cache_dir(child_dir=''):
"""
Returns the location of the directory where models are cached (and creates it if necessary).
@ -164,21 +171,28 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
hf_model_id, hf_revision = hf_split(model_id)
# Look for .safetensors alternatives and load from it if it exists
for safe_filename in _get_safe_alternatives(filename):
try:
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
_logger.info(f"[{model_id}] Safe alternative available for '{filename}' (as '{safe_filename}'). Loading weights using safetensors.")
return safetensors.torch.load_file(cached_safe_file, device="cpu")
except EntryNotFoundError:
pass
if _has_safetensors:
for safe_filename in _get_safe_alternatives(filename):
try:
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
_logger.info(
f"[{model_id}] Safe alternative available for '{filename}' "
f"(as '{safe_filename}'). Loading weights using safetensors.")
return safetensors.torch.load_file(cached_safe_file, device="cpu")
except EntryNotFoundError:
pass
# Otherwise, load using pytorch.load
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
_logger.info(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
_logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
return torch.load(cached_file, map_location='cpu')
def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None):
def save_config_for_hf(
model,
config_path: str,
model_config: Optional[dict] = None
):
model_config = model_config or {}
hf_config = {}
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
@ -220,8 +234,8 @@ def save_for_hf(
model,
save_directory: str,
model_config: Optional[dict] = None,
safe_serialization: Union[bool, Literal["both"]] = False
):
safe_serialization: Union[bool, Literal["both"]] = False,
):
assert has_hf_hub(True)
save_directory = Path(save_directory)
save_directory.mkdir(exist_ok=True, parents=True)
@ -229,6 +243,7 @@ def save_for_hf(
# Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
tensors = model.state_dict()
if safe_serialization is True or safe_serialization == "both":
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
if safe_serialization is False or safe_serialization == "both":
torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
@ -238,16 +253,16 @@ def save_for_hf(
def push_to_hf_hub(
model,
repo_id: str,
commit_message: str = 'Add model',
token: Optional[str] = None,
revision: Optional[str] = None,
private: bool = False,
create_pr: bool = False,
model_config: Optional[dict] = None,
model_card: Optional[dict] = None,
safe_serialization: Union[bool, Literal["both"]] = False
model,
repo_id: str,
commit_message: str = 'Add model',
token: Optional[str] = None,
revision: Optional[str] = None,
private: bool = False,
create_pr: bool = False,
model_config: Optional[dict] = None,
model_card: Optional[dict] = None,
safe_serialization: Union[bool, Literal["both"]] = False,
):
"""
Arguments:
@ -341,6 +356,7 @@ def generate_readme(model_card: dict, model_name: str):
readme_text += f"```bibtex\n{c}\n```\n"
return readme_text
def _get_safe_alternatives(filename: str) -> Iterable[str]:
"""Returns potential safetensors alternatives for a given filename.
@ -350,5 +366,5 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]:
"""
if filename == HF_WEIGHTS_NAME:
yield HF_SAFE_WEIGHTS_NAME
if filename.endswith(".bin"):
yield filename[:-4] + ".safetensors"
if filename != HF_WEIGHTS_NAME and filename.endswith(".bin"):
return filename[:-4] + ".safetensors"

Loading…
Cancel
Save