diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 9066a9de..9602355b 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -644,7 +644,7 @@ def checkpoint_filter_fn(state_dict, model, adapt_layer_scale=False): v = resize_pos_embed( v, model.pos_embed, - getattr(model, 'num_prefix_tokens', 1), + 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1), model.patch_embed.grid_size ) elif adapt_layer_scale and 'gamma_' in k: