From 163d95155078a2faea1e98d20c0ec1ae4aed263d Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 8 Dec 2022 07:21:28 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 1bc8c7b0..539086c4 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -514,9 +514,7 @@ class DaViTFeatures(DaViT): def __init__(*args): super(DaViT, self).__init__(*args, **kwargs) - default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) - out_indices = kwargs.pop('out_indices', default_out_indices) - self.feature_info = FeatureInfo(self.feature_info, out_indices) + self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_inices', (1, 2, 3, 4))) def forward(self, x) -> List[Tensor]: x, sizes = self.forward_network(x) @@ -549,6 +547,8 @@ def _create_davit(variant, pretrained=False, **kwargs): model_cls = DaViT features_only = False kwargs_filter = None + default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) + out_indices = kwargs.pop('out_indices', default_out_indices) if kwargs.pop('features_only', False): model_cls = DaViTFeatures kwargs_filter = ('num_classes', 'global_pool')