diff --git a/timm/models/davit.py b/timm/models/davit.py index 31e21fef..7a6d8142 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -419,7 +419,7 @@ class DaViTStage(nn.Module): cpe_act=cpe_act )) - stage_blocks.append(nn.ModuleList(dual_attention_block)) + stage_blocks.append(SequentialWithSize(*dual_attention_block)) self.blocks = SequentialWithSize(*stage_blocks)