Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent a97963bff1
commit 1470cebb6c

@ -569,9 +569,10 @@ def checkpoint_filter_fn(state_dict, model):
import re
out_dict = {}
for k, v in state_dict.items():
k = k.replace('stages.0.patch_embed', 'patch_embed')
k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.patch_embed', k)
k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
k = k.replace('stages.0.patch_embed', 'patch_embed')
k = k.replace('head.', 'head.fc.')
k = k.replace('cpe.0', 'cpe1')
k = k.replace('cpe.1', 'cpe2')

Loading…
Cancel
Save