diff --git a/timm/models/davit.py b/timm/models/davit.py index 61388fe2..ae7b3836 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -17,7 +17,7 @@ DaViT model defs and weights adapted from https://github.com/dingmyu/davit, orig # This source code is licensed under the MIT license import itertools -from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, overload, Tuple, TypeVar, Union, List +from typing import Any, Dict, Iterable, Iterator, Mapping, Optional, Tuple, TypeVar, Union, List import torch @@ -43,21 +43,7 @@ class MySequential(nn.Sequential): else: inputs = module(inputs) return inputs -''' -class MySequential(nn.Sequential): - @overload - def forward(self, inputs : Tensor): - for module in self._modules.values(): - inputs = module(inputs) - return inputs - - @overload - def forward(self, inputs : Tuple[Tensor, Tensor]): - for module in self._modules.values(): - inputs = module(*inputs) - return inputs -''' class ConvPosEnc(nn.Module): def __init__(self, dim, k=3, act=False, normtype=False): super(ConvPosEnc, self).__init__()