From 03c779f1cfda2d0607606cb119a49564c018b9fa Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Sat, 10 Dec 2022 05:05:00 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 7275aa33..8d288bd0 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -35,9 +35,9 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['DaViT'] # modified nn.Sequential that includes a size tuple in the forward function +# FIXME doesn't work with torchscript/JIT class SequentialWithSize(nn.Sequential): - @torch.jit.ignore def forward(self, x : Tensor, size: Tuple[int, int]): for module in self._modules.values(): x, size = module(x, size) @@ -434,7 +434,6 @@ class DaViTStage(nn.Module): self.blocks = SequentialWithSize(*stage_blocks) - @torch.jit.ignore def forward(self, x : Tensor, size: Tuple[int, int]): x, size = self.patch_embed(x, size) if self.grad_checkpointing and not torch.jit.is_scripting():