Merge remote-tracking branch 'origin/main' into levit_efficientformer_redux

pull/1655/head
Ross Wightman 1 year ago
commit 9d03c6f526

@ -3,7 +3,7 @@ from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
from .blur_pool import BlurPool2d
from .classifier import ClassifierHead, create_classifier
from .classifier import ClassifierHead, create_classifier, NormMlpClassifierHead
from .cond_conv2d import CondConv2d, get_condconv_initializer
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
set_layer_config

@ -2,10 +2,17 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
from torch import nn as nn
from collections import OrderedDict
from functools import partial
from typing import Optional, Union, Callable
import torch
import torch.nn as nn
from torch.nn import functional as F
from .adaptive_avgmax_pool import SelectAdaptivePool2d
from .create_act import get_act_layer
from .create_norm import get_norm_layer
def _create_pool(num_features, num_classes, pool_type='avg', use_conv=False):
@ -38,7 +45,21 @@ def create_classifier(num_features, num_classes, pool_type='avg', use_conv=False
class ClassifierHead(nn.Module):
"""Classifier head w/ configurable global pooling and dropout."""
def __init__(self, in_features, num_classes, pool_type='avg', drop_rate=0., use_conv=False):
def __init__(
self,
in_features: int,
num_classes: int,
pool_type: str = 'avg',
drop_rate: float = 0.,
use_conv: bool = False,
):
"""
Args:
in_features: The number of input features.
num_classes: The number of classes for the final classifier layer (output).
pool_type: Global pooling type, pooling disabled if empty string ('').
drop_rate: Pre-classifier dropout rate.
"""
super(ClassifierHead, self).__init__()
self.drop_rate = drop_rate
self.in_features = in_features
@ -65,3 +86,76 @@ class ClassifierHead(nn.Module):
else:
x = self.fc(x)
return self.flatten(x)
class NormMlpClassifierHead(nn.Module):
def __init__(
self,
in_features: int,
num_classes: int,
hidden_size: Optional[int] = None,
pool_type: str = 'avg',
drop_rate: float = 0.,
norm_layer: Union[str, Callable] = 'layernorm2d',
act_layer: Union[str, Callable] = 'tanh',
):
"""
Args:
in_features: The number of input features.
num_classes: The number of classes for the final classifier layer (output).
hidden_size: The hidden size of the MLP (pre-logits FC layer) if not None.
pool_type: Global pooling type, pooling disabled if empty string ('').
drop_rate: Pre-classifier dropout rate.
norm_layer: Normalization layer type.
act_layer: MLP activation layer type (only used if hidden_size is not None).
"""
super().__init__()
self.drop_rate = drop_rate
self.in_features = in_features
self.hidden_size = hidden_size
self.num_features = in_features
self.use_conv = not pool_type
norm_layer = get_norm_layer(norm_layer)
act_layer = get_act_layer(act_layer)
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
self.norm = norm_layer(in_features)
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
if hidden_size:
self.pre_logits = nn.Sequential(OrderedDict([
('fc', linear_layer(in_features, hidden_size)),
('act', act_layer()),
]))
self.num_features = hidden_size
else:
self.pre_logits = nn.Identity()
self.drop = nn.Dropout(self.drop_rate)
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def reset(self, num_classes, global_pool=None):
if global_pool is not None:
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.use_conv = self.global_pool.is_identity()
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
if self.hidden_size:
if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or
(isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)):
with torch.no_grad():
new_fc = linear_layer(self.in_features, self.hidden_size)
new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape))
new_fc.bias.copy_(self.pre_logits.fc.bias)
self.pre_logits.fc = new_fc
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x, pre_logits: bool = False):
x = self.global_pool(x)
x = self.norm(x)
x = self.flatten(x)
x = self.pre_logits(x)
if pre_logits:
return x
x = self.fc(x)
return x

