|
|
@ -39,6 +39,7 @@ Modifications and additions for timm hacked together by / Copyright 2022, Ross W
|
|
|
|
|
|
|
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
from collections import OrderedDict
|
|
|
|
from functools import partial
|
|
|
|
from functools import partial
|
|
|
|
|
|
|
|
from typing import Callable, Optional, Tuple, Union
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
import torch.nn as nn
|
|
|
@ -46,6 +47,7 @@ import torch.nn as nn
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
|
|
|
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
|
|
|
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \
|
|
|
|
from timm.layers import trunc_normal_, SelectAdaptivePool2d, DropPath, Mlp, GlobalResponseNormMlp, \
|
|
|
|
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
|
|
|
|
LayerNorm2d, LayerNorm, create_conv2d, get_act_layer, make_divisible, to_ntuple
|
|
|
|
|
|
|
|
from timm.layers import NormMlpClassifierHead, ClassifierHead
|
|
|
|
from ._builder import build_model_with_cfg
|
|
|
|
from ._builder import build_model_with_cfg
|
|
|
|
from ._manipulate import named_apply, checkpoint_seq
|
|
|
|
from ._manipulate import named_apply, checkpoint_seq
|
|
|
|
from ._pretrained import generate_default_cfgs
|
|
|
|
from ._pretrained import generate_default_cfgs
|
|
|
@ -188,48 +190,50 @@ class ConvNeXt(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
self,
|
|
|
|
in_chans=3,
|
|
|
|
in_chans: int = 3,
|
|
|
|
num_classes=1000,
|
|
|
|
num_classes: int = 1000,
|
|
|
|
global_pool='avg',
|
|
|
|
global_pool: str = 'avg',
|
|
|
|
output_stride=32,
|
|
|
|
output_stride: int = 32,
|
|
|
|
depths=(3, 3, 9, 3),
|
|
|
|
depths: Tuple[int, ...] = (3, 3, 9, 3),
|
|
|
|
dims=(96, 192, 384, 768),
|
|
|
|
dims: Tuple[int, ...] = (96, 192, 384, 768),
|
|
|
|
kernel_sizes=7,
|
|
|
|
kernel_sizes: Union[int, Tuple[int, ...]] = 7,
|
|
|
|
ls_init_value=1e-6,
|
|
|
|
ls_init_value: Optional[float] = 1e-6,
|
|
|
|
stem_type='patch',
|
|
|
|
stem_type: str = 'patch',
|
|
|
|
patch_size=4,
|
|
|
|
patch_size: int = 4,
|
|
|
|
head_init_scale=1.,
|
|
|
|
head_init_scale: float = 1.,
|
|
|
|
head_norm_first=False,
|
|
|
|
head_norm_first: bool = False,
|
|
|
|
conv_mlp=False,
|
|
|
|
head_hidden_size: Optional[int] = None,
|
|
|
|
conv_bias=True,
|
|
|
|
conv_mlp: bool = False,
|
|
|
|
use_grn=False,
|
|
|
|
conv_bias: bool = True,
|
|
|
|
act_layer='gelu',
|
|
|
|
use_grn: bool = False,
|
|
|
|
norm_layer=None,
|
|
|
|
act_layer: Union[str, Callable] = 'gelu',
|
|
|
|
norm_eps=None,
|
|
|
|
norm_layer: Optional[Union[str, Callable]] = None,
|
|
|
|
drop_rate=0.,
|
|
|
|
norm_eps: Optional[float] = None,
|
|
|
|
drop_path_rate=0.,
|
|
|
|
drop_rate: float = 0.,
|
|
|
|
|
|
|
|
drop_path_rate: float = 0.,
|
|
|
|
):
|
|
|
|
):
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
Args:
|
|
|
|
Args:
|
|
|
|
in_chans (int): Number of input image channels (default: 3)
|
|
|
|
in_chans: Number of input image channels.
|
|
|
|
num_classes (int): Number of classes for classification head (default: 1000)
|
|
|
|
num_classes: Number of classes for classification head.
|
|
|
|
global_pool (str): Global pooling type (default: 'avg')
|
|
|
|
global_pool: Global pooling type.
|
|
|
|
output_stride (int): Output stride of network, one of (8, 16, 32) (default: 32)
|
|
|
|
output_stride: Output stride of network, one of (8, 16, 32).
|
|
|
|
depths (tuple(int)): Number of blocks at each stage. (default: [3, 3, 9, 3])
|
|
|
|
depths: Number of blocks at each stage.
|
|
|
|
dims (tuple(int)): Feature dimension at each stage. (default: [96, 192, 384, 768])
|
|
|
|
dims: Feature dimension at each stage.
|
|
|
|
kernel_sizes (Union[int, List[int]]: Depthwise convolution kernel-sizes for each stage (default: 7)
|
|
|
|
kernel_sizes: Depthwise convolution kernel-sizes for each stage.
|
|
|
|
ls_init_value (float): Init value for Layer Scale (default: 1e-6)
|
|
|
|
ls_init_value: Init value for Layer Scale, disabled if None.
|
|
|
|
stem_type (str): Type of stem (default: 'patch')
|
|
|
|
stem_type: Type of stem.
|
|
|
|
patch_size (int): Stem patch size for patch stem (default: 4)
|
|
|
|
patch_size: Stem patch size for patch stem.
|
|
|
|
head_init_scale (float): Init scaling value for classifier weights and biases (default: 1)
|
|
|
|
head_init_scale: Init scaling value for classifier weights and biases.
|
|
|
|
head_norm_first (bool): Apply normalization before global pool + head (default: False)
|
|
|
|
head_norm_first: Apply normalization before global pool + head.
|
|
|
|
conv_mlp (bool): Use 1x1 conv in MLP, improves speed for small networks w/ chan last (default: False)
|
|
|
|
head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
|
|
|
|
conv_bias (bool): Use bias layers w/ all convolutions (default: True)
|
|
|
|
conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
|
|
|
|
use_grn (bool): Use Global Response Norm (ConvNeXt-V2) in MLP (default: False)
|
|
|
|
conv_bias: Use bias layers w/ all convolutions.
|
|
|
|
act_layer (Union[str, nn.Module]): Activation Layer
|
|
|
|
use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
|
|
|
|
norm_layer (Union[str, nn.Module]): Normalization Layer
|
|
|
|
act_layer: Activation layer type.
|
|
|
|
drop_rate (float): Head dropout rate (default: 0.)
|
|
|
|
norm_layer: Normalization layer type.
|
|
|
|
drop_path_rate (float): Stochastic depth rate (default: 0.)
|
|
|
|
drop_rate: Head pre-classifier dropout rate.
|
|
|
|
|
|
|
|
drop_path_rate: Stochastic depth drop rate.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
assert output_stride in (8, 16, 32)
|
|
|
|
assert output_stride in (8, 16, 32)
|
|
|
@ -307,14 +311,26 @@ class ConvNeXt(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
|
|
|
|
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
|
|
|
|
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
|
|
|
|
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
|
|
|
|
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
|
|
|
|
if head_norm_first:
|
|
|
|
self.head = nn.Sequential(OrderedDict([
|
|
|
|
assert not head_hidden_size
|
|
|
|
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
|
|
|
|
self.norm_pre = norm_layer(self.num_features)
|
|
|
|
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
|
|
|
|
self.head = ClassifierHead(
|
|
|
|
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
|
|
|
|
self.num_features,
|
|
|
|
('drop', nn.Dropout(self.drop_rate)),
|
|
|
|
num_classes,
|
|
|
|
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
|
|
|
|
pool_type=global_pool,
|
|
|
|
|
|
|
|
drop_rate=self.drop_rate,
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
|
|
self.norm_pre = nn.Identity()
|
|
|
|
|
|
|
|
self.head = NormMlpClassifierHead(
|
|
|
|
|
|
|
|
self.num_features,
|
|
|
|
|
|
|
|
num_classes,
|
|
|
|
|
|
|
|
hidden_size=head_hidden_size,
|
|
|
|
|
|
|
|
pool_type=global_pool,
|
|
|
|
|
|
|
|
drop_rate=self.drop_rate,
|
|
|
|
|
|
|
|
norm_layer=norm_layer,
|
|
|
|
|
|
|
|
act_layer='gelu',
|
|
|
|
|
|
|
|
)
|
|
|
|
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
|
|
|
|
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.ignore
|
|
|
|
@torch.jit.ignore
|
|
|
@ -338,10 +354,7 @@ class ConvNeXt(nn.Module):
|
|
|
|
return self.head.fc
|
|
|
|
return self.head.fc
|
|
|
|
|
|
|
|
|
|
|
|
def reset_classifier(self, num_classes=0, global_pool=None):
|
|
|
|
def reset_classifier(self, num_classes=0, global_pool=None):
|
|
|
|
if global_pool is not None:
|
|
|
|
self.head.reset(num_classes, global_pool=global_pool)
|
|
|
|
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
|
|
|
|
|
|
|
|
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
|
|
|
|
|
|
|
|
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward_features(self, x):
|
|
|
|
def forward_features(self, x):
|
|
|
|
x = self.stem(x)
|
|
|
|
x = self.stem(x)
|
|
|
@ -350,12 +363,7 @@ class ConvNeXt(nn.Module):
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
|
|
def forward_head(self, x, pre_logits: bool = False):
|
|
|
|
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
|
|
|
|
return self.head(x, pre_logits=pre_logits)
|
|
|
|
x = self.head.global_pool(x)
|
|
|
|
|
|
|
|
x = self.head.norm(x)
|
|
|
|
|
|
|
|
x = self.head.flatten(x)
|
|
|
|
|
|
|
|
x = self.head.drop(x)
|
|
|
|
|
|
|
|
return x if pre_logits else self.head.fc(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.forward_features(x)
|
|
|
|
x = self.forward_features(x)
|
|
|
|