|
|
@ -36,8 +36,8 @@ __all__ = ['DaViT']
|
|
|
|
|
|
|
|
|
|
|
|
# modified nn.Sequential that includes a size tuple in the forward function
|
|
|
|
# modified nn.Sequential that includes a size tuple in the forward function
|
|
|
|
|
|
|
|
|
|
|
|
@register_notrace_module
|
|
|
|
|
|
|
|
class SequentialWithSize(nn.Sequential):
|
|
|
|
class SequentialWithSize(nn.Sequential):
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
def forward(self, x : Tensor, size: Tuple[int, int]):
|
|
|
|
for module in self._modules.values():
|
|
|
|
for module in self._modules.values():
|
|
|
|
x, size = module(x, size)
|
|
|
|
x, size = module(x, size)
|
|
|
|