BEiT-V2 checkpoints didn't remove 'module' from weights, adapt checkpoint filter

pull/1471/head
Ross Wightman 2 years ago
parent 73049dc2aa
commit c8ab747bf4

@ -384,6 +384,13 @@ class Beit(nn.Module):
return x return x
def _beit_checkpoint_filter_fn(state_dict, model):
if 'module' in state_dict:
# beit v2 didn't strip module
state_dict = state_dict['module']
return checkpoint_filter_fn(state_dict, model)
def _create_beit(variant, pretrained=False, **kwargs): def _create_beit(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Beit models.') raise RuntimeError('features_only not implemented for Beit models.')
@ -391,7 +398,7 @@ def _create_beit(variant, pretrained=False, **kwargs):
model = build_model_with_cfg( model = build_model_with_cfg(
Beit, variant, pretrained, Beit, variant, pretrained,
# FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes # FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes
pretrained_filter_fn=checkpoint_filter_fn, pretrained_filter_fn=_beit_checkpoint_filter_fn,
**kwargs) **kwargs)
return model return model

Loading…
Cancel
Save