Factor NormMlpClassifierHead from MaxxViT and use across MaxxViT / ConvNeXt / DaViT, refactor some type hints & comments

pull/1645/head
Ross Wightman 1 year ago
parent 29fda20e6d
commit 6f28b562c6

@ -3,7 +3,7 @@ from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
from .blur_pool import BlurPool2d from .blur_pool import BlurPool2d
from .classifier import ClassifierHead, create_classifier from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead
from .cond_conv2d import CondConv2d, get_condconv_initializer from .cond_conv2d import CondConv2d, get_condconv_initializer
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
set_layer_config set_layer_config

@ -2,10 +2,17 @@
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
from torch import nn as nn from collections import OrderedDict
from functools import partial
from typing import Optional, Union, Callable
import torch
import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from .adaptive_avgmax_pool import SelectAdaptivePool2d from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .create_act import get_act_layer
from .create_norm import get_norm_layer
def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False): def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
@ -38,7 +45,21 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False
class ClassifierHead(nn.Module): class ClassifierHead(nn.Module):
"""Classifier head w/ configurable global pooling and dropout.""" """Classifier head w/ configurable global pooling and dropout."""
def __init__(self, in_features, num_classes, pool_type='avg', drop_rate=0., use_conv=False): def __init__(
self,
in_features: int,
num_classes: int,
pool_type: str = 'avg',
drop_rate: float = 0.,
use_conv: bool = False,
):
"""
Args:
in_features: The number of input features.
num_classes: The number of classes for the final classifier layer (output).
pool_type: Global pooling type, pooling disabled if empty string ('').
drop_rate: Pre-classifier dropout rate.
"""
super(ClassifierHead, self).__init__() super(ClassifierHead, self).__init__()
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.in_features = in_features self.in_features = in_features
@ -65,3 +86,76 @@ class ClassifierHead(nn.Module):
else: else:
x = self.fc(x) x = self.fc(x)
return self.flatten(x) return self.flatten(x)
class NormMlpClassifierHead(nn.Module):
def __init__(
self,
in_features: int,
num_classes: int,
hidden_size: Optional[int] = None,
pool_type: str = 'avg',
drop_rate: float = 0.,
norm_layer: Union[str, Callable] = 'layernorm2d',
act_layer: Union[str, Callable] = 'tanh',
):
"""
Args:
in_features: The number of input features.
num_classes: The number of classes for the final classifier layer (output).
hidden_size: The hidden size of the MLP (pre-logits FC layer) if not None.
pool_type: Global pooling type, pooling disabled if empty string ('').
drop_rate: Pre-classifier dropout rate.
norm_layer: Normalization layer type.
act_layer: MLP activation layer type (only used if hidden_size is not None).
"""
super().__init__()
self.drop_rate = drop_rate
self.in_features = in_features
self.hidden_size = hidden_size
self.num_features = in_features
self.use_conv = not pool_type
norm_layer = get_norm_layer(norm_layer)
act_layer = get_act_layer(act_layer)
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
self.norm = norm_layer(in_features)
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
if hidden_size:
self.pre_logits = nn.Sequential(OrderedDict([
('fc', linear_layer(in_features, hidden_size)),
('act', act_layer()),
]))
self.num_features = hidden_size
else:
self.pre_logits = nn.Identity()
self.drop = nn.Dropout(self.drop_rate)
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def reset(self, num_classes, global_pool=None):
if global_pool is not None:
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.use_conv = self.global_pool.is_identity()
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
if self.hidden_size:
if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or
(isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)):
with torch.no_grad():
new_fc = linear_layer(self.in_features, self.hidden_size)
new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape))
new_fc.bias.copy_(self.pre_logits.fc.bias)
self.pre_logits.fc = new_fc
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x, pre_logits: bool = False):
x = self.global_pool(x)
x = self.norm(x)
x = self.flatten(x)
x = self.pre_logits(x)
if pre_logits:
return x
x = self.fc(x)
return x

