|
|
@ -35,6 +35,7 @@ from .registry import register_model
|
|
|
|
|
|
|
|
|
|
|
|
__all__ = ['DaViT']
|
|
|
|
__all__ = ['DaViT']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
class MySequential(nn.Sequential):
|
|
|
|
class MySequential(nn.Sequential):
|
|
|
|
def forward(self, inputs : Tuple[Tensor, Tuple[int, int]]):
|
|
|
|
def forward(self, inputs : Tuple[Tensor, Tuple[int, int]]):
|
|
|
|
for module in self:
|
|
|
|
for module in self:
|
|
|
@ -45,7 +46,7 @@ class MySequential(nn.Sequential):
|
|
|
|
# inputs = module(inputs)
|
|
|
|
# inputs = module(inputs)
|
|
|
|
return inputs
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
'''
|
|
|
|
'''
|
|
|
|
class MySequential(nn.Sequential):
|
|
|
|
class MySequential(nn.Sequential):
|
|
|
|
@overload
|
|
|
|
@overload
|
|
|
@ -517,9 +518,9 @@ class DaViT(nn.Module):
|
|
|
|
for _, layer in enumerate(block):
|
|
|
|
for _, layer in enumerate(block):
|
|
|
|
print(layer)
|
|
|
|
print(layer)
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
features[-1], sizes[-1] = checkpoint.checkpoint(layer, (features[-1], sizes[-1]))
|
|
|
|
features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1])
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
features[-1], sizes[-1] = layer((features[-1], sizes[-1]))
|
|
|
|
features[-1], sizes[-1] = layer(features[-1], sizes[-1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
features.append(features[-1])
|
|
|
|
features.append(features[-1])
|
|
|
|