From 8e421bccfbe9d1fc9f0cb07a1bd66af09185747e Mon Sep 17 00:00:00 2001 From: Fredo Guan Date: Wed, 7 Dec 2022 00:48:14 -0800 Subject: [PATCH] Update davit.py --- timm/models/davit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/davit.py b/timm/models/davit.py index aa006a7d..c1f16f6a 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -516,9 +516,9 @@ class DaViT(nn.Module): for block in enumerate(stage): for layer in enumerate(block): if self.grad_checkpointing and not torch.jit.is_scripting(): - features[-1], sizes[-1] = checkpoint.checkpoint(layer, features[-1], sizes[-1]) + features[-1], sizes[-1] = checkpoint.checkpoint(layer, (features[-1], sizes[-1])) else: - features[-1], sizes[-1] = layer(features[-1], sizes[-1]) + features[-1], sizes[-1] = layer((features[-1], sizes[-1])) features.append(features[-1])