diff --git a/timm/models/davit.py b/timm/models/davit.py index 18550b4c..72c11679 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -34,18 +34,18 @@ from .registry import register_model __all__ = ['DaViT'] -''' + class MySequential(nn.Sequential): - def forward(self, *inputs): + def forward(self, inputs : Tuple[Tensor, Tensor]): for module in self._modules.values(): - if type(inputs) == tuple: - inputs = module(*inputs) - else: - inputs = module(inputs) + #if type(inputs) == tuple: + inputs = module(*inputs) + #else: + # inputs = module(inputs) return inputs -''' +''' class MySequential(nn.Sequential): @overload def forward(self, inputs : Tensor): @@ -59,7 +59,7 @@ class MySequential(nn.Sequential): inputs = module(*inputs) return inputs - +''' class ConvPosEnc(nn.Module): def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'): super(ConvPosEnc, self).__init__()