Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 8792e906ab
commit 163d951550

@ -514,9 +514,7 @@ class DaViTFeatures(DaViT):
def __init__(*args):
super(DaViT, self).__init__(*args, **kwargs)
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
self.feature_info = FeatureInfo(self.feature_info, out_indices)
self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_inices', (1, 2, 3, 4)))
def forward(self, x) -> List[Tensor]:
x, sizes = self.forward_network(x)
@ -549,6 +547,8 @@ def _create_davit(variant, pretrained=False, **kwargs):
model_cls = DaViT
features_only = False
kwargs_filter = None
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
if kwargs.pop('features_only', False):
model_cls = DaViTFeatures
kwargs_filter = ('num_classes', 'global_pool')

Loading…
Cancel
Save