@ -546,7 +546,8 @@ def checkpoint_filter_fn(state_dict, model):
out_dict = {}
import re
for k, v in state_dict.items():
k = k.replace('norms.', 'head.norm.')
k = k.replace('head.', 'head.fc.')
out_dict[k] = v
return out_dict