Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 625bb86a64
commit f8c5150fcd

@ -35,6 +35,7 @@ from .registry import register_model
__all__ = ['DaViT']
'''
class MySequential(nn.Sequential):
def forward(self, inputs : Tuple[Tensor, Tuple[int, int]]):
for module in self:
@ -45,7 +46,7 @@ class MySequential(nn.Sequential):
# inputs = module(inputs)
return inputs
'''
'''
class MySequential(nn.Sequential):
@overload
@ -517,9 +518,9 @@ class DaViT(nn.Module):
for _, layer in enumerate(block):
print(layer)
if self.grad_checkpointing and not torch.jit.is_scripting():
features[-1], sizes[-1] = checkpoint.checkpoint(layer, (features[-1], sizes[-1]))
features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1])
else:
features[-1], sizes[-1] = layer((features[-1], sizes[-1]))
features[-1], sizes[-1] = layer(features[-1], sizes[-1])
features.append(features[-1])

Loading…
Cancel
Save