Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 8792e906ab
commit 163d951550

@ -514,9 +514,7 @@ class DaViTFeatures(DaViT):
def __init__(*args): def __init__(*args):
super(DaViT, self).__init__(*args, **kwargs) super(DaViT, self).__init__(*args, **kwargs)
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_inices', (1, 2, 3, 4)))
out_indices = kwargs.pop('out_indices', default_out_indices)
self.feature_info = FeatureInfo(self.feature_info, out_indices)
def forward(self, x) -> List[Tensor]: def forward(self, x) -> List[Tensor]:
x, sizes = self.forward_network(x) x, sizes = self.forward_network(x)
@ -549,6 +547,8 @@ def _create_davit(variant, pretrained=False, **kwargs):
model_cls = DaViT model_cls = DaViT
features_only = False features_only = False
kwargs_filter = None kwargs_filter = None
default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1))))
out_indices = kwargs.pop('out_indices', default_out_indices)
if kwargs.pop('features_only', False): if kwargs.pop('features_only', False):
model_cls = DaViTFeatures model_cls = DaViTFeatures
kwargs_filter = ('num_classes', 'global_pool') kwargs_filter = ('num_classes', 'global_pool')

Loading…
Cancel
Save