From 82ce70eec33eb0be96d6e55fcd46d9e1a1c82f77 Mon Sep 17 00:00:00 2001
From: Ross Wightman <rwightman@gmail.com>
Date: Thu, 13 Jan 2022 21:10:32 -0800
Subject: [PATCH] Some convnext cleanup, remove in place mul_ for gamma,
 breaking symbolic trace, cleanup head a bit...

---
 timm/models/convnext.py | 72 ++++++++++++++++++-----------------------
 1 file changed, 32 insertions(+), 40 deletions(-)

diff --git a/timm/models/convnext.py b/timm/models/convnext.py
index aa8112cb..5f345135 100644
--- a/timm/models/convnext.py
+++ b/timm/models/convnext.py
@@ -9,7 +9,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # All rights reserved.
 # This source code is licensed under the MIT license
-
+from collections import OrderedDict
 from functools import partial
 
 import torch
@@ -32,7 +32,7 @@ def _cfg(url='', **kwargs):
         'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
         'crop_pct': 0.875, 'interpolation': 'bicubic',
         'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
-        'first_conv': 'stem.0', 'classifier': 'head',
+        'first_conv': 'stem.0', 'classifier': 'head.fc',
         **kwargs
     }
 
@@ -43,7 +43,7 @@ default_cfgs = dict(
     convnext_base=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth"),
     convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"),
 
-    convnext_tiny_hnf=_cfg(url='', classifier='head.fc'),
+    convnext_tiny_hnf=_cfg(url=''),
 
     convnext_base_in22k=_cfg(
         url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841),
@@ -65,16 +65,12 @@ def _is_contiguous(tensor: torch.Tensor) -> bool:
 
 
 @register_notrace_module
-class LayerNorm2d(nn.Module):
+class LayerNorm2d(nn.LayerNorm):
     r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
     """
 
     def __init__(self, normalized_shape, eps=1e-6):
-        super().__init__()
-        self.weight = nn.Parameter(torch.ones(normalized_shape))
-        self.bias = nn.Parameter(torch.zeros(normalized_shape))
-        self.eps = eps
-        self.normalized_shape = (normalized_shape,)
+        super().__init__(normalized_shape, eps=eps)
 
     def forward(self, x) -> torch.Tensor:
         if _is_contiguous(x):
@@ -105,7 +101,8 @@ class ConvNeXtBlock(nn.Module):
 
     def __init__(self, dim, drop_path=0., ls_init_value=1e-6, conv_mlp=True, mlp_ratio=4, norm_layer=None):
         super().__init__()
-        norm_layer = norm_layer or (partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6))
+        if not norm_layer:
+            norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
         mlp_layer = ConvMlp if conv_mlp else Mlp
         self.use_conv_mlp = conv_mlp
         self.conv_dw = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)  # depthwise conv
@@ -120,15 +117,13 @@ class ConvNeXtBlock(nn.Module):
         if self.use_conv_mlp:
             x = self.norm(x)
             x = self.mlp(x)
-            if self.gamma is not None:
-                x.mul_(self.gamma.reshape(1, -1, 1, 1))
         else:
             x = x.permute(0, 2, 3, 1)
             x = self.norm(x)
             x = self.mlp(x)
-            if self.gamma is not None:
-                x.mul_(self.gamma)
             x = x.permute(0, 3, 1, 2)
+        if self.gamma is not None:
+            x = x.mul(self.gamma.reshape(1, -1, 1, 1))
         x = self.drop_path(x) + shortcut
         return x
 
@@ -191,7 +186,6 @@ class ConvNeXt(nn.Module):
                 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
             cl_norm_layer = norm_layer
 
-        partial(LayerNorm2d, eps=1e-6)
         self.num_classes = num_classes
         self.drop_rate = drop_rate
         self.feature_info = []
@@ -226,51 +220,46 @@ class ConvNeXt(nn.Module):
         self.num_features = prev_chs
         if head_norm_first:
             # norm -> global pool -> fc ordering, like most other nets (not compat with FB weights)
-            self.norm = norm_layer(self.num_features)  # final norm layer
-            self.pool = None  # global pool in ClassifierHead, pool == None being used to differentiate
+            self.norm_pre = norm_layer(self.num_features)  # final norm layer, before pooling
             self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
         else:
             # pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
-            self.pool = SelectAdaptivePool2d(pool_type=global_pool)
-            # NOTE when cl_norm_layer != norm_layer we could flatten here and use cl, but makes no performance diff
-            self.norm = norm_layer(self.num_features)
-            self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+            self.norm_pre = nn.Identity()
+            self.head = nn.Sequential(OrderedDict([
+                ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
+                ('norm', norm_layer(self.num_features)),
+                ('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
+                ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
+            ]))
 
         named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
 
     def get_classifier(self):
-        return self.head.fc if self.pool is None else self.head
+        return self.head.fc
 
     def reset_classifier(self, num_classes=0, global_pool='avg'):
-        if self.pool is None:
-            # norm -> global pool -> fc ordering
+        if isinstance(self.head, ClassifierHead):
+            # norm -> global pool -> fc
             self.head = ClassifierHead(
                 self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
         else:
             # pool -> norm -> fc
-            self.pool = SelectAdaptivePool2d(pool_type=global_pool)
-            self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+            self.head = nn.Sequential(OrderedDict([
+                ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
+                ('norm', self.head.norm),
+                ('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
+                ('drop', nn.Dropout(self.drop_rate)),
+                ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
+            ]))
 
     def forward_features(self, x):
         x = self.stem(x)
         x = self.stages(x)
-        if self.pool is None:
-            # standard head, norm -> spatial pool -> fc
-            # ideally, last norm is within forward_features, but can only do so if norm precedes pooling
-            x = self.norm(x)
+        x = self.norm_pre(x)
         return x
 
     def forward(self, x):
         x = self.forward_features(x)
-        if self.pool is not None:
-            # ConvNeXt head, spatial pool -> norm -> fc
-            # FIXME clean this up
-            x = self.pool(x)
-            x = self.norm(x)
-            if not self.pool.is_identity():
-                x = x.flatten(1)
-            if self.drop_rate > 0:
-                x = F.dropout(x, self.drop_rate, self.training)
         x = self.head(x)
         return x
 
@@ -282,7 +271,7 @@ def _init_weights(module, name=None, head_init_scale=1.0):
     elif isinstance(module, nn.Linear):
         trunc_normal_(module.weight, std=.02)
         nn.init.constant_(module.bias, 0)
-        if name and '.head' in name:
+        if name and 'head.' in name:
             module.weight.data.mul_(head_init_scale)
             module.bias.data.mul_(head_init_scale)
 
@@ -299,6 +288,9 @@ def checkpoint_filter_fn(state_dict, model):
         k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
         k = k.replace('dwconv', 'conv_dw')
         k = k.replace('pwconv', 'mlp.fc')
+        k = k.replace('head.', 'head.fc.')
+        if k.startswith('norm.'):
+            k = k.replace('norm', 'head.norm')
         if v.ndim == 2 and 'head' not in k:
             model_shape = model.state_dict()[k].shape
             v = v.reshape(model_shape)