From d0b45c9b4d6c38e5c93ede6e9b8fc676bcb9818e Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 18 Feb 2023 16:06:42 -0800 Subject: [PATCH] Make safetensor import option for now. Improve avg/clean checkpoints ext handling a bit (more consistent). --- avg_checkpoints.py | 52 ++++++++++++++++++++------------ clean_checkpoint.py | 66 ++++++++++++++++++++++++++++++----------- timm/models/_helpers.py | 7 ++++- timm/models/_hub.py | 64 ++++++++++++++++++++++++--------------- 4 files changed, 127 insertions(+), 62 deletions(-) diff --git a/avg_checkpoints.py b/avg_checkpoints.py index bdfa2265..6cedcb7d 100755 --- a/avg_checkpoints.py +++ b/avg_checkpoints.py @@ -17,10 +17,14 @@ import os import glob import hashlib from timm.models import load_state_dict -import safetensors.torch +try: + import safetensors.torch + _has_safetensors = True +except ImportError: + _has_safetensors = False -DEFAULT_OUTPUT = "./average.pth" -DEFAULT_SAFE_OUTPUT = "./average.safetensors" +DEFAULT_OUTPUT = "./averaged.pth" +DEFAULT_SAFE_OUTPUT = "./averaged.safetensors" parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager') parser.add_argument('--input', default='', type=str, metavar='PATH', @@ -38,6 +42,7 @@ parser.add_argument('-n', type=int, default=10, metavar='N', parser.add_argument('--safetensors', action='store_true', help='Save weights using safetensors instead of the default torch way (pickle).') + def checkpoint_metric(checkpoint_path): if not checkpoint_path or not os.path.isfile(checkpoint_path): return {} @@ -63,14 +68,20 @@ def main(): if args.safetensors and args.output == DEFAULT_OUTPUT: # Default path changes if using safetensors args.output = DEFAULT_SAFE_OUTPUT - if args.safetensors and not args.output.endswith(".safetensors"): + + output, output_ext = os.path.splitext(args.output) + if not output_ext: + output_ext = ('.safetensors' if args.safetensors else '.pth') + output = output + output_ext + + if args.safetensors and not output_ext == ".safetensors": print( "Warning: saving weights as safetensors but output file extension is not " f"set to '.safetensors': {args.output}" ) - if os.path.exists(args.output): - print("Error: Output filename ({}) already exists.".format(args.output)) + if os.path.exists(output): + print("Error: Output filename ({}) already exists.".format(output)) exit(1) pattern = args.input @@ -87,22 +98,27 @@ def main(): checkpoint_metrics.append((metric, c)) checkpoint_metrics = list(sorted(checkpoint_metrics)) checkpoint_metrics = checkpoint_metrics[-args.n:] - print("Selected checkpoints:") - [print(m, c) for m, c in checkpoint_metrics] + if checkpoint_metrics: + print("Selected checkpoints:") + [print(m, c) for m, c in checkpoint_metrics] avg_checkpoints = [c for m, c in checkpoint_metrics] else: avg_checkpoints = checkpoints - print("Selected checkpoints:") - [print(c) for c in checkpoints] + if avg_checkpoints: + print("Selected checkpoints:") + [print(c) for c in checkpoints] + + if not avg_checkpoints: + print('Error: No checkpoints found to average.') + exit(1) avg_state_dict = {} avg_counts = {} for c in avg_checkpoints: new_state_dict = load_state_dict(c, args.use_ema) if not new_state_dict: - print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint)) + print(f"Error: Checkpoint ({c}) doesn't exist") continue - for k, v in new_state_dict.items(): if k not in avg_state_dict: avg_state_dict[k] = v.clone().to(dtype=torch.float64) @@ -122,16 +138,14 @@ def main(): final_state_dict[k] = v.to(dtype=torch.float32) if args.safetensors: - safetensors.torch.save_file(final_state_dict, args.output) + assert _has_safetensors, "`pip install safetensors` to use .safetensors" + safetensors.torch.save_file(final_state_dict, output) else: - try: - torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False) - except: - torch.save(final_state_dict, args.output) + torch.save(final_state_dict, output) - with open(args.output, 'rb') as f: + with open(output, 'rb') as f: sha_hash = hashlib.sha256(f.read()).hexdigest() - print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash)) + print(f"=> Saved state_dict to '{output}, SHA256: {sha_hash}'") if __name__ == '__main__': diff --git a/clean_checkpoint.py b/clean_checkpoint.py index d18951bc..c2da0642 100755 --- a/clean_checkpoint.py +++ b/clean_checkpoint.py @@ -11,9 +11,14 @@ import torch import argparse import os import hashlib -import safetensors.torch import shutil +import tempfile from timm.models import load_state_dict +try: + import safetensors.torch + _has_safetensors = True +except ImportError: + _has_safetensors = False parser = argparse.ArgumentParser(description='PyTorch Checkpoint Cleaner') parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', @@ -22,13 +27,13 @@ parser.add_argument('--output', default='', type=str, metavar='PATH', help='output path') parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true', help='use ema version of weights if present') +parser.add_argument('--no-hash', dest='no_hash', action='store_true', + help='no hash in output filename') parser.add_argument('--clean-aux-bn', dest='clean_aux_bn', action='store_true', help='remove auxiliary batch norm layers (from SplitBN training) from checkpoint') parser.add_argument('--safetensors', action='store_true', help='Save weights using safetensors instead of the default torch way (pickle).') -_TEMP_NAME = './_checkpoint.pth' - def main(): args = parser.parse_args() @@ -37,10 +42,24 @@ def main(): print("Error: Output filename ({}) already exists.".format(args.output)) exit(1) - clean_checkpoint(args.checkpoint, args.output, not args.no_use_ema, args.clean_aux_bn, safe_serialization=args.safetensors) + clean_checkpoint( + args.checkpoint, + args.output, + not args.no_use_ema, + args.no_hash, + args.clean_aux_bn, + safe_serialization=args.safetensors, + ) -def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, safe_serialization: bool=False): +def clean_checkpoint( + checkpoint, + output, + use_ema=True, + no_hash=False, + clean_aux_bn=False, + safe_serialization: bool=False, +): # Load an existing checkpoint to CPU, strip everything but the state_dict and re-save if checkpoint and os.path.isfile(checkpoint): print("=> Loading checkpoint '{}'".format(checkpoint)) @@ -55,25 +74,36 @@ def clean_checkpoint(checkpoint, output='', use_ema=True, clean_aux_bn=False, sa new_state_dict[name] = v print("=> Loaded state_dict from '{}'".format(checkpoint)) + ext = '' + if output: + checkpoint_root, checkpoint_base = os.path.split(output) + checkpoint_base, ext = os.path.splitext(checkpoint_base) + else: + checkpoint_root = '' + checkpoint_base = os.path.split(checkpoint)[1] + checkpoint_base = os.path.splitext(checkpoint_base)[0] + + temp_filename = '__' + checkpoint_base if safe_serialization: - safetensors.torch.save_file(new_state_dict, _TEMP_NAME) + assert _has_safetensors, "`pip install safetensors` to use .safetensors" + safetensors.torch.save_file(new_state_dict, temp_filename) else: - try: - torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False) - except: - torch.save(new_state_dict, _TEMP_NAME) + torch.save(new_state_dict, temp_filename) - with open(_TEMP_NAME, 'rb') as f: + with open(temp_filename, 'rb') as f: sha_hash = hashlib.sha256(f.read()).hexdigest() - if output: - checkpoint_root, checkpoint_base = os.path.split(output) - checkpoint_base = os.path.splitext(checkpoint_base)[0] + if ext: + final_ext = ext else: - checkpoint_root = '' - checkpoint_base = os.path.splitext(checkpoint)[0] - final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + ('.safetensors' if safe_serialization else '.pth') - shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename)) + final_ext = ('.safetensors' if safe_serialization else '.pth') + + if no_hash: + final_filename = checkpoint_base + final_ext + else: + final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + final_ext + + shutil.move(temp_filename, os.path.join(checkpoint_root, final_filename)) print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash)) return final_filename else: diff --git a/timm/models/_helpers.py b/timm/models/_helpers.py index adae77eb..89ce4318 100644 --- a/timm/models/_helpers.py +++ b/timm/models/_helpers.py @@ -7,7 +7,11 @@ import os from collections import OrderedDict import torch -import safetensors.torch +try: + import safetensors.torch + _has_safetensors = True +except ImportError: + _has_safetensors = False import timm.models._builder @@ -29,6 +33,7 @@ def load_state_dict(checkpoint_path, use_ema=True): if checkpoint_path and os.path.isfile(checkpoint_path): # Check if safetensors or not and load weights accordingly if str(checkpoint_path).endswith(".safetensors"): + assert _has_safetensors, "`pip install safetensors` to use .safetensors" checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu') else: checkpoint = torch.load(checkpoint_path, map_location='cpu') diff --git a/timm/models/_hub.py b/timm/models/_hub.py index 464d6b06..270f8568 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -7,15 +7,21 @@ from functools import partial from pathlib import Path from tempfile import TemporaryDirectory from typing import Iterable, Optional, Union + import torch from torch.hub import HASH_REGEX, download_url_to_file, urlparse -import safetensors.torch try: from torch.hub import get_dir except ImportError: from torch.hub import _get_torch_home as get_dir +try: + import safetensors.torch + _has_safetensors = True +except ImportError: + _has_safetensors = False + if sys.version_info >= (3, 8): from typing import Literal else: @@ -45,6 +51,7 @@ __all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'l HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version + def get_cache_dir(child_dir=''): """ Returns the location of the directory where models are cached (and creates it if necessary). @@ -164,21 +171,28 @@ def load_state_dict_from_hf(model_id: str, filename: str = HF_WEIGHTS_NAME): hf_model_id, hf_revision = hf_split(model_id) # Look for .safetensors alternatives and load from it if it exists - for safe_filename in _get_safe_alternatives(filename): - try: - cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision) - _logger.info(f"[{model_id}] Safe alternative available for '{filename}' (as '{safe_filename}'). Loading weights using safetensors.") - return safetensors.torch.load_file(cached_safe_file, device="cpu") - except EntryNotFoundError: - pass + if _has_safetensors: + for safe_filename in _get_safe_alternatives(filename): + try: + cached_safe_file = hf_hub_download(repo_id=hf_model_id, filename=safe_filename, revision=hf_revision) + _logger.info( + f"[{model_id}] Safe alternative available for '{filename}' " + f"(as '{safe_filename}'). Loading weights using safetensors.") + return safetensors.torch.load_file(cached_safe_file, device="cpu") + except EntryNotFoundError: + pass # Otherwise, load using pytorch.load cached_file = hf_hub_download(hf_model_id, filename=filename, revision=hf_revision) - _logger.info(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.") + _logger.debug(f"[{model_id}] Safe alternative not found for '{filename}'. Loading weights using default pytorch.") return torch.load(cached_file, map_location='cpu') -def save_config_for_hf(model, config_path: str, model_config: Optional[dict] = None): +def save_config_for_hf( + model, + config_path: str, + model_config: Optional[dict] = None +): model_config = model_config or {} hf_config = {} pretrained_cfg = filter_pretrained_cfg(model.pretrained_cfg, remove_source=True, remove_null=True) @@ -220,8 +234,8 @@ def save_for_hf( model, save_directory: str, model_config: Optional[dict] = None, - safe_serialization: Union[bool, Literal["both"]] = False - ): + safe_serialization: Union[bool, Literal["both"]] = False, +): assert has_hf_hub(True) save_directory = Path(save_directory) save_directory.mkdir(exist_ok=True, parents=True) @@ -229,6 +243,7 @@ def save_for_hf( # Save model weights, either safely (using safetensors), or using legacy pytorch approach or both. tensors = model.state_dict() if safe_serialization is True or safe_serialization == "both": + assert _has_safetensors, "`pip install safetensors` to use .safetensors" safetensors.torch.save_file(tensors, save_directory / HF_SAFE_WEIGHTS_NAME) if safe_serialization is False or safe_serialization == "both": torch.save(tensors, save_directory / HF_WEIGHTS_NAME) @@ -238,16 +253,16 @@ def save_for_hf( def push_to_hf_hub( - model, - repo_id: str, - commit_message: str = 'Add model', - token: Optional[str] = None, - revision: Optional[str] = None, - private: bool = False, - create_pr: bool = False, - model_config: Optional[dict] = None, - model_card: Optional[dict] = None, - safe_serialization: Union[bool, Literal["both"]] = False + model, + repo_id: str, + commit_message: str = 'Add model', + token: Optional[str] = None, + revision: Optional[str] = None, + private: bool = False, + create_pr: bool = False, + model_config: Optional[dict] = None, + model_card: Optional[dict] = None, + safe_serialization: Union[bool, Literal["both"]] = False, ): """ Arguments: @@ -341,6 +356,7 @@ def generate_readme(model_card: dict, model_name: str): readme_text += f"```bibtex\n{c}\n```\n" return readme_text + def _get_safe_alternatives(filename: str) -> Iterable[str]: """Returns potential safetensors alternatives for a given filename. @@ -350,5 +366,5 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]: """ if filename == HF_WEIGHTS_NAME: yield HF_SAFE_WEIGHTS_NAME - if filename.endswith(".bin"): - yield filename[:-4] + ".safetensors" \ No newline at end of file + if filename != HF_WEIGHTS_NAME and filename.endswith(".bin"): + return filename[:-4] + ".safetensors"