Update helpers.py

Fixing out of memory error by loading the checkpoint onto the CPU.
pull/23/head
Minqin Chen 5 years ago committed by GitHub
parent 0c874195db
commit 4e7a854dd0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -7,7 +7,7 @@ from collections import OrderedDict
def load_checkpoint(model, checkpoint_path, use_ema=False): def load_checkpoint(model, checkpoint_path, use_ema=False):
if checkpoint_path and os.path.isfile(checkpoint_path): if checkpoint_path and os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path) checkpoint = torch.load(checkpoint_path, map_location='cpu')
state_dict_key = '' state_dict_key = ''
if isinstance(checkpoint, dict): if isinstance(checkpoint, dict):
state_dict_key = 'state_dict' state_dict_key = 'state_dict'
@ -32,7 +32,7 @@ def resume_checkpoint(model, checkpoint_path):
optimizer_state = None optimizer_state = None
resume_epoch = None resume_epoch = None
if os.path.isfile(checkpoint_path): if os.path.isfile(checkpoint_path):
checkpoint = torch.load(checkpoint_path) checkpoint = torch.load(checkpoint_path, map_location='cpu')
if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
new_state_dict = OrderedDict() new_state_dict = OrderedDict()
for k, v in checkpoint['state_dict'].items(): for k, v in checkpoint['state_dict'].items():

Loading…
Cancel
Save