From c9f42315debcc398fc9b9f806f71fff1ec49d4f8 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 8 Dec 2022 07:14:42 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 3f0d7107..5fcc6052 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -518,7 +518,7 @@ class DaViTFeatures(DaViT): out_indices = kwargs.pop('out_indices', default_out_indices) self.feature_info = FeatureInfo(self.feature_info, out_indices) - def forward_pyramid_features(self, x) -> List[Tensor]: + def forward(self, x) -> List[Tensor]: x, sizes = self.forward_network(x) outs = [] for i, out in enumerate(x): @@ -546,11 +546,11 @@ def checkpoint_filter_fn(state_dict, model): def _create_davit(variant, pretrained=False, **kwargs): - model_cls = HighResolutionNet + model_cls = DaViT features_only = False kwargs_filter = None if model_kwargs.pop('features_only', False): - model_cls = HighResolutionNetFeatures + model_cls = DaViTFeatures kwargs_filter = ('num_classes', 'global_pool') features_only = True model = build_model_with_cfg(