From 9305313291ab1966b093abde83d78e1e7e15186d Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 13 Oct 2020 12:58:04 -0700 Subject: [PATCH] Default to old checkpoint format for now, still want compatibility with older torch ver for released models --- avg_checkpoints.py | 6 +++++- clean_checkpoint.py | 6 +++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/avg_checkpoints.py b/avg_checkpoints.py index feeac8af..a6921224 100755 --- a/avg_checkpoints.py +++ b/avg_checkpoints.py @@ -103,7 +103,11 @@ def main(): v = v.clamp(float32_info.min, float32_info.max) final_state_dict[k] = v.to(dtype=torch.float32) - torch.save(final_state_dict, args.output) + try: + torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False) + except: + torch.save(final_state_dict, args.output) + with open(args.output, 'rb') as f: sha_hash = hashlib.sha256(f.read()).hexdigest() print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash)) diff --git a/clean_checkpoint.py b/clean_checkpoint.py index af67f3b9..94f184d1 100755 --- a/clean_checkpoint.py +++ b/clean_checkpoint.py @@ -57,7 +57,11 @@ def main(): new_state_dict[name] = v print("=> Loaded state_dict from '{}'".format(args.checkpoint)) - torch.save(new_state_dict, _TEMP_NAME) + try: + torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False) + except: + torch.save(new_state_dict, _TEMP_NAME) + with open(_TEMP_NAME, 'rb') as f: sha_hash = hashlib.sha256(f.read()).hexdigest()