diff --git a/timm/models/davit.py b/timm/models/davit.py index 93ef42c9..cf2461c3 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -524,7 +524,7 @@ class DaViT(nn.Module): class DaViTFeatures(DaViT): - def __init__(*args): + def __init__(*args, **kwargs): super(DaViT, self).__init__(*args, **kwargs) self.feature_info = FeatureInfo(self.feature_info, kwargs.get('out_inices', (1, 2, 3, 4)))