@ -39,6 +39,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
from collections import OrderedDict from collections import OrderedDict
from functools import partial from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -46,6 +47,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \ from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
from timm.layers import NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq from ._manipulate import named_apply, checkpoint_seq
from ._pretrained import generate_default_cfgs from ._pretrained import generate_default_cfgs
@ -188,48 +190,50 @@ class ConvNeXt(nn.Module):
def __init__( def __init__(
self, self,
in_chans=3, in_chans: int = 3,
num_classes=1000, num_classes: int = 1000,
global_pool='avg', global_pool: str = 'avg',
output_stride=32, output_stride: int = 32,
depths=(3, 3, 9, 3), depths: Tuple[int, ...] = (3, 3, 9, 3),
dims=(96, 192, 384, 768), dims: Tuple[int, ...] = (96, 192, 384, 768),
kernel_sizes=7, kernel_sizes: Union[int, Tuple[int, ...]] = 7,
ls_init_value=1e-6, ls_init_value: Optional[float] = 1e-6,
stem_type='patch', stem_type: str = 'patch',
patch_size=4, patch_size: int = 4,
head_init_scale=1., head_init_scale: float = 1.,
head_norm_first=False, head_norm_first: bool = False,
conv_mlp=False, head_hidden_size: Optional[int] = None,
conv_bias=True, conv_mlp: bool = False,
use_grn=False, conv_bias: bool = True,
act_layer='gelu', use_grn: bool = False,
norm_layer=None, act_layer: Union[str, Callable] = 'gelu',
norm_eps=None, norm_layer: Optional[Union[str, Callable]] = None,
drop_rate=0., norm_eps: Optional[float] = None,
drop_path_rate=0., drop_rate: float = 0.,
drop_path_rate: float = 0.,
): ):
""" """
Args: Args:
in_chans (int): Number of input image channels (default: 3) in_chans: Number of input image channels.
num_classes (int): Number of classes for classification head (default: 1000) num_classes: Number of classes for classification head.
global_pool (str): Global pooling type (default: 'avg') global_pool: Global pooling type.
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32) output_stride: Output stride of network, one of (8, 16, 32).
depths (tuple(int)): Number of blocks at each stage. (default: [3, 3, 9, 3]) depths: Number of blocks at each stage.
dims (tuple(int)): Feature dimension at each stage. (default: [96, 192, 384, 768]) dims: Feature dimension at each stage.
kernel_sizes (Union[int, List[int]]: Depthwise convolution kernel-sizes for each stage (default: 7) kernel_sizes: Depthwise convolution kernel-sizes for each stage.
ls_init_value (float): Init value for Layer Scale (default: 1e-6) ls_init_value: Init value for Layer Scale, disabled if None.
stem_type (str): Type of stem (default: 'patch') stem_type: Type of stem.
patch_size (int): Stem patch size for patch stem (default: 4) patch_size: Stem patch size for patch stem.
head_init_scale (float): Init scaling value for classifier weights and biases (default: 1) head_init_scale: Init scaling value for classifier weights and biases.
head_norm_first (bool): Apply normalization before global pool + head (default: False) head_norm_first: Apply normalization before global pool + head.
conv_mlp (bool): Use 1x1 conv in MLP, improves speed for small networks w/ chan last (default: False) head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
conv_bias (bool): Use bias layers w/ all convolutions (default: True) conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
use_grn (bool): Use Global Response Norm (ConvNeXt-V2) in MLP (default: False) conv_bias: Use bias layers w/ all convolutions.
act_layer (Union[str, nn.Module]): Activation Layer use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
norm_layer (Union[str, nn.Module]): Normalization Layer act_layer: Activation layer type.
drop_rate (float): Head dropout rate (default: 0.) norm_layer: Normalization layer type.
drop_path_rate (float): Stochastic depth rate (default: 0.) drop_rate: Head pre-classifier dropout rate.
drop_path_rate: Stochastic depth drop rate.
""" """
super().__init__() super().__init__()
assert output_stride in (8, 16, 32) assert output_stride in (8, 16, 32)
@ -307,14 +311,26 @@ class ConvNeXt(nn.Module):
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() if head_norm_first:
self.head = nn.Sequential(OrderedDict([ assert not head_hidden_size
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), self.norm_pre = norm_layer(self.num_features)
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)), self.head = ClassifierHead(
('flatten', nn.Flatten(1) if global_pool else nn.Identity()), self.num_features,
('drop', nn.Dropout(self.drop_rate)), num_classes,
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())])) pool_type=global_pool,
drop_rate=self.drop_rate,
)
else:
self.norm_pre = nn.Identity()
self.head = NormMlpClassifierHead(
self.num_features,
num_classes,
hidden_size=head_hidden_size,
pool_type=global_pool,
drop_rate=self.drop_rate,
norm_layer=norm_layer,
act_layer='gelu',
)
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
@torch.jit.ignore @torch.jit.ignore
@ -338,10 +354,7 @@ class ConvNeXt(nn.Module):
return self.head.fc return self.head.fc
def reset_classifier(self, num_classes=0, global_pool=None): def reset_classifier(self, num_classes=0, global_pool=None):
if global_pool is not None: self.head.reset(num_classes, global_pool=global_pool)
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
@ -350,12 +363,7 @@ class ConvNeXt(nn.Module):
return x return x
def forward_head(self, x, pre_logits: bool = False): def forward_head(self, x, pre_logits: bool = False):
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( return self.head(x, pre_logits=pre_logits)
x = self.head.global_pool(x)
x = self.head.norm(x)
x = self.head.flatten(x)
x = self.head.drop(x)
return x if pre_logits else self.head.fc(x)
def forward(self, x): def forward(self, x):
x = self.forward_features(x) x = self.forward_features(x)

@ -23,6 +23,7 @@ from torch import Tensor
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp, LayerNorm2d, get_norm_layer from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp, LayerNorm2d, get_norm_layer
from timm.layers import NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq from ._manipulate import checkpoint_seq
@ -519,14 +520,23 @@ class DaViT(nn.Module):
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
# otherwise pool -> norm -> fc, the default DaViT order, similar to ConvNeXt # otherwise pool -> norm -> fc, the default DaViT order, similar to ConvNeXt
# FIXME generalize this structure to ClassifierHead # FIXME generalize this structure to ClassifierHead
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() if head_norm_first:
self.head = nn.Sequential(OrderedDict([ self.norm_pre = norm_layer(self.num_features)
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), self.head = ClassifierHead(
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)), self.num_features,
('flatten', nn.Flatten(1) if global_pool else nn.Identity()), num_classes,
('drop', nn.Dropout(self.drop_rate)), pool_type=global_pool,
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())])) drop_rate=self.drop_rate,
)
else:
self.norm_pre = nn.Identity()
self.head = NormMlpClassifierHead(
self.num_features,
num_classes,
pool_type=global_pool,
drop_rate=self.drop_rate,
norm_layer=norm_layer,
)
self.apply(self._init_weights) self.apply(self._init_weights)
def _init_weights(self, m): def _init_weights(self, m):
@ -546,10 +556,7 @@ class DaViT(nn.Module):
return self.head.fc return self.head.fc
def reset_classifier(self, num_classes, global_pool=None): def reset_classifier(self, num_classes, global_pool=None):
if global_pool is not None: self.head.reset(num_classes, global_pool=global_pool)
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)

