diff --git a/timm/models/helpers.py b/timm/models/helpers.py index fda84171..9734c779 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -37,8 +37,14 @@ def clean_state_dict(state_dict): # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training cleaned_state_dict = OrderedDict() for k, v in state_dict.items(): - name = k[7:] if k.startswith('module.') else k - cleaned_state_dict[name] = v + # strip `module.` and `model.` prefixes + if k.startswith('module'): + name = k[7:] + elif k.startswith('model'): + name = k[6:] + else: + name = k + new_state_dict[name] = v return cleaned_state_dict