From b355135a0adae7e4cd2b01d3a6920f1f4129cedd Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 10 Dec 2022 01:06:08 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 97 ++++++++++++++++++++++++++++++++------------ 1 file changed, 70 insertions(+), 27 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index caa91707..e14ab8e3 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -26,7 +26,7 @@ import torch.utils.checkpoint as checkpoint from .features import FeatureInfo from .fx_features import register_notrace_function, register_notrace_module -from .helpers import build_model_with_cfg, checkpoint_seq, pretrained_cfg_for_features +from .helpers import build_model_with_cfg, checkpoint_seq pretrained_cfg_for_features from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp from .pretrained import generate_default_cfgs from .registry import register_model @@ -35,18 +35,10 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['DaViT'] class SequentialWithSize(nn.Sequential): - def __init__(self, *args, **kwargs): - super(SequentialWithSize, self).__init__(*args, **kwargs) - - def forward(self, x: Tensor, size: Tuple[int, int]): - for module in self.__iter__(): + def forward(self, x : Tensor, size: Tuple[int, int]): + for module in self._modules.values(): x, size = module(x, size) - ''' - output = module(x, size) - x : Tensor = output[0] - size : Tuple[int, int] = output[1] - ''' - return x, size + return x, size class ConvPosEnc(nn.Module): @@ -419,17 +411,19 @@ class DaViTStage(nn.Module): window_size=window_size, )) - stage_blocks.append(SequentialWithSize(*dual_attention_block)) + stage_blocks.append(nn.ModuleList(*dual_attention_block)) - self.blocks = SequentialWithSize(*stage_blocks) + self.blocks = nn.ModuleList(*stage_blocks) def forward(self, x : Tensor, size: Tuple[int, int]): x, size = self.patch_embed(x, size) - if self.grad_checkpointing and not torch.jit.is_scripting(): - x, size = checkpoint_seq(self.blocks, x, size) - else: - x, size = self.blocks(x, size) - + for block in self.blocks + for layer in block: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x, size = checkpoint.checkpoint(layer, x, size) + else: + x, size = layer(x, size) + return x, size class DaViT(nn.Module): @@ -514,12 +508,12 @@ class DaViT(nn.Module): self.feature_info += [dict(num_chs=self.embed_dims[stage_id], reduction=2, module=f'stages.{stage_id}')] - self.stages = SequentialWithSize(*stages) + self.stages = nn.ModuleList(*stages) self.norms = norm_layer(self.num_features) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) self.apply(self._init_weights) - + def _init_weights(self, m): if isinstance(m, nn.Linear): @@ -545,21 +539,60 @@ class DaViT(nn.Module): self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) - def forward_features(self, x): + def forward_network(self, x): size: Tuple[int, int] = (x.size(2), x.size(3)) - x, size = self.stages(x, size) - x = self.norms(x) - H, W = size + features = [x] + sizes = [size] + + for stage in self.stages: + features[-1], sizes[-1] = stage(features[-1], sizes[-1]) + + # don't append outputs of last stage, since they are already there + if(len(features) < self.num_stages): + features.append(features[-1]) + sizes.append(sizes[-1]) + + + # non-normalized pyramid features + corresponding sizes + return features, sizes + + def forward_pyramid_features(self, x) -> List[Tensor]: + x, sizes = self.forward_network(x) + outs = [] + for i, out in enumerate(x): + H, W = sizes[i] + outs.append(out.view(-1, H, W, self.embed_dims[i]).permute(0, 3, 1, 2).contiguous()) + + return outs + + def forward_features(self, x): + x, sizes = self.forward_network(x) + # take final feature and norm + x = self.norms(x[-1]) + H, W = sizes[-1] x = x.view(-1, H, W, self.embed_dims[-1]).permute(0, 3, 1, 2).contiguous() return x def forward_head(self, x, pre_logits: bool = False): return self.head(x, pre_logits=pre_logits) - def forward(self, x): + def forward_classifier(self, x): x = self.forward_features(x) x = self.forward_head(x) return x + + def forward(self, x): + return self.forward_classifier(x) + + +class DaViTFeatures(DaViT): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_indices', (0, 1, 2, 3))) + + def forward(self, x) -> List[Tensor]: + return self.forward_pyramid_features(x) def checkpoint_filter_fn(state_dict, model): @@ -580,15 +613,25 @@ def checkpoint_filter_fn(state_dict, model): def _create_davit(variant, pretrained=False, **kwargs): + model_cls = DaViT + features_only = False + kwargs_filter = None default_out_indices = tuple(i for i, _ in enumerate(kwargs.get('depths', (1, 1, 3, 1)))) out_indices = kwargs.pop('out_indices', default_out_indices) + if kwargs.pop('features_only', False): + model_cls = DaViTFeatures + kwargs_filter = ('num_classes', 'global_pool') + features_only = True model = build_model_with_cfg( - DaViT, + model_cls, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(flatten_sequential=True, out_indices=out_indices), **kwargs) + if features_only: + model.pretrained_cfg = pretrained_cfg_for_features(model.default_cfg) + model.default_cfg = model.pretrained_cfg # backwards compat return model