diff --git a/timm/models/beit.py b/timm/models/beit.py index 60497d9a..1f6bf82b 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -384,6 +384,13 @@ class Beit(nn.Module): return x +def _beit_checkpoint_filter_fn(state_dict, model): + if 'module' in state_dict: + # beit v2 didn't strip module + state_dict = state_dict['module'] + return checkpoint_filter_fn(state_dict, model) + + def _create_beit(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Beit models.') @@ -391,7 +398,7 @@ def _create_beit(variant, pretrained=False, **kwargs): model = build_model_with_cfg( Beit, variant, pretrained, # FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes - pretrained_filter_fn=checkpoint_filter_fn, + pretrained_filter_fn=_beit_checkpoint_filter_fn, **kwargs) return model