|
|
|
@ -7,7 +7,7 @@ from collections import OrderedDict
|
|
|
|
|
|
|
|
|
|
def load_checkpoint(model, checkpoint_path, use_ema=False):
|
|
|
|
|
if checkpoint_path and os.path.isfile(checkpoint_path):
|
|
|
|
|
checkpoint = torch.load(checkpoint_path)
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
|
|
|
state_dict_key = ''
|
|
|
|
|
if isinstance(checkpoint, dict):
|
|
|
|
|
state_dict_key = 'state_dict'
|
|
|
|
@ -32,7 +32,7 @@ def resume_checkpoint(model, checkpoint_path):
|
|
|
|
|
optimizer_state = None
|
|
|
|
|
resume_epoch = None
|
|
|
|
|
if os.path.isfile(checkpoint_path):
|
|
|
|
|
checkpoint = torch.load(checkpoint_path)
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
|
|
|
|
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
|
|
|
|
|
new_state_dict = OrderedDict()
|
|
|
|
|
for k, v in checkpoint['state_dict'].items():
|
|
|
|
|