@ -44,7 +44,7 @@ import torch
from torch import nn from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, ConvMlp, DropPath, ClassifierHead, LayerNorm, SelectAdaptivePool2d from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead
from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d
from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert
from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf
@ -1072,69 +1072,6 @@ def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]):
return cfg return cfg
class NormMlpHead(nn.Module):
def __init__(
self,
in_features,
num_classes,
hidden_size=None,
pool_type='avg',
drop_rate=0.,
norm_layer='layernorm2d',
act_layer='tanh',
):
super().__init__()
self.drop_rate = drop_rate
self.in_features = in_features
self.hidden_size = hidden_size
self.num_features = in_features
self.use_conv = not pool_type
norm_layer = get_norm_layer(norm_layer)
act_layer = get_act_layer(act_layer)
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
self.norm = norm_layer(in_features)
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
if hidden_size:
self.pre_logits = nn.Sequential(OrderedDict([
('fc', linear_layer(in_features, hidden_size)),
('act', act_layer()),
]))
self.num_features = hidden_size
else:
self.pre_logits = nn.Identity()
self.drop = nn.Dropout(self.drop_rate)
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def reset(self, num_classes, global_pool=None):
if global_pool is not None:
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.use_conv = self.global_pool.is_identity()
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
if self.hidden_size:
if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or
(isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)):
with torch.no_grad():
new_fc = linear_layer(self.in_features, self.hidden_size)
new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape))
new_fc.bias.copy_(self.pre_logits.fc.bias)
self.pre_logits.fc = new_fc
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x, pre_logits: bool = False):
x = self.global_pool(x)
x = self.norm(x)
x = self.flatten(x)
x = self.pre_logits(x)
if pre_logits:
return x
x = self.fc(x)
return x
def _overlay_kwargs(cfg: MaxxVitCfg, **kwargs): def _overlay_kwargs(cfg: MaxxVitCfg, **kwargs):
transformer_kwargs = {} transformer_kwargs = {}
conv_kwargs = {} conv_kwargs = {}
@ -1225,7 +1162,7 @@ class MaxxVit(nn.Module):
self.head_hidden_size = cfg.head_hidden_size self.head_hidden_size = cfg.head_hidden_size
if self.head_hidden_size: if self.head_hidden_size:
self.norm = nn.Identity() self.norm = nn.Identity()
self.head = NormMlpHead( self.head = NormMlpClassifierHead(
self.num_features, self.num_features,
num_classes, num_classes,
hidden_size=self.head_hidden_size, hidden_size=self.head_hidden_size,

Loading…
Cancel
Save