Some convnext cleanup, remove in place mul_ for gamma, breaking symbolic trace, cleanup head a bit...

pull/1083/head
Ross Wightman 2 years ago
parent c767acd39f
commit 82ce70eec3

@ -9,7 +9,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the MIT license
from collections import OrderedDict
from functools import partial
import torch
@ -32,7 +32,7 @@ def _cfg(url='', **kwargs):
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.0', 'classifier': 'head',
'first_conv': 'stem.0', 'classifier': 'head.fc',
**kwargs
}
@ -43,7 +43,7 @@ default_cfgs = dict(
convnext_base=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth"),
convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"),
convnext_tiny_hnf=_cfg(url='', classifier='head.fc'),
convnext_tiny_hnf=_cfg(url=''),
convnext_base_in22k=_cfg(
url="https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", num_classes=21841),
@ -65,16 +65,12 @@ def _is_contiguous(tensor: torch.Tensor) -> bool:
@register_notrace_module
class LayerNorm2d(nn.Module):
class LayerNorm2d(nn.LayerNorm):
r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
"""
def __init__(self, normalized_shape, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.normalized_shape = (normalized_shape,)
super().__init__(normalized_shape, eps=eps)
def forward(self, x) -> torch.Tensor:
if _is_contiguous(x):
@ -105,7 +101,8 @@ class ConvNeXtBlock(nn.Module):
def __init__(self, dim, drop_path=0., ls_init_value=1e-6, conv_mlp=True, mlp_ratio=4, norm_layer=None):
super().__init__()
norm_layer = norm_layer or (partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6))
if not norm_layer:
norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
mlp_layer = ConvMlp if conv_mlp else Mlp
self.use_conv_mlp = conv_mlp
self.conv_dw = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
@ -120,15 +117,13 @@ class ConvNeXtBlock(nn.Module):
if self.use_conv_mlp:
x = self.norm(x)
x = self.mlp(x)
if self.gamma is not None:
x.mul_(self.gamma.reshape(1, -1, 1, 1))
else:
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = self.mlp(x)
if self.gamma is not None:
x.mul_(self.gamma)
x = x.permute(0, 3, 1, 2)
if self.gamma is not None:
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
x = self.drop_path(x) + shortcut
return x
@ -191,7 +186,6 @@ class ConvNeXt(nn.Module):
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
cl_norm_layer = norm_layer
partial(LayerNorm2d, eps=1e-6)
self.num_classes = num_classes
self.drop_rate = drop_rate
self.feature_info = []
@ -226,51 +220,46 @@ class ConvNeXt(nn.Module):
self.num_features = prev_chs
if head_norm_first:
# norm -> global pool -> fc ordering, like most other nets (not compat with FB weights)
self.norm = norm_layer(self.num_features) # final norm layer
self.pool = None # global pool in ClassifierHead, pool == None being used to differentiate
self.norm_pre = norm_layer(self.num_features) # final norm layer, before pooling
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=drop_rate)
else:
# pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
self.pool = SelectAdaptivePool2d(pool_type=global_pool)
# NOTE when cl_norm_layer != norm_layer we could flatten here and use cl, but makes no performance diff
self.norm = norm_layer(self.num_features)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.norm_pre = nn.Identity()
self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
('norm', norm_layer(self.num_features)),
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
]))
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
def get_classifier(self):
return self.head.fc if self.pool is None else self.head
return self.head.fc
def reset_classifier(self, num_classes=0, global_pool='avg'):
if self.pool is None:
# norm -> global pool -> fc ordering
if isinstance(self.head, ClassifierHead):
# norm -> global pool -> fc
self.head = ClassifierHead(
self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
else:
# pool -> norm -> fc
self.pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
('norm', self.head.norm),
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())
]))
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
if self.pool is None:
# standard head, norm -> spatial pool -> fc
# ideally, last norm is within forward_features, but can only do so if norm precedes pooling
x = self.norm(x)
x = self.norm_pre(x)
return x
def forward(self, x):
x = self.forward_features(x)
if self.pool is not None:
# ConvNeXt head, spatial pool -> norm -> fc
# FIXME clean this up
x = self.pool(x)
x = self.norm(x)
if not self.pool.is_identity():
x = x.flatten(1)
if self.drop_rate > 0:
x = F.dropout(x, self.drop_rate, self.training)
x = self.head(x)
return x
@ -282,7 +271,7 @@ def _init_weights(module, name=None, head_init_scale=1.0):
elif isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=.02)
nn.init.constant_(module.bias, 0)
if name and '.head' in name:
if name and 'head.' in name:
module.weight.data.mul_(head_init_scale)
module.bias.data.mul_(head_init_scale)
@ -299,6 +288,9 @@ def checkpoint_filter_fn(state_dict, model):
k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
k = k.replace('dwconv', 'conv_dw')
k = k.replace('pwconv', 'mlp.fc')
k = k.replace('head.', 'head.fc.')
if k.startswith('norm.'):
k = k.replace('norm', 'head.norm')
if v.ndim == 2 and 'head' not in k:
model_shape = model.state_dict()[k].shape
v = v.reshape(model_shape)

Loading…
Cancel
Save