diff --git a/timm/models/davit.py b/timm/models/davit.py index a307ba6a..452f30c7 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -36,12 +36,9 @@ from .registry import register_model __all__ = ['DaViT'] class MySequential(nn.Sequential): - def forward(self, inputs : List[Tensor]): + def forward(self, inputs : Tuple[Tensor, Tensor]): for module in self._modules.values(): - if len(inputs) > 1: inputs = module(*inputs) - else: - inputs = module(inputs) return inputs class ConvPosEnc(nn.Module):