From e858912e0c20f22b28af5c4d9e04b1872b76f86f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 23 Sep 2022 16:07:03 -0700 Subject: [PATCH] Add brute-force checkpoint remapping option --- timm/models/helpers.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index d68c7e65..c771e825 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -63,7 +63,7 @@ def load_state_dict(checkpoint_path, use_ema=True): raise FileNotFoundError() -def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True): +def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False): if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'): # numpy checkpoint, try to load via model specific load_pretrained fn if hasattr(model, 'load_pretrained'): @@ -72,10 +72,28 @@ def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True): raise NotImplementedError('Model cannot load numpy checkpoint') return state_dict = load_state_dict(checkpoint_path, use_ema) + if remap: + state_dict = remap_checkpoint(model, state_dict) incompatible_keys = model.load_state_dict(state_dict, strict=strict) return incompatible_keys +def remap_checkpoint(model, state_dict, allow_reshape=True): + """ remap checkpoint by iterating over state dicts in order (ignoring original keys). + This assumes models (and originating state dict) were created with params registered in same order. + """ + out_dict = {} + for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()): + assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' + if va.shape != vb.shape: + if allow_reshape: + vb = vb.reshape(va.shape) + else: + assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.' + out_dict[ka] = vb + return out_dict + + def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True): resume_epoch = None if os.path.isfile(checkpoint_path):