Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 01127c0e51
commit 036c8639dd

@ -657,7 +657,7 @@ def checkpoint_filter_fn(state_dict, model):
def _create_davit(variant, pretrained=False, **kwargs): def _create_davit(variant, pretrained=False, **kwargs):
out_indices = [i for i, _ in enumerate(kwargs.get(depths))] out_indices = [i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))]
feature_cfg = {'out_indices', out_indices} feature_cfg = {'out_indices', out_indices}
model = build_model_with_cfg(DaViT, variant, pretrained, model = build_model_with_cfg(DaViT, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=feature_cfg **kwargs) pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=feature_cfg **kwargs)

Loading…
Cancel
Save