From 2e83bba1422295f177c2681d5835012633cbae19 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 9 Jan 2023 13:37:40 -0800 Subject: [PATCH] Revert head norm changes to ConvNeXt as it broke some downstream use, alternate workaround for fcmae weights --- timm/models/convnext.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/timm/models/convnext.py b/timm/models/convnext.py index e799a7de..11c061f6 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -301,11 +301,10 @@ class ConvNeXt(nn.Module): # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) - self.head_norm_first = head_norm_first self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() self.head = nn.Sequential(OrderedDict([ ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), - ('norm', nn.Identity() if head_norm_first or num_classes == 0 else norm_layer(self.num_features)), + ('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)), ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), ('drop', nn.Dropout(self.drop_rate)), ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())])) @@ -336,14 +335,7 @@ class ConvNeXt(nn.Module): if global_pool is not None: self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() - if num_classes == 0: - self.head.norm = nn.Identity() - self.head.fc = nn.Identity() - else: - if not self.head_norm_first: - norm_layer = type(self.stem[-1]) # obtain type from stem norm - self.head.norm = norm_layer(self.num_features) - self.head.fc = nn.Linear(self.num_features, num_classes) + self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.stem(x) @@ -407,6 +399,11 @@ def checkpoint_filter_fn(state_dict, model): def _create_convnext(variant, pretrained=False, **kwargs): + if kwargs.get('pretrained_cfg', '') == 'fcmae': + # NOTE fcmae pretrained weights have no classifier or final norm-layer (`head.norm`) + # This is workaround loading with num_classes=0 w/o removing norm-layer. + kwargs.setdefault('pretrained_strict', False) + model = build_model_with_cfg( ConvNeXt, variant, pretrained, pretrained_filter_fn=checkpoint_filter_fn,