|
|
@ -36,12 +36,9 @@ from .registry import register_model
|
|
|
|
__all__ = ['DaViT']
|
|
|
|
__all__ = ['DaViT']
|
|
|
|
|
|
|
|
|
|
|
|
class MySequential(nn.Sequential):
|
|
|
|
class MySequential(nn.Sequential):
|
|
|
|
def forward(self, inputs : List[Tensor]):
|
|
|
|
def forward(self, inputs : Tuple[Tensor, Tensor]):
|
|
|
|
for module in self._modules.values():
|
|
|
|
for module in self._modules.values():
|
|
|
|
if len(inputs) > 1:
|
|
|
|
|
|
|
|
inputs = module(*inputs)
|
|
|
|
inputs = module(*inputs)
|
|
|
|
else:
|
|
|
|
|
|
|
|
inputs = module(inputs)
|
|
|
|
|
|
|
|
return inputs
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
|
|
|
|
class ConvPosEnc(nn.Module):
|
|
|
|
class ConvPosEnc(nn.Module):
|
|
|
|