|
|
@ -565,6 +565,7 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.patch_embed', k)
|
|
|
|
k = re.sub(r'patch_embeds.([0-9]+)', r'stages.\1.patch_embed', k)
|
|
|
|
k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
|
|
|
|
k = re.sub(r'main_blocks.([0-9]+)', r'stages.\1.blocks', k)
|
|
|
|
k = k.replace('stages.0.patch_embed', 'stem')
|
|
|
|
k = k.replace('stages.0.patch_embed', 'stem')
|
|
|
|
|
|
|
|
k = k.replace('norms.', 'norm.')
|
|
|
|
k = k.replace('head.', 'head.fc.')
|
|
|
|
k = k.replace('head.', 'head.fc.')
|
|
|
|
k = k.replace('cpe.0', 'cpe1')
|
|
|
|
k = k.replace('cpe.0', 'cpe1')
|
|
|
|
k = k.replace('cpe.1', 'cpe2')
|
|
|
|
k = k.replace('cpe.1', 'cpe2')
|
|
|
|