|
|
@ -832,6 +832,7 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
import re
|
|
|
|
import re
|
|
|
|
out_dict = {}
|
|
|
|
out_dict = {}
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
for k, v in state_dict.items():
|
|
|
|
|
|
|
|
'''
|
|
|
|
k = k.replace('proj', 'conv')
|
|
|
|
k = k.replace('proj', 'conv')
|
|
|
|
k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k)
|
|
|
|
k = re.sub(r'layer_scale_([0-9]+)', r'layer_scale\1.scale', k)
|
|
|
|
k = k.replace('network.1', 'downsample_layers.1')
|
|
|
|
k = k.replace('network.1', 'downsample_layers.1')
|
|
|
@ -841,6 +842,7 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
k = k.replace('network.4', 'network.2')
|
|
|
|
k = k.replace('network.4', 'network.2')
|
|
|
|
k = k.replace('network.6', 'network.3')
|
|
|
|
k = k.replace('network.6', 'network.3')
|
|
|
|
k = k.replace('network', 'stages')
|
|
|
|
k = k.replace('network', 'stages')
|
|
|
|
|
|
|
|
'''
|
|
|
|
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
|
|
|
|
k = re.sub(r'downsample_layers.([0-9]+)', r'stages.\1.downsample', k)
|
|
|
|
k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k)
|
|
|
|
k = re.sub(r'([0-9]+).([0-9]+)', r'\1.blocks.\2', k)
|
|
|
|
k = k.replace('stages.0.downsample', 'patch_embed')
|
|
|
|
k = k.replace('stages.0.downsample', 'patch_embed')
|
|
|
|