@ -39,6 +39,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
from collections import OrderedDict
from functools import partial
from typing import Callable, Optional, Tuple, Union
import torch
import torch.nn as nn
@ -46,6 +47,7 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
from timm.layers import NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg
from ._manipulate import named_apply, checkpoint_seq
from ._pretrained import generate_default_cfgs
@ -188,48 +190,50 @@ class ConvNeXt(nn.Module):
def __init__(
self,
in_chans=3,
num_classes=1000,
global_pool='avg',
output_stride=32,
depths=(3, 3, 9, 3),
dims=(96, 192, 384, 768),
kernel_sizes=7,
ls_init_value=1e-6,
stem_type='patch',
patch_size=4,
head_init_scale=1.,
head_norm_first=False,
conv_mlp=False,
conv_bias=True,
use_grn=False,
act_layer='gelu',
norm_layer=None,
norm_eps=None,
drop_rate=0.,
drop_path_rate=0.,
in_chans: int = 3,
num_classes: int = 1000,
global_pool: str = 'avg',
output_stride: int = 32,
depths: Tuple[int, ...] = (3, 3, 9, 3),
dims: Tuple[int, ...] = (96, 192, 384, 768),
kernel_sizes: Union[int, Tuple[int, ...]] = 7,
ls_init_value: Optional[float] = 1e-6,
stem_type: str = 'patch',
patch_size: int = 4,
head_init_scale: float = 1.,
head_norm_first: bool = False,
head_hidden_size: Optional[int] = None,
conv_mlp: bool = False,
conv_bias: bool = True,
use_grn: bool = False,
act_layer: Union[str, Callable] = 'gelu',
norm_layer: Optional[Union[str, Callable]] = None,
norm_eps: Optional[float] = None,
drop_rate: float = 0.,
drop_path_rate: float = 0.,
):
"""
Args:
in_chans (int): Number of input image channels (default: 3)
num_classes (int): Number of classes for classification head (default: 1000)
global_pool (str): Global pooling type (default: 'avg')
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
depths (tuple(int)): Number of blocks at each stage. (default: [3, 3, 9, 3])
dims (tuple(int)): Feature dimension at each stage. (default: [96, 192, 384, 768])
kernel_sizes (Union[int, List[int]]: Depthwise convolution kernel-sizes for each stage (default: 7)
ls_init_value (float): Init value for Layer Scale (default: 1e-6)
stem_type (str): Type of stem (default: 'patch')
patch_size (int): Stem patch size for patch stem (default: 4)
head_init_scale (float): Init scaling value for classifier weights and biases (default: 1)
head_norm_first (bool): Apply normalization before global pool + head (default: False)
conv_mlp (bool): Use 1x1 conv in MLP, improves speed for small networks w/ chan last (default: False)
conv_bias (bool): Use bias layers w/ all convolutions (default: True)
use_grn (bool): Use Global Response Norm (ConvNeXt-V2) in MLP (default: False)
act_layer (Union[str, nn.Module]): Activation Layer
norm_layer (Union[str, nn.Module]): Normalization Layer
drop_rate (float): Head dropout rate (default: 0.)
drop_path_rate (float): Stochastic depth rate (default: 0.)
in_chans: Number of input image channels.
num_classes: Number of classes for classification head.
global_pool: Global pooling type.
output_stride: Output stride of network, one of (8, 16, 32).
depths: Number of blocks at each stage.
dims: Feature dimension at each stage.
kernel_sizes: Depthwise convolution kernel-sizes for each stage.
ls_init_value: Init value for Layer Scale, disabled if None.
stem_type: Type of stem.
patch_size: Stem patch size for patch stem.
head_init_scale: Init scaling value for classifier weights and biases.
head_norm_first: Apply normalization before global pool + head.
head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
conv_bias: Use bias layers w/ all convolutions.
use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
act_layer: Activation layer type.
norm_layer: Normalization layer type.
drop_rate: Head pre-classifier dropout rate.
drop_path_rate: Stochastic depth drop rate.
"""
super().__init__()
assert output_stride in (8, 16, 32)
@ -307,14 +311,26 @@ class ConvNeXt(nn.Module):
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
if head_norm_first:
assert not head_hidden_size
self.norm_pre = norm_layer(self.num_features)
self.head = ClassifierHead(
self.num_features,
num_classes,
pool_type=global_pool,
drop_rate=self.drop_rate,
)
else:
self.norm_pre = nn.Identity()
self.head = NormMlpClassifierHead(
self.num_features,
num_classes,
hidden_size=head_hidden_size,
pool_type=global_pool,
drop_rate=self.drop_rate,
norm_layer=norm_layer,
act_layer='gelu',
)
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
@torch.jit.ignore
@ -338,10 +354,7 @@ class ConvNeXt(nn.Module):
return self.head.fc
def reset_classifier(self, num_classes=0, global_pool=None):
if global_pool is not None:
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head.reset(num_classes, global_pool=global_pool)
def forward_features(self, x):
x = self.stem(x)
@ -350,12 +363,7 @@ class ConvNeXt(nn.Module):
return x
def forward_head(self, x, pre_logits: bool = False):
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
x = self.head.global_pool(x)
x = self.head.norm(x)
x = self.head.flatten(x)
x = self.head.drop(x)
return x if pre_logits else self.head.fc(x)
return self.head(x, pre_logits=pre_logits)
def forward(self, x):
x = self.forward_features(x)
@ -389,6 +397,11 @@ def checkpoint_filter_fn(state_dict, model):
if 'visual.head.proj.weight' in state_dict:
out_dict['head.fc.weight'] = state_dict['visual.head.proj.weight']
out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.proj.weight'].shape[0])
elif 'visual.head.mlp.fc1.weight' in state_dict:
out_dict['head.pre_logits.fc.weight'] = state_dict['visual.head.mlp.fc1.weight']
out_dict['head.pre_logits.fc.bias'] = state_dict['visual.head.mlp.fc1.bias']
out_dict['head.fc.weight'] = state_dict['visual.head.mlp.fc2.weight']
out_dict['head.fc.bias'] = torch.zeros(state_dict['visual.head.mlp.fc2.weight'].shape[0])
return out_dict
import re
@ -708,6 +721,22 @@ default_cfgs = generate_default_cfgs({
'convnextv2_small.untrained': _cfg(),
# CLIP weights, fine-tuned on in1k or in12k + in1k
'convnext_base.clip_laion2b_augreg_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0),
'convnext_base.clip_laiona_augreg_ft_in1k_384': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 384, 384), pool_size=(12, 12), crop_pct=1.0),
'convnext_large_mlp.clip_laion2b_augreg_ft_in1k': _cfg(
hf_hub_id='timm/',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0
),
# CLIP based weights, original image tower weights and fine-tunes
'convnext_base.clip_laion2b': _cfg(
hf_hub_id='laion/CLIP-convnext_base_w-laion2B-s13B-b82K',
@ -734,6 +763,11 @@ default_cfgs = generate_default_cfgs({
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 320, 320), pool_size=(10, 10), crop_pct=1.0, num_classes=640),
'convnext_large_mlp.clip_laion2b_augreg': _cfg(
hf_hub_id='laion/CLIP-convnext_large_d.laion2B-s26B-b102K-augreg',
hf_hub_filename='open_clip_pytorch_model.bin',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD,
input_size=(3, 256, 256), pool_size=(8, 8), crop_pct=1.0, num_classes=768),
})
@ -846,6 +880,13 @@ def convnext_large(pretrained=False, **kwargs):
return model
@register_model
def convnext_large_mlp(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], head_hidden_size=1536, **kwargs)
model = _create_convnext('convnext_large_mlp', pretrained=pretrained, **model_args)
return model
@register_model
def convnext_xlarge(pretrained=False, **kwargs):
model_args = dict(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs)

