diff --git a/timm/models/davit.py b/timm/models/davit.py index b9bcc8b9..06f65020 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -382,7 +382,7 @@ class DaViT(nn.Module): attn_drop_rate=0., img_size=224, num_classes=1000, - global_pool='avg' + global_pool='avg', features_only = False ): super().__init__()