diff --git a/timm/models/davit.py b/timm/models/davit.py index 61393e59..61eeba9c 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -399,7 +399,7 @@ class DaViT(nn.Module): self.num_features = embed_dims[-1] self.drop_rate=drop_rate self.grad_checkpointing = False - self._features_only = kwargs.get(features_only, False) + self._features_only = kwargs.get('features_only', False) self.feature_info = [] self.patch_embeds = nn.ModuleList([