Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 37191d8337
commit c9f42315de

@ -518,7 +518,7 @@ class DaViTFeatures(DaViT):
out_indices = kwargs.pop('out_indices', default_out_indices)
self.feature_info = FeatureInfo(self.feature_info, out_indices)
def forward_pyramid_features(self, x) -> List[Tensor]:
def forward(self, x) -> List[Tensor]:
x, sizes = self.forward_network(x)
outs = []
for i, out in enumerate(x):
@ -546,11 +546,11 @@ def checkpoint_filter_fn(state_dict, model):
def _create_davit(variant, pretrained=False, **kwargs):
model_cls = HighResolutionNet
model_cls = DaViT
features_only = False
kwargs_filter = None
if model_kwargs.pop('features_only', False):
model_cls = HighResolutionNetFeatures
model_cls = DaViTFeatures
kwargs_filter = ('num_classes', 'global_pool')
features_only = True
model = build_model_with_cfg(

Loading…
Cancel
Save