@ -23,6 +23,7 @@ from torch import Tensor
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import DropPath, to_2tuple, trunc_normal_, SelectAdaptivePool2d, Mlp, LayerNorm2d, get_norm_layer
from timm.layers import NormMlpClassifierHead, ClassifierHead
from ._builder import build_model_with_cfg
from ._features_fx import register_notrace_function
from ._manipulate import checkpoint_seq
@ -519,14 +520,23 @@ class DaViT(nn.Module):
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
# otherwise pool -> norm -> fc, the default DaViT order, similar to ConvNeXt
# FIXME generalize this structure to ClassifierHead
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
if head_norm_first:
self.norm_pre = norm_layer(self.num_features)
self.head = ClassifierHead(
self.num_features,
num_classes,
pool_type=global_pool,
drop_rate=self.drop_rate,
)
else:
self.norm_pre = nn.Identity()
self.head = NormMlpClassifierHead(
self.num_features,
num_classes,
pool_type=global_pool,
drop_rate=self.drop_rate,
norm_layer=norm_layer,
)
self.apply(self._init_weights)
def _init_weights(self, m):
@ -546,10 +556,7 @@ class DaViT(nn.Module):
return self.head.fc
def reset_classifier(self, num_classes, global_pool=None):
if global_pool is not None:
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
self.head.reset(num_classes, global_pool=global_pool)
def forward_features(self, x):
x = self.stem(x)

