Default to old checkpoint format for now, still want compatibility with older torch ver for released models

pull/250/head
Ross Wightman 4 years ago
parent a4d8fea61e
commit 9305313291

@ -103,7 +103,11 @@ def main():
v = v.clamp(float32_info.min, float32_info.max)
final_state_dict[k] = v.to(dtype=torch.float32)
torch.save(final_state_dict, args.output)
try:
torch.save(final_state_dict, args.output, _use_new_zipfile_serialization=False)
except:
torch.save(final_state_dict, args.output)
with open(args.output, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()
print("=> Saved state_dict to '{}, SHA256: {}'".format(args.output, sha_hash))

@ -57,7 +57,11 @@ def main():
new_state_dict[name] = v
print("=> Loaded state_dict from '{}'".format(args.checkpoint))
torch.save(new_state_dict, _TEMP_NAME)
try:
torch.save(new_state_dict, _TEMP_NAME, _use_new_zipfile_serialization=False)
except:
torch.save(new_state_dict, _TEMP_NAME)
with open(_TEMP_NAME, 'rb') as f:
sha_hash = hashlib.sha256(f.read()).hexdigest()

Loading…
Cancel
Save