Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent 1bff091c78
commit 3cff8d0ff4

@ -411,7 +411,7 @@ class DaViTStage(nn.Module):
self.blocks = nn.Sequential(*stage_blocks)
def forward(self, x : Tensor):
x = self.patch_embed(x)
#x = self.patch_embed(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:

Loading…
Cancel
Save