From 3718c5a5bd4fd814b13a73c4cb3116d91337b7ff Mon Sep 17 00:00:00 2001 From: Richard Chen Date: Wed, 8 Sep 2021 11:53:05 -0400 Subject: [PATCH] fix loading pretrained model --- timm/models/crossvit.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/timm/models/crossvit.py b/timm/models/crossvit.py index 0873fdcc..f9296b74 100644 --- a/timm/models/crossvit.py +++ b/timm/models/crossvit.py @@ -337,11 +337,23 @@ def _create_crossvit(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') + def pretrained_filter_fn(state_dict): + new_state_dict = {} + for key in state_dict.keys(): + if 'pos_embed' in key or 'cls_token' in key: + new_key = key.replace(".", "_") + else: + new_key = key + new_state_dict[new_key] = state_dict[key] + return new_state_dict + return build_model_with_cfg( CrossViT, variant, pretrained, default_cfg=default_cfgs[variant], + pretrained_filter_fn=pretrained_filter_fn, **kwargs) + @register_model def crossvit_tiny_224(pretrained=False, **kwargs):