Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent b1337d7f4c
commit 38168b0741

@ -408,6 +408,10 @@ class DaViTStage(nn.Module):
self.blocks = nn.Sequential(*stage_blocks) self.blocks = nn.Sequential(*stage_blocks)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
def forward(self, x : Tensor): def forward(self, x : Tensor):
x = self.patch_embed(x) x = self.patch_embed(x)
if self.grad_checkpointing and not torch.jit.is_scripting(): if self.grad_checkpointing and not torch.jit.is_scripting():
@ -522,6 +526,8 @@ class DaViT(nn.Module):
@torch.jit.ignore @torch.jit.ignore
def set_grad_checkpointing(self, enable=True): def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable self.grad_checkpointing = enable
for stage in self.stages:
stage.set_grad_checkpointing(enable=enable)
@torch.jit.ignore @torch.jit.ignore
def get_classifier(self): def get_classifier(self):
@ -535,7 +541,10 @@ class DaViT(nn.Module):
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
x = self.stages(x) if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stages, x)
else:
x = self.stages(x)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x return x

Loading…
Cancel
Save