|
|
@ -34,18 +34,18 @@ from .registry import register_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['DaViT']
|
|
|
|
__all__ = ['DaViT']
|
|
|
|
'''
|
|
|
|
|
|
|
|
class MySequential(nn.Sequential):
|
|
|
|
class MySequential(nn.Sequential):
|
|
|
|
def forward(self, *inputs):
|
|
|
|
def forward(self, inputs : Tuple[Tensor, Tensor]):
|
|
|
|
for module in self._modules.values():
|
|
|
|
for module in self._modules.values():
|
|
|
|
if type(inputs) == tuple:
|
|
|
|
#if type(inputs) == tuple:
|
|
|
|
inputs = module(*inputs)
|
|
|
|
inputs = module(*inputs)
|
|
|
|
else:
|
|
|
|
#else:
|
|
|
|
inputs = module(inputs)
|
|
|
|
# inputs = module(inputs)
|
|
|
|
return inputs
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
class MySequential(nn.Sequential):
|
|
|
|
class MySequential(nn.Sequential):
|
|
|
|
@overload
|
|
|
|
@overload
|
|
|
|
def forward(self, inputs : Tensor):
|
|
|
|
def forward(self, inputs : Tensor):
|
|
|
@ -59,7 +59,7 @@ class MySequential(nn.Sequential):
|
|
|
|
inputs = module(*inputs)
|
|
|
|
inputs = module(*inputs)
|
|
|
|
return inputs
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
class ConvPosEnc(nn.Module):
|
|
|
|
class ConvPosEnc(nn.Module):
|
|
|
|
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
|
|
|
|
def __init__(self, dim : int, k : int=3, act : bool=False, normtype : str='none'):
|
|
|
|
super(ConvPosEnc, self).__init__()
|
|
|
|
super(ConvPosEnc, self).__init__()
|
|
|
|