diff --git a/timm/models/davit.py b/timm/models/davit.py index da4d7f52..72aefb00 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -38,13 +38,13 @@ __all__ = ['DaViT'] class MySequential(nn.Sequential): - @Overload + @overload def forward(self, inputs : Tensor): for module in self._modules.values(): inputs = module(inputs) return inputs - @Overload + @overload def forward(self, inputs : Tuple[Tensor, Tensor]): for module in self._modules.values(): inputs = module(*inputs)