Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 37191d8337
commit c9f42315de

@ -518,7 +518,7 @@ class DaViTFeatures(DaViT):
out_indices = kwargs.pop('out_indices', default_out_indices) out_indices = kwargs.pop('out_indices', default_out_indices)
self.feature_info = FeatureInfo(self.feature_info, out_indices) self.feature_info = FeatureInfo(self.feature_info, out_indices)
def forward_pyramid_features(self, x) -> List[Tensor]: def forward(self, x) -> List[Tensor]:
x, sizes = self.forward_network(x) x, sizes = self.forward_network(x)
outs = [] outs = []
for i, out in enumerate(x): for i, out in enumerate(x):
@ -546,11 +546,11 @@ def checkpoint_filter_fn(state_dict, model):
def _create_davit(variant, pretrained=False, **kwargs): def _create_davit(variant, pretrained=False, **kwargs):
model_cls = HighResolutionNet model_cls = DaViT
features_only = False features_only = False
kwargs_filter = None kwargs_filter = None
if model_kwargs.pop('features_only', False): if model_kwargs.pop('features_only', False):
model_cls = HighResolutionNetFeatures model_cls = DaViTFeatures
kwargs_filter = ('num_classes', 'global_pool') kwargs_filter = ('num_classes', 'global_pool')
features_only = True features_only = True
model = build_model_with_cfg( model = build_model_with_cfg(

Loading…
Cancel
Save