Make safetensor import option for now. Improve avg/clean checkpoints ext handling a bit (more consistent).

pull/1553/merge
Ross Wightman 2 years ago
parent 7d9e321b76
commit d0b45c9b4d

@ -17,10 +17,14 @@ import os
import glob import glob
import hashlib import hashlib
from timm.models import load_state_dict from timm.models import load_state_dict
try:
import safetensors.torch import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
DEFAULT_OUTPUT = "./average.pth" DEFAULT_OUTPUT = "./averaged.pth"
DEFAULT_SAFE_OUTPUT = "./average.safetensors" DEFAULT_SAFE_OUTPUT = "./averaged.safetensors"
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager') parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
parser.add_argument('--input', default='', type=str, metavar='PATH', parser.add_argument('--input', default='', type=str, metavar='PATH',
@ -38,6 +42,7 @@ parser.add_argument('-n', type=int, default=10, metavar='N',
parser.add_argument('--safetensors', action='store_true', parser.add_argument('--safetensors', action='store_true',
help='Save weights using safetensors instead of the default torch way (pickle).') help='Save weights using safetensors instead of the default torch way (pickle).')
def checkpoint_metric(checkpoint_path): def checkpoint_metric(checkpoint_path):
if not checkpoint_path or not os.path.isfile(checkpoint_path): if not checkpoint_path or not os.path.isfile(checkpoint_path):
return {} return {}
@ -63,14 +68,20 @@ def main():
if args.safetensors and args.output == DEFAULT_OUTPUT: if args.safetensors and args.output == DEFAULT_OUTPUT:
# Default path changes if using safetensors # Default path changes if using safetensors
args.output = DEFAULT_SAFE_OUTPUT args.output = DEFAULT_SAFE_OUTPUT
if args.safetensors and not args.output.endswith(".safetensors"):
output, output_ext = os.path.splitext(args.output)
if not output_ext:
output_ext = ('.safetensors' if args.safetensors else '.pth')
output = output + output_ext
if args.safetensors and not output_ext == ".safetensors":
print( print(
"Warning: saving weights as safetensors but output file extension is not " "Warning: saving weights as safetensors but output file extension is not "
f"set to '.safetensors': {args.output}" f"set to '.safetensors': {args.output}"
) )
if os.path.exists(args.output): if os.path.exists(output):
print("Error: Output filename ({}) already exists.".format(args.output)) print("Error: Output filename ({}) already exists.".format(output))
exit(1) exit(1)
pattern = args.input pattern = args.input
@ -87,22 +98,27 @@ def main():
checkpoint_metrics.append((metric, c)) checkpoint_metrics.append((metric, c))
checkpoint_metrics = list(sorted(checkpoint_metrics)) checkpoint_metrics = list(sorted(checkpoint_metrics))
checkpoint_metrics = checkpoint_metrics[-args.n:] checkpoint_metrics = checkpoint_metrics[-args.n:]
if checkpoint_metrics:
print("Selected checkpoints:") print("Selected checkpoints:")
[print(m, c) for m, c in checkpoint_metrics] [print(m, c) for m, c in checkpoint_metrics]
avg_checkpoints = [c for m, c in checkpoint_metrics] avg_checkpoints = [c for m, c in checkpoint_metrics]
else: else:
avg_checkpoints = checkpoints avg_checkpoints = checkpoints
if avg_checkpoints:
print("Selected checkpoints:") print("Selected checkpoints:")
[print(c) for c in checkpoints] [print(c) for c in checkpoints]
if not avg_checkpoints:
print('Error: No checkpoints found to average.')
exit(1)
avg_state_dict = {} avg_state_dict = {}
avg_counts = {} avg_counts = {}
for c in avg_checkpoints: for c in avg_checkpoints:
new_state_dict = load_state_dict(c, args.use_ema) new_state_dict = load_state_dict(c, args.use_ema)
if not new_state_dict: if not new_state_dict:
print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint)) print(f"Error: Checkpoint ({c}) doesn't exist")
continue continue
for k, v in new_state_dict.items(): for k, v in new_state_dict.items():
if k not in avg_state_dict: if k not in avg_state_dict:
avg_state_dict[k] = v.clone().to(dtype=torch.float64) avg_state_dict[k] = v.clone().to(dtype=torch.float64)
@ -122,16 +138,14 @@ def main():
final_state_dict[k] = v.to(dtype=torch.float32) final_state_dict[k] = v.to(dtype=torch.float32)
if args.safetensors: if args.safetensors:
safetensors.torch.save_file(final_state_dict, args.output) assert _has_safetensors, "`pip install safetensors` to use .safetensors"
safetensors.torch.save_file(final_state_dict, output)
else: else:
try: torch.save(final_state_dict, output)
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False)
except:
torch.save(final_state_dict, args.output)
with open(args.output, 'rb') as f: with open(output, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest() sha_hash = hashlib.sha256(f.read()).hexdigest()
print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash)) print(f"=> Saved state_dict to '{output}, SHA256: {sha_hash}'")
if __name__ == '__main__': if __name__ == '__main__':

