|
|
|
@ -384,6 +384,13 @@ class Beit(nn.Module):
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _beit_checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
if 'module' in state_dict:
|
|
|
|
|
# beit v2 didn't strip module
|
|
|
|
|
state_dict = state_dict['module']
|
|
|
|
|
return checkpoint_filter_fn(state_dict, model)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_beit(variant, pretrained=False, **kwargs):
|
|
|
|
|
if kwargs.get('features_only', None):
|
|
|
|
|
raise RuntimeError('features_only not implemented for Beit models.')
|
|
|
|
@ -391,7 +398,7 @@ def _create_beit(variant, pretrained=False, **kwargs):
|
|
|
|
|
model = build_model_with_cfg(
|
|
|
|
|
Beit, variant, pretrained,
|
|
|
|
|
# FIXME an updated filter fn needed to interpolate rel pos emb if fine tuning to diff model sizes
|
|
|
|
|
pretrained_filter_fn=checkpoint_filter_fn,
|
|
|
|
|
pretrained_filter_fn=_beit_checkpoint_filter_fn,
|
|
|
|
|
**kwargs)
|
|
|
|
|
return model
|
|
|
|
|
|
|
|
|
|