|
|
|
@ -407,7 +407,11 @@ class DaViTStage(nn.Module):
|
|
|
|
|
stage_blocks.append(nn.Sequential(*dual_attention_block))
|
|
|
|
|
|
|
|
|
|
self.blocks = nn.Sequential(*stage_blocks)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def set_grad_checkpointing(self, enable=True):
|
|
|
|
|
self.grad_checkpointing = enable
|
|
|
|
|
|
|
|
|
|
def forward(self, x : Tensor):
|
|
|
|
|
x = self.patch_embed(x)
|
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
@ -522,6 +526,8 @@ class DaViT(nn.Module):
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def set_grad_checkpointing(self, enable=True):
|
|
|
|
|
self.grad_checkpointing = enable
|
|
|
|
|
for stage in self.stages:
|
|
|
|
|
stage.set_grad_checkpointing(enable=enable)
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
|
def get_classifier(self):
|
|
|
|
@ -535,7 +541,10 @@ class DaViT(nn.Module):
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
x = self.stem(x)
|
|
|
|
|
x = self.stages(x)
|
|
|
|
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
|
|
|
|
x = checkpoint_seq(self.stages, x)
|
|
|
|
|
else:
|
|
|
|
|
x = self.stages(x)
|
|
|
|
|
x = self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|