|
|
|
@ -518,7 +518,7 @@ class DaViTFeatures(DaViT):
|
|
|
|
|
out_indices = kwargs.pop('out_indices', default_out_indices)
|
|
|
|
|
self.feature_info = FeatureInfo(self.feature_info, out_indices)
|
|
|
|
|
|
|
|
|
|
def forward_pyramid_features(self, x) -> List[Tensor]:
|
|
|
|
|
def forward(self, x) -> List[Tensor]:
|
|
|
|
|
x, sizes = self.forward_network(x)
|
|
|
|
|
outs = []
|
|
|
|
|
for i, out in enumerate(x):
|
|
|
|
@ -546,11 +546,11 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_davit(variant, pretrained=False, **kwargs):
|
|
|
|
|
model_cls = HighResolutionNet
|
|
|
|
|
model_cls = DaViT
|
|
|
|
|
features_only = False
|
|
|
|
|
kwargs_filter = None
|
|
|
|
|
if model_kwargs.pop('features_only', False):
|
|
|
|
|
model_cls = HighResolutionNetFeatures
|
|
|
|
|
model_cls = DaViTFeatures
|
|
|
|
|
kwargs_filter = ('num_classes', 'global_pool')
|
|
|
|
|
features_only = True
|
|
|
|
|
model = build_model_with_cfg(
|
|
|
|
|