From f8c5150fcdb2b56d1fce85484249b094ab69f509 Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 7 Dec 2022 01:04:09 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index 440c912d..bef04e87 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -35,6 +35,7 @@ from .registry import register_model __all__ = ['DaViT'] +''' class MySequential(nn.Sequential): def forward(self, inputs : Tuple[Tensor, Tuple[int, int]]): for module in self: @@ -45,7 +46,7 @@ class MySequential(nn.Sequential): # inputs = module(inputs) return inputs - +''' ''' class MySequential(nn.Sequential): @overload @@ -517,9 +518,9 @@ class DaViT(nn.Module): for _, layer in enumerate(block): print(layer) if self.grad_checkpointing and not torch.jit.is_scripting(): - features[-1], sizes[-1] = checkpoint.checkpoint(layer, (features[-1], sizes[-1])) + features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1]) else: - features[-1], sizes[-1] = layer((features[-1], sizes[-1])) + features[-1], sizes[-1] = layer(features[-1], sizes[-1]) features.append(features[-1])