From 5d8ea5a21d6687b88e668fdf9faf235e9acc8c66 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Thu, 8 Dec 2022 07:07:32 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 53 ++++++++++++++++++++++++++------------------ 1 file changed, 32 insertions(+), 21 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index dc86c219..11f6ff11 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -23,8 +23,9 @@ import torch.nn as nn import torch.nn.functional as F from torch import Tensor from .helpers import build_model_with_cfg -from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp +from .features import FeatureInfo from collections import OrderedDict import torch.utils.checkpoint as checkpoint from .pretrained import generate_default_cfgs @@ -381,8 +382,7 @@ class DaViT(nn.Module): drop_rate=0., attn_drop_rate=0., num_classes=1000, - global_pool='avg', - **kwargs + global_pool='avg' ): super().__init__() @@ -439,12 +439,7 @@ class DaViT(nn.Module): ]) self.stages.add_module(f'stage_{stage_id}', stage) - - self.feature_info += [dict( - num_chs=self.embed_dims[stage_id], - reduction = 2, - module=f'stages.stage_{stage_id}.{depths[stage_id] - 1}.{len(attention_types) - 1}.mlp')] - + self.feature_info += [dict(num_chs=self.embed_dims[item], reduction=2, module=f'stages.stage_{stage_id}')] self.norms = norm_layer(self.num_features) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) @@ -498,15 +493,7 @@ class DaViT(nn.Module): # non-normalized pyramid features + corresponding sizes return features, sizes - def forward_pyramid_features(self, x): - x, sizes = self.forward_network(x) - outs = [] - for i, out in enumerate(x): - H, W = sizes[i] - outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()) - - return outs - + def forward_features(self, x): x, sizes = self.forward_network(x) # take final feature and norm @@ -523,6 +510,22 @@ class DaViT(nn.Module): x = self.forward_head(x) return x +class DaViTFeatures(DaViT): + + def __init__(*args): + super(DaViT, self).__init__(*args, **kwargs) + default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) + out_indices = kwargs.pop('out_indices', default_out_indices) + self.feature_info = FeatureInfo(self.feature_info, out_indices) + + def forward_pyramid_features(self, x) -> List[Tensor]: + x, sizes = self.forward_network(x) + outs = [] + for i, out in enumerate(x): + H, W = sizes[i] + outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()) + + return outs def checkpoint_filter_fn(state_dict, model): @@ -543,15 +546,23 @@ def checkpoint_filter_fn(state_dict, model): def _create_davit(variant, pretrained=False, **kwargs): - default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) - out_indices = kwargs.pop('out_indices', default_out_indices) + model_cls = HighResolutionNet + features_only = False + kwargs_filter = None + if model_kwargs.pop('features_only', False): + model_cls = HighResolutionNetFeatures + kwargs_filter = ('num_classes', 'global_pool') + features_only = True model = build_model_with_cfg( - DaViT, + model_cls, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs) + if features_only: + model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg) + model.default_cfg = model.pretrained_cfg # backwards compat return model