Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 625bb86a64
commit f8c5150fcd

@ -35,6 +35,7 @@ from .registry import register_model
__all__ = ['DaViT'] __all__ = ['DaViT']
'''
class MySequential(nn.Sequential): class MySequential(nn.Sequential):
def forward(self, inputs : Tuple[Tensor, Tuple[int, int]]): def forward(self, inputs : Tuple[Tensor, Tuple[int, int]]):
for module in self: for module in self:
@ -45,7 +46,7 @@ class MySequential(nn.Sequential):
# inputs = module(inputs) # inputs = module(inputs)
return inputs return inputs
'''
''' '''
class MySequential(nn.Sequential): class MySequential(nn.Sequential):
@overload @overload
@ -517,9 +518,9 @@ class DaViT(nn.Module):
for _, layer in enumerate(block): for _, layer in enumerate(block):
print(layer) print(layer)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():
features[-1], sizes[-1] = checkpoint.checkpoint(layer, (features[-1], sizes[-1])) features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1])
else: else:
features[-1], sizes[-1] = layer((features[-1], sizes[-1])) features[-1], sizes[-1] = layer(features[-1], sizes[-1])
features.append(features[-1]) features.append(features[-1])

Loading…
Cancel
Save