diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 1deff273..a37f5062 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -7,7 +7,7 @@ from collections import OrderedDict def load_checkpoint(model, checkpoint_path, use_ema=False): if checkpoint_path and os.path.isfile(checkpoint_path): - checkpoint = torch.load(checkpoint_path) + checkpoint = torch.load(checkpoint_path, map_location='cpu') state_dict_key = '' if isinstance(checkpoint, dict): state_dict_key = 'state_dict' @@ -32,7 +32,7 @@ def resume_checkpoint(model, checkpoint_path): optimizer_state = None resume_epoch = None if os.path.isfile(checkpoint_path): - checkpoint = torch.load(checkpoint_path) + checkpoint = torch.load(checkpoint_path, map_location='cpu') if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: new_state_dict = OrderedDict() for k, v in checkpoint['state_dict'].items():