diff --git a/timm/models/metaformers.py b/timm/models/metaformers.py index 1cb13141..30cba4cc 100644 --- a/timm/models/metaformers.py +++ b/timm/models/metaformers.py @@ -546,7 +546,7 @@ class MetaFormer(nn.Module): @torch.jit.ignore def set_grad_checkpointing(self, enable=True): - print("not implemented") + self.grad_checkpointing = enable @torch.jit.ignore def get_classifier(self): @@ -574,7 +574,10 @@ class MetaFormer(nn.Module): def forward_features(self, x): x = self.stem(x) - x = self.stages(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.stages, x) + else: + x = self.stages(x) x = self.norm_pre(x) return x @@ -583,6 +586,8 @@ class MetaFormer(nn.Module): x = self.forward_head(x) return x +# FIXME convert to group matcher +# this works but it's long and breaks backwards compatability with weights from the poolformer-only impl def checkpoint_filter_fn(state_dict, model): import re out_dict = {} @@ -817,6 +822,7 @@ default_cfgs = generate_default_cfgs({ classifier='head.fc.fc2', num_classes=21841), }) +# FIXME fully merge poolformerv1, rename to poolformer to succeed poolformer.py @register_model def poolformerv1_s12(pretrained=False, **kwargs):