|
|
|
@ -514,9 +514,7 @@ class DaViTFeatures(DaViT):
|
|
|
|
|
|
|
|
|
|
def __init__(*args):
|
|
|
|
|
super(DaViT, self).__init__(*args, **kwargs)
|
|
|
|
|
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
|
|
|
|
|
out_indices = kwargs.pop('out_indices', default_out_indices)
|
|
|
|
|
self.feature_info = FeatureInfo(self.feature_info, out_indices)
|
|
|
|
|
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_inices', (1, 2, 3, 4)))
|
|
|
|
|
|
|
|
|
|
def forward(self, x) -> List[Tensor]:
|
|
|
|
|
x, sizes = self.forward_network(x)
|
|
|
|
@ -549,6 +547,8 @@ def _create_davit(variant, pretrained=False, **kwargs):
|
|
|
|
|
model_cls = DaViT
|
|
|
|
|
features_only = False
|
|
|
|
|
kwargs_filter = None
|
|
|
|
|
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
|
|
|
|
|
out_indices = kwargs.pop('out_indices', default_out_indices)
|
|
|
|
|
if kwargs.pop('features_only', False):
|
|
|
|
|
model_cls = DaViTFeatures
|
|
|
|
|
kwargs_filter = ('num_classes', 'global_pool')
|
|
|
|
|