diff --git a/timm/models/davit.py b/timm/models/davit.py index 7275aa33..8d288bd0 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -35,9 +35,9 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD __all__ = ['DaViT'] # modified nn.Sequential that includes a size tuple in the forward function +# FIXME doesn't work with torchscript/JIT class SequentialWithSize(nn.Sequential): - @torch.jit.ignore def forward(self, x : Tensor, size: Tuple[int, int]): for module in self._modules.values(): x, size = module(x, size) @@ -434,7 +434,6 @@ class DaViTStage(nn.Module): self.blocks = SequentialWithSize(*stage_blocks) - @torch.jit.ignore def forward(self, x : Tensor, size: Tuple[int, int]): x, size = self.patch_embed(x, size) if self.grad_checkpointing and not torch.jit.is_scripting():