Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent e8fd52e39f
commit 6abb3ba634

@ -34,16 +34,18 @@ from .registry import register_model
__all__ = ['DaViT'] __all__ = ['DaViT']
'''
class MySequential(nn.Sequential): class MySequential(nn.Sequential):
def forward(self, inputs : Tuple): def forward(self, *inputs):
for module in self._modules.values(): for module in self._modules.values():
if type(inputs) == tuple: if type(inputs) == tuple:
inputs = module(*inputs) inputs = module(*inputs)
else: else:
inputs = module(inputs) inputs = module(inputs)
return inputs return inputs
''' '''
class MySequential(nn.Sequential): class MySequential(nn.Sequential):
@overload @overload
def forward(self, inputs : Tensor): def forward(self, inputs : Tensor):
@ -57,9 +59,9 @@ class MySequential(nn.Sequential):
inputs = module(*inputs) inputs = module(*inputs)
return inputs return inputs
'''
class ConvPosEnc(nn.Module): class ConvPosEnc(nn.Module):
def __init__(self, dim, k=3, act=False, normtype=False): def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
super(ConvPosEnc, self).__init__() super(ConvPosEnc, self).__init__()
self.proj = nn.Conv2d(dim, self.proj = nn.Conv2d(dim,
dim, dim,

Loading…
Cancel
Save