Update davit.py

pull/1630/head
Fredo Guan 3 years ago
parent b1337d7f4c
commit 38168b0741

@ -407,7 +407,11 @@ class DaViTStage(nn.Module):
stage_blocks.append(nn.Sequential(*dual_attention_block))
self.blocks = nn.Sequential(*stage_blocks)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
def forward(self, x : Tensor):
x = self.patch_embed(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
@ -522,6 +526,8 @@ class DaViT(nn.Module):
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
for stage in self.stages:
stage.set_grad_checkpointing(enable=enable)
@torch.jit.ignore
def get_classifier(self):
@ -535,7 +541,10 @@ class DaViT(nn.Module):
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.stages, x)
else:
x = self.stages(x)
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
return x

Loading…
Cancel
Save