|
|
|
@ -301,11 +301,10 @@ class ConvNeXt(nn.Module):
|
|
|
|
|
|
|
|
|
|
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
|
|
|
|
|
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
|
|
|
|
|
self.head_norm_first = head_norm_first
|
|
|
|
|
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
|
|
|
|
|
self.head = nn.Sequential(OrderedDict([
|
|
|
|
|
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
|
|
|
|
|
('norm', nn.Identity() if head_norm_first or num_classes == 0 else norm_layer(self.num_features)),
|
|
|
|
|
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
|
|
|
|
|
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
|
|
|
|
|
('drop', nn.Dropout(self.drop_rate)),
|
|
|
|
|
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
|
|
|
|
@ -336,14 +335,7 @@ class ConvNeXt(nn.Module):
|
|
|
|
|
if global_pool is not None:
|
|
|
|
|
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
|
|
|
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
|
|
|
|
if num_classes == 0:
|
|
|
|
|
self.head.norm = nn.Identity()
|
|
|
|
|
self.head.fc = nn.Identity()
|
|
|
|
|
else:
|
|
|
|
|
if not self.head_norm_first:
|
|
|
|
|
norm_layer = type(self.stem[-1]) # obtain type from stem norm
|
|
|
|
|
self.head.norm = norm_layer(self.num_features)
|
|
|
|
|
self.head.fc = nn.Linear(self.num_features, num_classes)
|
|
|
|
|
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
|
x = self.stem(x)
|
|
|
|
@ -407,6 +399,11 @@ def checkpoint_filter_fn(state_dict, model):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _create_convnext(variant, pretrained=False, **kwargs):
|
|
|
|
|
if kwargs.get('pretrained_cfg', '') == 'fcmae':
|
|
|
|
|
# NOTE fcmae pretrained weights have no classifier or final norm-layer (`head.norm`)
|
|
|
|
|
# This is workaround loading with num_classes=0 w/o removing norm-layer.
|
|
|
|
|
kwargs.setdefault('pretrained_strict', False)
|
|
|
|
|
|
|
|
|
|
model = build_model_with_cfg(
|
|
|
|
|
ConvNeXt, variant, pretrained,
|
|
|
|
|
pretrained_filter_fn=checkpoint_filter_fn,
|
|
|
|
|