|
|
|
""" Model creation / weight loading / state_dict helpers
|
|
|
|
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
|
|
"""
|
|
|
|
import logging
|
|
|
|
import os
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import safetensors.torch
|
|
|
|
|
|
|
|
import timm.models._builder
|
|
|
|
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
__all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_checkpoint', 'resume_checkpoint']
|
|
|
|
|
|
|
|
|
|
|
|
def clean_state_dict(state_dict):
|
|
|
|
# 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
|
|
|
|
cleaned_state_dict = OrderedDict()
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
name = k[7:] if k.startswith('module.') else k
|
|
|
|
cleaned_state_dict[name] = v
|
|
|
|
return cleaned_state_dict
|
|
|
|
|
|
|
|
|
|
|
|
def load_state_dict(checkpoint_path, use_ema=True):
|
|
|
|
if checkpoint_path and os.path.isfile(checkpoint_path):
|
|
|
|
# Check if safetensors or not and load weights accordingly
|
|
|
|
if str(checkpoint_path).endswith(".safetensors"):
|
|
|
|
checkpoint = safetensors.torch.load_file(checkpoint_path, device='cpu')
|
|
|
|
else:
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
|
|
|
|
|
|
state_dict_key = ''
|
|
|
|
if isinstance(checkpoint, dict):
|
|
|
|
if use_ema and checkpoint.get('state_dict_ema', None) is not None:
|
|
|
|
state_dict_key = 'state_dict_ema'
|
|
|
|
elif use_ema and checkpoint.get('model_ema', None) is not None:
|
|
|
|
state_dict_key = 'model_ema'
|
|
|
|
elif 'state_dict' in checkpoint:
|
|
|
|
state_dict_key = 'state_dict'
|
|
|
|
elif 'model' in checkpoint:
|
|
|
|
state_dict_key = 'model'
|
|
|
|
state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint)
|
|
|
|
_logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
|
|
|
|
return state_dict
|
|
|
|
else:
|
|
|
|
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
|
|
|
raise FileNotFoundError()
|
|
|
|
|
|
|
|
|
|
|
|
def load_checkpoint(model, checkpoint_path, use_ema=True, strict=True, remap=False):
|
|
|
|
if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
|
|
|
|
# numpy checkpoint, try to load via model specific load_pretrained fn
|
|
|
|
if hasattr(model, 'load_pretrained'):
|
|
|
|
timm.models._model_builder.load_pretrained(checkpoint_path)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError('Model cannot load numpy checkpoint')
|
|
|
|
return
|
|
|
|
state_dict = load_state_dict(checkpoint_path, use_ema)
|
|
|
|
if remap:
|
|
|
|
state_dict = remap_checkpoint(model, state_dict)
|
|
|
|
incompatible_keys = model.load_state_dict(state_dict, strict=strict)
|
|
|
|
return incompatible_keys
|
|
|
|
|
|
|
|
|
|
|
|
def remap_checkpoint(model, state_dict, allow_reshape=True):
|
|
|
|
""" remap checkpoint by iterating over state dicts in order (ignoring original keys).
|
|
|
|
This assumes models (and originating state dict) were created with params registered in same order.
|
|
|
|
"""
|
|
|
|
out_dict = {}
|
|
|
|
for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()):
|
|
|
|
assert va.numel == vb.numel, f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
|
|
|
|
if va.shape != vb.shape:
|
|
|
|
if allow_reshape:
|
|
|
|
vb = vb.reshape(va.shape)
|
|
|
|
else:
|
|
|
|
assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
|
|
|
|
out_dict[ka] = vb
|
|
|
|
return out_dict
|
|
|
|
|
|
|
|
|
|
|
|
def resume_checkpoint(model, checkpoint_path, optimizer=None, loss_scaler=None, log_info=True):
|
|
|
|
resume_epoch = None
|
|
|
|
if os.path.isfile(checkpoint_path):
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
|
|
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
|
|
|
if log_info:
|
|
|
|
_logger.info('Restoring model state from checkpoint...')
|
|
|
|
state_dict = clean_state_dict(checkpoint['state_dict'])
|
|
|
|
model.load_state_dict(state_dict)
|
|
|
|
|
|
|
|
if optimizer is not None and 'optimizer' in checkpoint:
|
|
|
|
if log_info:
|
|
|
|
_logger.info('Restoring optimizer state from checkpoint...')
|
|
|
|
optimizer.load_state_dict(checkpoint['optimizer'])
|
|
|
|
|
|
|
|
if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
|
|
|
|
if log_info:
|
|
|
|
_logger.info('Restoring AMP loss scaler state from checkpoint...')
|
|
|
|
loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
|
|
|
|
|
|
|
|
if 'epoch' in checkpoint:
|
|
|
|
resume_epoch = checkpoint['epoch']
|
|
|
|
if 'version' in checkpoint and checkpoint['version'] > 1:
|
|
|
|
resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
|
|
|
|
|
|
|
|
if log_info:
|
|
|
|
_logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
|
|
|
|
else:
|
|
|
|
model.load_state_dict(checkpoint)
|
|
|
|
if log_info:
|
|
|
|
_logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
|
|
|
|
return resume_epoch
|
|
|
|
else:
|
|
|
|
_logger.error("No checkpoint found at '{}'".format(checkpoint_path))
|
|
|
|
raise FileNotFoundError()
|
|
|
|
|
|
|
|
|