From c902731699952762d3145a0e6c72359bda7175f0 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 6 Dec 2022 20:50:09 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index a307ba6a..452f30c7 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -36,12 +36,9 @@ from .registry import register_model __all__ = ['DaViT'] class MySequential(nn.Sequential): - def forward(self, inputs : List[Tensor]): + def forward(self, inputs : Tuple[Tensor, Tensor]): for module in self._modules.values(): - if len(inputs) > 1: inputs = module(*inputs) - else: - inputs = module(inputs) return inputs class ConvPosEnc(nn.Module):