|
|
@ -666,6 +666,7 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
|
|
|
|
|
|
|
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
|
|
|
|
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
|
|
|
|
k = k.replace('downsample.proj', 'downsample.conv')
|
|
|
|
k = k.replace('downsample.proj', 'downsample.conv')
|
|
|
|
|
|
|
|
k = k.replace('patch_embed.proj', 'patch_embed.conv')
|
|
|
|
k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k)
|
|
|
|
k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k)
|
|
|
|
k = k.replace('stages.0.downsample', 'patch_embed')
|
|
|
|
k = k.replace('stages.0.downsample', 'patch_embed')
|
|
|
|
k = re.sub(r'^head', 'head.fc', k)
|
|
|
|
k = re.sub(r'^head', 'head.fc', k)
|
|
|
|