diff --git a/timm/utils.py b/timm/utils.py index 59d2bcd0..1da69a96 100644 --- a/timm/utils.py +++ b/timm/utils.py @@ -259,7 +259,7 @@ class ModelEma: p.requires_grad_(False) def _load_checkpoint(self, checkpoint_path): - checkpoint = torch.load(checkpoint_path) + checkpoint = torch.load(checkpoint_path, map_location='cpu') assert isinstance(checkpoint, dict) if 'state_dict_ema' in checkpoint: new_state_dict = OrderedDict()