From 6f28b562c619c0f39bbbdd4bbba4f502677e8515 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 27 Jan 2023 14:57:01 -0800 Subject: [PATCH] Factor NormMlpClassifierHead from MaxxViT and use across MaxxViT / ConvNeXt / DaViT, refactor some type hints & comments --- timm/layers/__init__.py | 2 +- timm/layers/classifier.py | 98 +++++++++++++++++++++++++++++- timm/models/convnext.py | 122 ++++++++++++++++++++------------------ timm/models/davit.py | 31 ++++++---- timm/models/maxxvit.py | 67 +-------------------- 5 files changed, 183 insertions(+), 137 deletions(-) diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 6b2dabba..8e555b8b 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -3,7 +3,7 @@ from .adaptive_avgmax_pool import \ adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding from .blur_pool import BlurPool2d -from .classifier import ClassifierHead, create_classifier +from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead from .cond_conv2d import CondConv2d, get_condconv_initializer from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ set_layer_config diff --git a/timm/layers/classifier.py b/timm/layers/classifier.py index e885084c..d93d0ec7 100644 --- a/timm/layers/classifier.py +++ b/timm/layers/classifier.py @@ -2,10 +2,17 @@ Hacked together by / Copyright 2020 Ross Wightman """ -from torch import nn as nn +from collections import OrderedDict +from functools import partial +from typing import Optional, Union, Callable + +import torch +import torch.nn as nn from torch.nn import functional as F from .adaptive_avgmax_pool import SelectAdaptivePool2d +from .create_act import get_act_layer +from .create_norm import get_norm_layer def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): @@ -38,7 +45,21 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False class ClassifierHead(nn.Module): """Classifier head w/ configurable global pooling and dropout.""" - def __init__(self, in_features, num_classes, pool_type='avg', drop_rate=0., use_conv=False): + def __init__( + self, + in_features: int, + num_classes: int, + pool_type: str = 'avg', + drop_rate: float = 0., + use_conv: bool = False, + ): + """ + Args: + in_features: The number of input features. + num_classes: The number of classes for the final classifier layer (output). + pool_type: Global pooling type, pooling disabled if empty string (''). + drop_rate: Pre-classifier dropout rate. + """ super(ClassifierHead, self).__init__() self.drop_rate = drop_rate self.in_features = in_features @@ -65,3 +86,76 @@ class ClassifierHead(nn.Module): else: x = self.fc(x) return self.flatten(x) + + +class NormMlpClassifierHead(nn.Module): + + def __init__( + self, + in_features: int, + num_classes: int, + hidden_size: Optional[int] = None, + pool_type: str = 'avg', + drop_rate: float = 0., + norm_layer: Union[str, Callable] = 'layernorm2d', + act_layer: Union[str, Callable] = 'tanh', + ): + """ + Args: + in_features: The number of input features. + num_classes: The number of classes for the final classifier layer (output). + hidden_size: The hidden size of the MLP (pre-logits FC layer) if not None. + pool_type: Global pooling type, pooling disabled if empty string (''). + drop_rate: Pre-classifier dropout rate. + norm_layer: Normalization layer type. + act_layer: MLP activation layer type (only used if hidden_size is not None). + """ + super().__init__() + self.drop_rate = drop_rate + self.in_features = in_features + self.hidden_size = hidden_size + self.num_features = in_features + self.use_conv = not pool_type + norm_layer = get_norm_layer(norm_layer) + act_layer = get_act_layer(act_layer) + linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear + + self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) + self.norm = norm_layer(in_features) + self.flatten = nn.Flatten(1) if pool_type else nn.Identity() + if hidden_size: + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', linear_layer(in_features, hidden_size)), + ('act', act_layer()), + ])) + self.num_features = hidden_size + else: + self.pre_logits = nn.Identity() + self.drop = nn.Dropout(self.drop_rate) + self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def reset(self, num_classes, global_pool=None): + if global_pool is not None: + self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.flatten = nn.Flatten(1) if global_pool else nn.Identity() + self.use_conv = self.global_pool.is_identity() + linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear + if self.hidden_size: + if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or + (isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)): + with torch.no_grad(): + new_fc = linear_layer(self.in_features, self.hidden_size) + new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape)) + new_fc.bias.copy_(self.pre_logits.fc.bias) + self.pre_logits.fc = new_fc + self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward(self, x, pre_logits: bool = False): + x = self.global_pool(x) + x = self.norm(x) + x = self.flatten(x) + x = self.pre_logits(x) + if pre_logits: + return x + x = self.fc(x) + return x diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 2bbe0b11..1655ad34 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -39,6 +39,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W from collections import OrderedDict from functools import partial +from typing import Callable, Optional, Tuple, Union import torch import torch.nn as nn @@ -46,6 +47,7 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \ LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple +from timm.layers import NormMlpClassifierHead, ClassifierHead from ._builder import build_model_with_cfg from ._manipulate import named_apply, checkpoint_seq from ._pretrained import generate_default_cfgs @@ -188,48 +190,50 @@ class ConvNeXt(nn.Module): def __init__( self, - in_chans=3, - num_classes=1000, - global_pool='avg', - output_stride=32, - depths=(3, 3, 9, 3), - dims=(96, 192, 384, 768), - kernel_sizes=7, - ls_init_value=1e-6, - stem_type='patch', - patch_size=4, - head_init_scale=1., - head_norm_first=False, - conv_mlp=False, - conv_bias=True, - use_grn=False, - act_layer='gelu', - norm_layer=None, - norm_eps=None, - drop_rate=0., - drop_path_rate=0., + in_chans: int = 3, + num_classes: int = 1000, + global_pool: str = 'avg', + output_stride: int = 32, + depths: Tuple[int, ...] = (3, 3, 9, 3), + dims: Tuple[int, ...] = (96, 192, 384, 768), + kernel_sizes: Union[int, Tuple[int, ...]] = 7, + ls_init_value: Optional[float] = 1e-6, + stem_type: str = 'patch', + patch_size: int = 4, + head_init_scale: float = 1., + head_norm_first: bool = False, + head_hidden_size: Optional[int] = None, + conv_mlp: bool = False, + conv_bias: bool = True, + use_grn: bool = False, + act_layer: Union[str, Callable] = 'gelu', + norm_layer: Optional[Union[str, Callable]] = None, + norm_eps: Optional[float] = None, + drop_rate: float = 0., + drop_path_rate: float = 0., ): """ Args: - in_chans (int): Number of input image channels (default: 3) - num_classes (int): Number of classes for classification head (default: 1000) - global_pool (str): Global pooling type (default: 'avg') - output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32) - depths (tuple(int)): Number of blocks at each stage. (default: [3, 3, 9, 3]) - dims (tuple(int)): Feature dimension at each stage. (default: [96, 192, 384, 768]) - kernel_sizes (Union[int, List[int]]: Depthwise convolution kernel-sizes for each stage (default: 7) - ls_init_value (float): Init value for Layer Scale (default: 1e-6) - stem_type (str): Type of stem (default: 'patch') - patch_size (int): Stem patch size for patch stem (default: 4) - head_init_scale (float): Init scaling value for classifier weights and biases (default: 1) - head_norm_first (bool): Apply normalization before global pool + head (default: False) - conv_mlp (bool): Use 1x1 conv in MLP, improves speed for small networks w/ chan last (default: False) - conv_bias (bool): Use bias layers w/ all convolutions (default: True) - use_grn (bool): Use Global Response Norm (ConvNeXt-V2) in MLP (default: False) - act_layer (Union[str, nn.Module]): Activation Layer - norm_layer (Union[str, nn.Module]): Normalization Layer - drop_rate (float): Head dropout rate (default: 0.) - drop_path_rate (float): Stochastic depth rate (default: 0.) + in_chans: Number of input image channels. + num_classes: Number of classes for classification head. + global_pool: Global pooling type. + output_stride: Output stride of network, one of (8, 16, 32). + depths: Number of blocks at each stage. + dims: Feature dimension at each stage. + kernel_sizes: Depthwise convolution kernel-sizes for each stage. + ls_init_value: Init value for Layer Scale, disabled if None. + stem_type: Type of stem. + patch_size: Stem patch size for patch stem. + head_init_scale: Init scaling value for classifier weights and biases. + head_norm_first: Apply normalization before global pool + head. + head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False. + conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last. + conv_bias: Use bias layers w/ all convolutions. + use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP. + act_layer: Activation layer type. + norm_layer: Normalization layer type. + drop_rate: Head pre-classifier dropout rate. + drop_path_rate: Stochastic depth drop rate. """ super().__init__() assert output_stride in (8, 16, 32) @@ -307,14 +311,26 @@ class ConvNeXt(nn.Module): # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) - self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() - self.head = nn.Sequential(OrderedDict([ - ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), - ('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)), - ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), - ('drop', nn.Dropout(self.drop_rate)), - ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())])) - + if head_norm_first: + assert not head_hidden_size + self.norm_pre = norm_layer(self.num_features) + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + ) + else: + self.norm_pre = nn.Identity() + self.head = NormMlpClassifierHead( + self.num_features, + num_classes, + hidden_size=head_hidden_size, + pool_type=global_pool, + drop_rate=self.drop_rate, + norm_layer=norm_layer, + act_layer='gelu', + ) named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) @torch.jit.ignore @@ -338,10 +354,7 @@ class ConvNeXt(nn.Module): return self.head.fc def reset_classifier(self, num_classes=0, global_pool=None): - if global_pool is not None: - self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() - self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head.reset(num_classes, global_pool=global_pool) def forward_features(self, x): x = self.stem(x) @@ -350,12 +363,7 @@ class ConvNeXt(nn.Module): return x def forward_head(self, x, pre_logits: bool = False): - # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( - x = self.head.global_pool(x) - x = self.head.norm(x) - x = self.head.flatten(x) - x = self.head.drop(x) - return x if pre_logits else self.head.fc(x) + return self.head(x, pre_logits=pre_logits) def forward(self, x): x = self.forward_features(x) diff --git a/timm/models/davit.py b/timm/models/davit.py index 8b9e67b4..e9871265 100644 --- a/timm/models/davit.py +++ b/timm/models/davit.py @@ -23,6 +23,7 @@ from torch import Tensor from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp, LayerNorm2d, get_norm_layer +from timm.layers import NormMlpClassifierHead, ClassifierHead from ._builder import build_model_with_cfg from ._features_fx import register_notrace_function from ._manipulate import checkpoint_seq @@ -519,14 +520,23 @@ class DaViT(nn.Module): # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets # otherwise pool -> norm -> fc, the default DaViT order, similar to ConvNeXt # FIXME generalize this structure to ClassifierHead - self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() - self.head = nn.Sequential(OrderedDict([ - ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), - ('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)), - ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), - ('drop', nn.Dropout(self.drop_rate)), - ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())])) - + if head_norm_first: + self.norm_pre = norm_layer(self.num_features) + self.head = ClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + ) + else: + self.norm_pre = nn.Identity() + self.head = NormMlpClassifierHead( + self.num_features, + num_classes, + pool_type=global_pool, + drop_rate=self.drop_rate, + norm_layer=norm_layer, + ) self.apply(self._init_weights) def _init_weights(self, m): @@ -546,10 +556,7 @@ class DaViT(nn.Module): return self.head.fc def reset_classifier(self, num_classes, global_pool=None): - if global_pool is not None: - self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() - self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + self.head.reset(num_classes, global_pool=global_pool) def forward_features(self, x): x = self.stem(x) diff --git a/timm/models/maxxvit.py b/timm/models/maxxvit.py index e730fa30..f41dba8b 100644 --- a/timm/models/maxxvit.py +++ b/timm/models/maxxvit.py @@ -44,7 +44,7 @@ import torch from torch import nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD -from timm.layers import Mlp, ConvMlp, DropPath, ClassifierHead, LayerNorm, SelectAdaptivePool2d +from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf @@ -1072,69 +1072,6 @@ def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]): return cfg -class NormMlpHead(nn.Module): - - def __init__( - self, - in_features, - num_classes, - hidden_size=None, - pool_type='avg', - drop_rate=0., - norm_layer='layernorm2d', - act_layer='tanh', - ): - super().__init__() - self.drop_rate = drop_rate - self.in_features = in_features - self.hidden_size = hidden_size - self.num_features = in_features - self.use_conv = not pool_type - norm_layer = get_norm_layer(norm_layer) - act_layer = get_act_layer(act_layer) - linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear - - self.global_pool = SelectAdaptivePool2d(pool_type=pool_type) - self.norm = norm_layer(in_features) - self.flatten = nn.Flatten(1) if pool_type else nn.Identity() - if hidden_size: - self.pre_logits = nn.Sequential(OrderedDict([ - ('fc', linear_layer(in_features, hidden_size)), - ('act', act_layer()), - ])) - self.num_features = hidden_size - else: - self.pre_logits = nn.Identity() - self.drop = nn.Dropout(self.drop_rate) - self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - - def reset(self, num_classes, global_pool=None): - if global_pool is not None: - self.global_pool = SelectAdaptivePool2d(pool_type=global_pool) - self.flatten = nn.Flatten(1) if global_pool else nn.Identity() - self.use_conv = self.global_pool.is_identity() - linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear - if self.hidden_size: - if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or - (isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)): - with torch.no_grad(): - new_fc = linear_layer(self.in_features, self.hidden_size) - new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape)) - new_fc.bias.copy_(self.pre_logits.fc.bias) - self.pre_logits.fc = new_fc - self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity() - - def forward(self, x, pre_logits: bool = False): - x = self.global_pool(x) - x = self.norm(x) - x = self.flatten(x) - x = self.pre_logits(x) - if pre_logits: - return x - x = self.fc(x) - return x - - def _overlay_kwargs(cfg: MaxxVitCfg, **kwargs): transformer_kwargs = {} conv_kwargs = {} @@ -1225,7 +1162,7 @@ class MaxxVit(nn.Module): self.head_hidden_size = cfg.head_hidden_size if self.head_hidden_size: self.norm = nn.Identity() - self.head = NormMlpHead( + self.head = NormMlpClassifierHead( self.num_features, num_classes, hidden_size=self.head_hidden_size,