diff --git a/timm/models/davit.py b/timm/models/davit.py index a51c8243..3ee6c4e6 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -569,6 +569,7 @@ def checkpoint_filter_fn(state_dict, model): import re out_dict = {} for k, v in state_dict.items(): + k = k.replace('stages.0.patch_embed', 'patch_embed') k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.patch_embed', k) k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k) k = k.replace('head.', 'head.fc.')