diff --git a/timm/models/davit.py b/timm/models/davit.py index e9957f19..fe3cbad4 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -36,7 +36,7 @@ from .registry import register_model __all__ = ['DaViT'] class MySequential(nn.Sequential): - def forward(self, inputs : Tuple[Tensor, Tensor]): + def forward(self, inputs : Tuple[Tensor, Tuple[int, int]]): for module in self._modules.values(): #if type(inputs) == tuple: inputs = module(*inputs)