|
|
|
@ -63,7 +63,7 @@ def load_state_dict(checkpoint_path, use_ema=True):
|
|
|
|
|
raise FileNotFoundError()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True):
|
|
|
|
|
def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False):
|
|
|
|
|
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
|
|
|
|
|
# numpy checkpoint, try to load via model specific load_pretrained fn
|
|
|
|
|
if hasattr(model, 'load_pretrained'):
|
|
|
|
@ -72,10 +72,28 @@ def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True):
|
|
|
|
|
raise NotImplementedError('Model cannot load numpy checkpoint')
|
|
|
|
|
return
|
|
|
|
|
state_dict = load_state_dict(checkpoint_path, use_ema)
|
|
|
|
|
if remap:
|
|
|
|
|
state_dict = remap_checkpoint(model, state_dict)
|
|
|
|
|
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
|
|
|
|
return incompatible_keys
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def remap_checkpoint(model, state_dict, allow_reshape=True):
|
|
|
|
|
""" remap checkpoint by iterating over state dicts in order (ignoring original keys).
|
|
|
|
|
This assumes models (and originating state dict) were created with params registered in same order.
|
|
|
|
|
"""
|
|
|
|
|
out_dict = {}
|
|
|
|
|
for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()):
|
|
|
|
|
assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
|
|
|
|
|
if va.shape != vb.shape:
|
|
|
|
|
if allow_reshape:
|
|
|
|
|
vb = vb.reshape(va.shape)
|
|
|
|
|
else:
|
|
|
|
|
assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
|
|
|
|
|
out_dict[ka] = vb
|
|
|
|
|
return out_dict
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
|
|
|
|
|
resume_epoch = None
|
|
|
|
|
if os.path.isfile(checkpoint_path):
|
|
|
|
|