diff --git a/timm/models/davit.py b/timm/models/davit.py index ee154d56..61393e59 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -399,7 +399,7 @@ class DaViT(nn.Module): self.num_features = embed_dims[-1] self.drop_rate=drop_rate self.grad_checkpointing = False - self._features_only = features_only + self._features_only = kwargs.get(features_only, False) self.feature_info = [] self.patch_embeds = nn.ModuleList([