diff --git a/timm/models/davit.py b/timm/models/davit.py index 17898160..da4d7f52 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -17,7 +17,8 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig # This source code is licensed under the MIT license import itertools -from typing import Tuple +from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union + import torch import torch.nn as nn @@ -37,12 +38,16 @@ __all__ = ['DaViT'] class MySequential(nn.Sequential): - def forward(self, *inputs): + @Overload + def forward(self, inputs : Tensor): + for module in self._modules.values(): + inputs = module(inputs) + return inputs + + @Overload + def forward(self, inputs : Tuple[Tensor, Tensor]): for module in self._modules.values(): - if type(inputs) == tuple: inputs = module(*inputs) - else: - inputs = module(inputs) return inputs