From 1b7671439741f35d00251140c5a021af781e5ca1 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 10 Dec 2022 04:16:11 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 39 +++++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 076e84b5..f019c5d8 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -514,7 +514,7 @@ class DaViT(nn.Module): self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')] - self.stages = nn.ModuleList(stages) + self.stages = SequentialWithSize(*stages) self.norms = norm_layer(self.num_features) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) @@ -544,7 +544,7 @@ class DaViT(nn.Module): global_pool = self.head.global_pool.pool_type self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) - + ''' def forward_network(self, x : Tensor): size: Tuple[int, int] = (x.size(2), x.size(3)) features = [x] @@ -570,12 +570,19 @@ class DaViT(nn.Module): outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()) return outs - + ''' + + + + def forward_features(self, x): - x, sizes = self.forward_network(x) + #x, sizes = self.forward_network(x) + size: Tuple[int, int] = (x.size(2), x.size(3)) + x, size = stages(x, size) + # take final feature and norm - x = self.norms(x[-1]) - H, W = sizes[-1] + x = self.norms(x) + H, W = sizes x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous() return x @@ -590,7 +597,8 @@ class DaViT(nn.Module): def forward(self, x): return self.forward_classifier(x) - + +''' class DaViTFeatures(DaViT): def __init__(self, *args, **kwargs): @@ -600,6 +608,8 @@ class DaViTFeatures(DaViT): def forward(self, x) -> List[Tensor]: return self.forward_pyramid_features(x) +''' + def checkpoint_filter_fn(state_dict, model): """ Remap MSFT checkpoints -> timm """ @@ -620,7 +630,7 @@ def checkpoint_filter_fn(state_dict, model): return out_dict - +''' def _create_davit(variant, pretrained=False, **kwargs): model_cls = DaViT features_only = False @@ -643,6 +653,19 @@ def _create_davit(variant, pretrained=False, **kwargs): model.default_cfg = model.pretrained_cfg # backwards compat return model +''' +def _create_davit(variant, pretrained=False, **kwargs): + default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) + out_indices = kwargs.pop('out_indices', default_out_indices) + model = build_model_with_cfg( + DaViT, + variant, + pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), + **kwargs) + + return model def _cfg(url='', **kwargs): # not sure how this should be set up