From 6abb3ba634715f27137c22c787f6d94234a92d56 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Tue, 6 Dec 2022 21:02:33 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index d36a2742..18550b4c 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -34,16 +34,18 @@ from .registry import register_model __all__ = ['DaViT'] - +''' class MySequential(nn.Sequential): - def forward(self, inputs : Tuple): + def forward(self, *inputs): for module in self._modules.values(): if type(inputs) == tuple: inputs = module(*inputs) else: inputs = module(inputs) return inputs + ''' + class MySequential(nn.Sequential): @overload def forward(self, inputs : Tensor): @@ -57,9 +59,9 @@ class MySequential(nn.Sequential): inputs = module(*inputs) return inputs -''' + class ConvPosEnc(nn.Module): - def __init__(self, dim, k=3, act=False, normtype=False): + def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'): super(ConvPosEnc, self).__init__() self.proj = nn.Conv2d(dim, dim,