diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 5f345135..c28b189e 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -1,4 +1,4 @@ -""" ConvNext +""" ConvNeXt Paper: `A ConvNet for the 2020s` - https://arxiv.org/pdf/2201.03545.pdf @@ -229,6 +229,7 @@ class ConvNeXt(nn.Module): ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), ('norm', norm_layer(self.num_features)), ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), + ('drop', nn.Dropout(self.drop_rate)), ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) ]))