From 729fd0d53b24a6b76bea95ae1e1302680dae7858 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 10 Dec 2022 21:04:33 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index b1f82e23..3cca1a3d 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -788,7 +788,7 @@ class DaViT(nn.Module): if global_pool is None: global_pool = self.head.global_pool.pool_type self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) - + ''' def forward_network(self, x : Tensor): size: Tuple[int, int] = (x.size(2), x.size(3)) features = [x] @@ -808,16 +808,16 @@ class DaViT(nn.Module): def forward_pyramid_features(self, x) -> List[Tensor]: x = self.forward_network(x) - ''' + outs = [] for i, out in enumerate(x): H, W = sizes[i] outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()) - ''' + return x - + ''' def forward_features(self, x): - x = self.forward_network(x) + x = self.stages(x) # take final feature and norm x = self.norms(x[-1].permute(0, 2, 3, 1)).permute(0, 3, 1, 2) #H, W = sizes[-1] @@ -834,7 +834,7 @@ class DaViT(nn.Module): def forward(self, x): return self.forward_classifier(x) - +''' class DaViTFeatures(DaViT): def __init__(self, *args, **kwargs): @@ -843,7 +843,7 @@ class DaViTFeatures(DaViT): def forward(self, x) -> List[Tensor]: return self.forward_pyramid_features(x) - +''' def checkpoint_filter_fn(state_dict, model): """ Remap MSFT checkpoints -> timm """ if 'head.norm.weight' in state_dict: @@ -866,16 +866,11 @@ def checkpoint_filter_fn(state_dict, model): def _create_davit(variant, pretrained=False, **kwargs): - model_cls = DaViT - features_only = False - kwargs_filter = None + default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) out_indices = kwargs.pop('out_indices', default_out_indices) - if kwargs.pop('features_only', False): - model_cls = DaViTFeatures - kwargs_filter = ('num_classes', 'global_pool') - features_only = True + model = build_model_with_cfg( model_cls, variant, @@ -883,9 +878,7 @@ def _create_davit(variant, pretrained=False, **kwargs): pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs) - if features_only: - model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg) - model.default_cfg = model.pretrained_cfg # backwards compat + return model