@ -11,9 +11,14 @@ import torch
import argparse import argparse
import os import os
import hashlib import hashlib
import safetensors.torch
import shutil import shutil
import tempfile
from timm.models import load_state_dict from timm.models import load_state_dict
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner') parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner')
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
@ -22,13 +27,13 @@ parser.add_argument('--output', default='', type=str, metavar='PATH',
help='output path') help='output path')
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true', parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true',
help='use ema version of weights if present') help='use ema version of weights if present')
parser.add_argument('--no-hash', dest='no_hash', action='store_true',
help='no hash in output filename')
parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true', parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true',
help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint') help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint')
parser.add_argument('--safetensors', action='store_true', parser.add_argument('--safetensors', action='store_true',
help='Save weights using safetensors instead of the default torch way (pickle).') help='Save weights using safetensors instead of the default torch way (pickle).')
_TEMP_NAME = './_checkpoint.pth'
def main(): def main():
args = parser.parse_args() args = parser.parse_args()
@ -37,10 +42,24 @@ def main():
print("Error: Output filename ({}) already exists.".format(args.output)) print("Error: Output filename ({}) already exists.".format(args.output))
exit(1) exit(1)
clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn, safe_serialization=args.safetensors) clean_checkpoint(
args.checkpoint,
args.output,
not args.no_use_ema,
args.no_hash,
args.clean_aux_bn,
safe_serialization=args.safetensors,
)
def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, safe_serialization: bool=False): def clean_checkpoint(
checkpoint,
output,
use_ema=True,
no_hash=False,
clean_aux_bn=False,
safe_serialization: bool=False,
):
# Load an existing checkpoint to CPU, strip everything but the state_dict and re-save # Load an existing checkpoint to CPU, strip everything but the state_dict and re-save
if checkpoint and os.path.isfile(checkpoint): if checkpoint and os.path.isfile(checkpoint):
print("=> Loading checkpoint '{}'".format(checkpoint)) print("=> Loading checkpoint '{}'".format(checkpoint))
@ -55,25 +74,36 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, sa
new_state_dict[name] = v new_state_dict[name] = v
print("=> Loaded state_dict from '{}'".format(checkpoint)) print("=> Loaded state_dict from '{}'".format(checkpoint))
ext = ''
if output:
checkpoint_root, checkpoint_base = os.path.split(output)
checkpoint_base, ext = os.path.splitext(checkpoint_base)
else:
checkpoint_root = ''
checkpoint_base = os.path.split(checkpoint)[1]
checkpoint_base = os.path.splitext(checkpoint_base)[0]
temp_filename = '__' + checkpoint_base
if safe_serialization: if safe_serialization:
safetensors.torch.save_file(new_state_dict, _TEMP_NAME) assert _has_safetensors, "`pip install safetensors` to use .safetensors"
safetensors.torch.save_file(new_state_dict, temp_filename)
else: else:
try: torch.save(new_state_dict, temp_filename)
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
except:
torch.save(new_state_dict, _TEMP_NAME)
with open(_TEMP_NAME, 'rb') as f: with open(temp_filename, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest() sha_hash = hashlib.sha256(f.read()).hexdigest()
if output: if ext:
checkpoint_root, checkpoint_base = os.path.split(output) final_ext = ext
checkpoint_base = os.path.splitext(checkpoint_base)[0]
else: else:
checkpoint_root = '' final_ext = ('.safetensors' if safe_serialization else '.pth')
checkpoint_base = os.path.splitext(checkpoint)[0]
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + ('.safetensors' if safe_serialization else '.pth') if no_hash:
shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename)) final_filename = checkpoint_base + final_ext
else:
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + final_ext
shutil.move(temp_filename, os.path.join(checkpoint_root, final_filename))
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash)) print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
return final_filename return final_filename
else: else:

@ -7,7 +7,11 @@ import os
from collections import OrderedDict from collections import OrderedDict
import torch import torch
try:
import safetensors.torch import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
import timm.models._builder import timm.models._builder
@ -29,6 +33,7 @@ def load_state_dict(checkpoint_path, use_ema=True):
if checkpoint_path and os.path.isfile(checkpoint_path): if checkpoint_path and os.path.isfile(checkpoint_path):
# Check if safetensors or not and load weights accordingly # Check if safetensors or not and load weights accordingly
if str(checkpoint_path).endswith(".safetensors"): if str(checkpoint_path).endswith(".safetensors"):
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu') checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
else: else:
checkpoint = torch.load(checkpoint_path, map_location='cpu') checkpoint = torch.load(checkpoint_path, map_location='cpu')

