diff --git a/timm/models/davit.py b/timm/models/davit.py index c48fb864..444f21f3 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -546,7 +546,8 @@ def checkpoint_filter_fn(state_dict, model): out_dict = {} import re for k, v in state_dict.items(): - k = k.replace('norms.', 'head.norm.') + + k = k.replace('head.', 'head.fc.') out_dict[k] = v return out_dict