|
|
@ -626,7 +626,7 @@ def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
|
|
|
|
return posemb
|
|
|
|
return posemb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def checkpoint_filter_fn(state_dict, model):
|
|
|
|
def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
|
|
|
|
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
|
|
|
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
|
|
|
import re
|
|
|
|
import re
|
|
|
|
out_dict = {}
|
|
|
|
out_dict = {}
|
|
|
@ -647,7 +647,7 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
getattr(model, 'num_prefix_tokens', 1),
|
|
|
|
getattr(model, 'num_prefix_tokens', 1),
|
|
|
|
model.patch_embed.grid_size
|
|
|
|
model.patch_embed.grid_size
|
|
|
|
)
|
|
|
|
)
|
|
|
|
elif 'gamma_' in k:
|
|
|
|
elif adapt_layer_scale and 'gamma_' in k:
|
|
|
|
# remap layer-scale gamma into sub-module (deit3 models)
|
|
|
|
# remap layer-scale gamma into sub-module (deit3 models)
|
|
|
|
k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
|
|
|
|
k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
|
|
|
|
elif 'pre_logits' in k:
|
|
|
|
elif 'pre_logits' in k:
|
|
|
|