@ -7,15 +7,21 @@ from functools import partial
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Iterable, Optional, Union from typing import Iterable, Optional, Union
import torch import torch
from torch.hub import HASH_REGEX, download_url_to_file, urlparse from torch.hub import HASH_REGEX, download_url_to_file, urlparse
import safetensors.torch
try: try:
from torch.hub import get_dir from torch.hub import get_dir
except ImportError: except ImportError:
from torch.hub import _get_torch_home as get_dir from torch.hub import _get_torch_home as get_dir
try:
import safetensors.torch
_has_safetensors = True
except ImportError:
_has_safetensors = False
if sys.version_info >= (3, 8): if sys.version_info >= (3, 8):
from typing import Literal from typing import Literal
else: else:
@ -45,6 +51,7 @@ __all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'l
HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl
HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version
def get_cache_dir(child_dir=''): def get_cache_dir(child_dir=''):
""" """
Returns the location of the directory where models are cached (and creates it if necessary). Returns the location of the directory where models are cached (and creates it if necessary).
@ -164,21 +171,28 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME):
hf_model_id, hf_revision = hf_split(model_id) hf_model_id, hf_revision = hf_split(model_id)
# Look for .safetensors alternatives and load from it if it exists # Look for .safetensors alternatives and load from it if it exists
if _has_safetensors:
for safe_filename in _get_safe_alternatives(filename): for safe_filename in _get_safe_alternatives(filename):
try: try:
cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision) cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision)
_logger.info(f"[{model_id}] Safe alternative available for '{filename}' (as '{safe_filename}'). Loading weights using safetensors.") _logger.info(
f"[{model_id}] Safe alternative available for '{filename}' "
f"(as '{safe_filename}'). Loading weights using safetensors.")
return safetensors.torch.load_file(cached_safe_file, device="cpu") return safetensors.torch.load_file(cached_safe_file, device="cpu")
except EntryNotFoundError: except EntryNotFoundError:
pass pass
# Otherwise, load using pytorch.load # Otherwise, load using pytorch.load
cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision) cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision)
_logger.info(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.") _logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.")
return torch.load(cached_file, map_location='cpu') return torch.load(cached_file, map_location='cpu')
def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None): def save_config_for_hf(
model,
config_path: str,
model_config: Optional[dict] = None
):
model_config = model_config or {} model_config = model_config or {}
hf_config = {} hf_config = {}
pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True) pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True)
@ -220,7 +234,7 @@ def save_for_hf(
model, model,
save_directory: str, save_directory: str,
model_config: Optional[dict] = None, model_config: Optional[dict] = None,
safe_serialization: Union[bool, Literal["both"]] = False safe_serialization: Union[bool, Literal["both"]] = False,
): ):
assert has_hf_hub(True) assert has_hf_hub(True)
save_directory = Path(save_directory) save_directory = Path(save_directory)
@ -229,6 +243,7 @@ def save_for_hf(
# Save model weights, either safely (using safetensors), or using legacy pytorch approach or both. # Save model weights, either safely (using safetensors), or using legacy pytorch approach or both.
tensors = model.state_dict() tensors = model.state_dict()
if safe_serialization is True or safe_serialization == "both": if safe_serialization is True or safe_serialization == "both":
assert _has_safetensors, "`pip install safetensors` to use .safetensors"
safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME) safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME)
if safe_serialization is False or safe_serialization == "both": if safe_serialization is False or safe_serialization == "both":
torch.save(tensors, save_directory / HF_WEIGHTS_NAME) torch.save(tensors, save_directory / HF_WEIGHTS_NAME)
@ -247,7 +262,7 @@ def push_to_hf_hub(
create_pr: bool = False, create_pr: bool = False,
model_config: Optional[dict] = None, model_config: Optional[dict] = None,
model_card: Optional[dict] = None, model_card: Optional[dict] = None,
safe_serialization: Union[bool, Literal["both"]] = False safe_serialization: Union[bool, Literal["both"]] = False,
): ):
""" """
Arguments: Arguments:
@ -341,6 +356,7 @@ def generate_readme(model_card: dict, model_name: str):
readme_text += f"```bibtex\n{c}\n```\n" readme_text += f"```bibtex\n{c}\n```\n"
return readme_text return readme_text
def _get_safe_alternatives(filename: str) -> Iterable[str]: def _get_safe_alternatives(filename: str) -> Iterable[str]:
"""Returns potential safetensors alternatives for a given filename. """Returns potential safetensors alternatives for a given filename.
@ -350,5 +366,5 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]:
""" """
if filename == HF_WEIGHTS_NAME: if filename == HF_WEIGHTS_NAME:
yield HF_SAFE_WEIGHTS_NAME yield HF_SAFE_WEIGHTS_NAME
if filename.endswith(".bin"): if filename != HF_WEIGHTS_NAME and filename.endswith(".bin"):
yield filename[:-4] + ".safetensors" return filename[:-4] + ".safetensors"

Loading…
Cancel
Save