From c8ab747bf4f785c43782a8ef3863100b0b6c0029 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 13 Sep 2022 17:56:49 -0700 Subject: [PATCH] BEiT-V2 checkpoints didn't remove 'module' from weights, adapt checkpoint filter --- timm/models/beit.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/timm/models/beit.py b/timm/models/beit.py index 60497d9a..1f6bf82b 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -384,6 +384,13 @@ class Beit(nn.Module): return x +def _beit_checkpoint_filter_fn(state_dict, model): + if 'module' in state_dict: + # beit v2 didn't strip module + state_dict = state_dict['module'] + return checkpoint_filter_fn(state_dict, model) + + def _create_beit(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Beit models.') @@ -391,7 +398,7 @@ def _create_beit(variant, pretrained=False, **kwargs): model = build_model_with_cfg( Beit, variant, pretrained, # FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes - pretrained_filter_fn=checkpoint_filter_fn, + pretrained_filter_fn=_beit_checkpoint_filter_fn, **kwargs) return model