@ -36,7 +36,7 @@ Hacked together by / Copyright 2022, Ross Wightman
import math
from collections import OrderedDict
from dataclasses import dataclass, replace
from dataclasses import dataclass, replace, field
from functools import partial
from typing import Callable, Optional, Union, Tuple, List
@ -44,7 +44,7 @@ import torch
from torch import nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.layers import Mlp, ConvMlp, DropPath, ClassifierHead, LayerNorm, SelectAdaptivePool2d
from timm.layers import Mlp, ConvMlp, DropPath, LayerNorm, ClassifierHead, NormMlpClassifierHead
from timm.layers import create_attn, get_act_layer, get_norm_layer, get_norm_act_layer, create_conv2d, create_pool2d
from timm.layers import trunc_normal_tf_, to_2tuple, extend_tuple, make_divisible, _assert
from timm.layers import RelPosMlp, RelPosBias, RelPosBiasTf
@ -133,8 +133,8 @@ class MaxxVitCfg:
block_type: Tuple[Union[str, Tuple[str, ...]], ...] = ('C', 'C', 'T', 'T')
stem_width: Union[int, Tuple[int, int]] = 64
stem_bias: bool = False
conv_cfg: MaxxVitConvCfg = MaxxVitConvCfg()
transformer_cfg: MaxxVitTransformerCfg = MaxxVitTransformerCfg()
conv_cfg: MaxxVitConvCfg = field(default_factory=MaxxVitConvCfg)
transformer_cfg: MaxxVitTransformerCfg = field(default_factory=MaxxVitTransformerCfg)
head_hidden_size: int = None
weight_init: str = 'vit_eff'
@ -1072,69 +1072,6 @@ def cfg_window_size(cfg: MaxxVitTransformerCfg, img_size: Tuple[int, int]):
return cfg
class NormMlpHead(nn.Module):
def __init__(
self,
in_features,
num_classes,
hidden_size=None,
pool_type='avg',
drop_rate=0.,
norm_layer='layernorm2d',
act_layer='tanh',
):
super().__init__()
self.drop_rate = drop_rate
self.in_features = in_features
self.hidden_size = hidden_size
self.num_features = in_features
self.use_conv = not pool_type
norm_layer = get_norm_layer(norm_layer)
act_layer = get_act_layer(act_layer)
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
self.norm = norm_layer(in_features)
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()
if hidden_size:
self.pre_logits = nn.Sequential(OrderedDict([
('fc', linear_layer(in_features, hidden_size)),
('act', act_layer()),
]))
self.num_features = hidden_size
else:
self.pre_logits = nn.Identity()
self.drop = nn.Dropout(self.drop_rate)
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def reset(self, num_classes, global_pool=None):
if global_pool is not None:
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.use_conv = self.global_pool.is_identity()
linear_layer = partial(nn.Conv2d, kernel_size=1) if self.use_conv else nn.Linear
if self.hidden_size:
if ((isinstance(self.pre_logits.fc, nn.Conv2d) and not self.use_conv) or
(isinstance(self.pre_logits.fc, nn.Linear) and self.use_conv)):
with torch.no_grad():
new_fc = linear_layer(self.in_features, self.hidden_size)
new_fc.weight.copy_(self.pre_logits.fc.weight.reshape(new_fc.weight.shape))
new_fc.bias.copy_(self.pre_logits.fc.bias)
self.pre_logits.fc = new_fc
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x, pre_logits: bool = False):
x = self.global_pool(x)
x = self.norm(x)
x = self.flatten(x)
x = self.pre_logits(x)
if pre_logits:
return x
x = self.fc(x)
return x
def _overlay_kwargs(cfg: MaxxVitCfg, **kwargs):
transformer_kwargs = {}
conv_kwargs = {}
@ -1225,7 +1162,7 @@ class MaxxVit(nn.Module):
self.head_hidden_size = cfg.head_hidden_size
if self.head_hidden_size:
self.norm = nn.Identity()
self.head = NormMlpHead(
self.head = NormMlpClassifierHead(
self.num_features,
num_classes,
hidden_size=self.head_hidden_size,
@ -2342,4 +2279,4 @@ def maxvit_xlarge_tf_384(pretrained=False, **kwargs):
@register_model
def maxvit_xlarge_tf_512(pretrained=False, **kwargs):
return _create_maxxvit('maxvit_xlarge_tf_512', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs)
return _create_maxxvit('maxvit_xlarge_tf_512', 'maxvit_xlarge_tf', pretrained=pretrained, **kwargs)

Loading…
Cancel
Save