Add brute-force checkpoint remapping option

pull/1479/head
Ross Wightman 2 years ago
parent b293dfa595
commit e858912e0c

@ -63,7 +63,7 @@ def load_state_dict(checkpoint_path, use_ema=True):
raise FileNotFoundError() raise FileNotFoundError()
def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True): def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False):
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
# numpy checkpoint, try to load via model specific load_pretrained fn # numpy checkpoint, try to load via model specific load_pretrained fn
if hasattr(model, 'load_pretrained'): if hasattr(model, 'load_pretrained'):
@ -72,10 +72,28 @@ def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True):
raise NotImplementedError('Model cannot load numpy checkpoint') raise NotImplementedError('Model cannot load numpy checkpoint')
return return
state_dict = load_state_dict(checkpoint_path, use_ema) state_dict = load_state_dict(checkpoint_path, use_ema)
if remap:
state_dict = remap_checkpoint(model, state_dict)
incompatible_keys = model.load_state_dict(state_dict, strict=strict) incompatible_keys = model.load_state_dict(state_dict, strict=strict)
return incompatible_keys return incompatible_keys
def remap_checkpoint(model, state_dict, allow_reshape=True):
""" remap checkpoint by iterating over state dicts in order (ignoring original keys).
This assumes models (and originating state dict) were created with params registered in same order.
"""
out_dict = {}
for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()):
assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
if va.shape != vb.shape:
if allow_reshape:
vb = vb.reshape(va.shape)
else:
assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
out_dict[ka] = vb
return out_dict
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
resume_epoch = None resume_epoch = None
if os.path.isfile(checkpoint_path): if os.path.isfile(checkpoint_path):

Loading…
Cancel
Save