From b093dcb46dd2459b6e8af120ad71c9a2b11d6fcd Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 13 Jan 2022 21:10:32 -0800 Subject: [PATCH] Some convnext cleanup, remove in place mul_ for gamma, breaking symbolic trace, cleanup head a bit... --- timm/models/convnext.py | 72 ++++++++++++++++++----------------------- 1 file changed, 32 insertions(+), 40 deletions(-) diff --git a/timm/models/convnext.py b/timm/models/convnext.py index aa8112cb..5f345135 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -9,7 +9,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # This source code is licensed under the MIT license - +from collections import OrderedDict from functools import partial import torch @@ -32,7 +32,7 @@ def _cfg(url='', **kwargs): 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), 'crop_pct': 0.875, 'interpolation': 'bicubic', 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, - 'first_conv': 'stem.0', 'classifier': 'head', + 'first_conv': 'stem.0', 'classifier': 'head.fc', **kwargs } @@ -43,7 +43,7 @@ default_cfgs = dict( convnext_base=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth"), convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"), - convnext_tiny_hnf=_cfg(url='', classifier='head.fc'), + convnext_tiny_hnf=_cfg(url=''), convnext_base_in22k=_cfg( url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841), @@ -65,16 +65,12 @@ def _is_contiguous(tensor: torch.Tensor) -> bool: @register_notrace_module -class LayerNorm2d(nn.Module): +class LayerNorm2d(nn.LayerNorm): r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). """ def __init__(self, normalized_shape, eps=1e-6): - super().__init__() - self.weight = nn.Parameter(torch.ones(normalized_shape)) - self.bias = nn.Parameter(torch.zeros(normalized_shape)) - self.eps = eps - self.normalized_shape = (normalized_shape,) + super().__init__(normalized_shape, eps=eps) def forward(self, x) -> torch.Tensor: if _is_contiguous(x): @@ -105,7 +101,8 @@ class ConvNeXtBlock(nn.Module): def __init__(self, dim, drop_path=0., ls_init_value=1e-6, conv_mlp=True, mlp_ratio=4, norm_layer=None): super().__init__() - norm_layer = norm_layer or (partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6)) + if not norm_layer: + norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6) mlp_layer = ConvMlp if conv_mlp else Mlp self.use_conv_mlp = conv_mlp self.conv_dw = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv @@ -120,15 +117,13 @@ class ConvNeXtBlock(nn.Module): if self.use_conv_mlp: x = self.norm(x) x = self.mlp(x) - if self.gamma is not None: - x.mul_(self.gamma.reshape(1, -1, 1, 1)) else: x = x.permute(0, 2, 3, 1) x = self.norm(x) x = self.mlp(x) - if self.gamma is not None: - x.mul_(self.gamma) x = x.permute(0, 3, 1, 2) + if self.gamma is not None: + x = x.mul(self.gamma.reshape(1, -1, 1, 1)) x = self.drop_path(x) + shortcut return x @@ -191,7 +186,6 @@ class ConvNeXt(nn.Module): 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' cl_norm_layer = norm_layer - partial(LayerNorm2d, eps=1e-6) self.num_classes = num_classes self.drop_rate = drop_rate self.feature_info = [] @@ -226,51 +220,46 @@ class ConvNeXt(nn.Module): self.num_features = prev_chs if head_norm_first: # norm -> global pool -> fc ordering, like most other nets (not compat with FB weights) - self.norm = norm_layer(self.num_features) # final norm layer - self.pool = None # global pool in ClassifierHead, pool == None being used to differentiate + self.norm_pre = norm_layer(self.num_features) # final norm layer, before pooling self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate) else: # pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) - self.pool = SelectAdaptivePool2d(pool_type=global_pool) - # NOTE when cl_norm_layer != norm_layer we could flatten here and use cl, but makes no performance diff - self.norm = norm_layer(self.num_features) - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.norm_pre = nn.Identity() + self.head = nn.Sequential(OrderedDict([ + ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), + ('norm', norm_layer(self.num_features)), + ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), + ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) + ])) named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) def get_classifier(self): - return self.head.fc if self.pool is None else self.head + return self.head.fc def reset_classifier(self, num_classes=0, global_pool='avg'): - if self.pool is None: - # norm -> global pool -> fc ordering + if isinstance(self.head, ClassifierHead): + # norm -> global pool -> fc self.head = ClassifierHead( self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) else: # pool -> norm -> fc - self.pool = SelectAdaptivePool2d(pool_type=global_pool) - self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head = nn.Sequential(OrderedDict([ + ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), + ('norm', self.head.norm), + ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), + ('drop', nn.Dropout(self.drop_rate)), + ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()) + ])) def forward_features(self, x): x = self.stem(x) x = self.stages(x) - if self.pool is None: - # standard head, norm -> spatial pool -> fc - # ideally, last norm is within forward_features, but can only do so if norm precedes pooling - x = self.norm(x) + x = self.norm_pre(x) return x def forward(self, x): x = self.forward_features(x) - if self.pool is not None: - # ConvNeXt head, spatial pool -> norm -> fc - # FIXME clean this up - x = self.pool(x) - x = self.norm(x) - if not self.pool.is_identity(): - x = x.flatten(1) - if self.drop_rate > 0: - x = F.dropout(x, self.drop_rate, self.training) x = self.head(x) return x @@ -282,7 +271,7 @@ def _init_weights(module, name=None, head_init_scale=1.0): elif isinstance(module, nn.Linear): trunc_normal_(module.weight, std=.02) nn.init.constant_(module.bias, 0) - if name and '.head' in name: + if name and 'head.' in name: module.weight.data.mul_(head_init_scale) module.bias.data.mul_(head_init_scale) @@ -299,6 +288,9 @@ def checkpoint_filter_fn(state_dict, model): k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k) k = k.replace('dwconv', 'conv_dw') k = k.replace('pwconv', 'mlp.fc') + k = k.replace('head.', 'head.fc.') + if k.startswith('norm.'): + k = k.replace('norm', 'head.norm') if v.ndim == 2 and 'head' not in k: model_shape = model.state_dict()[k].shape v = v.reshape(model_shape)