Take `no_emb_class` into account when calling `resize_pos_embed`

pull/1365/head
Ceshine Lee 2 years ago
parent d7b55a9429
commit 0b64117592

@ -644,7 +644,7 @@ def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False):
v = resize_pos_embed(
v,
model.pos_embed,
getattr(model, 'num_prefix_tokens', 1),
0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1),
model.patch_embed.grid_size
)
elif adapt_layer_scale and 'gamma_' in k:

Loading…
Cancel
Save