From cb55f17dde5556bf2cf2bf55bee463a1551c9772 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 10 Dec 2022 04:28:13 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 9c8b608e..220623e4 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -26,7 +26,7 @@ import torch.utils.checkpoint as checkpoint from .features import FeatureInfo from .fx_features import register_notrace_function, register_notrace_module -from .helpers import build_model_with_cfg, checkpoint_seq, pretrained_cfg_for_features +from .helpers import build_model_with_cfg, pretrained_cfg_for_features from .layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, ClassifierHead, Mlp from .pretrained import generate_default_cfgs from .registry import register_model @@ -421,7 +421,7 @@ class DaViTStage(nn.Module): stage_blocks.append(nn.ModuleList(dual_attention_block)) - self.blocks = nn.ModuleList(stage_blocks) + self.blocks = SequentialWithSize(*stage_blocks) def forward(self, x : Tensor, size: Tuple[int, int]): x, size = self.patch_embed(x, size) @@ -516,7 +516,6 @@ class DaViT(nn.Module): self.stages = SequentialWithSize(*stages) - self.norms = norm_layer(self.num_features) self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) self.apply(self._init_weights) @@ -579,7 +578,7 @@ class DaViT(nn.Module): def forward_features(self, x): #x, sizes = self.forward_network(x) size: Tuple[int, int] = (x.size(2), x.size(3)) - x, size = self.stages(x, size) + x, size = stages(x, size) # take final feature and norm x = self.norms(x)