From 6064d16a2dfe89b1d3706df338cecfdcee395d1f Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 1 Jul 2022 15:16:41 -0700 Subject: [PATCH] Add initial EdgeNeXt import. Significant cleanup / reorg (like ConvNeXt). Fix #1320 * edgenext refactored for torchscript compat, stage base organization * slight refactor of ConvNeXt to match some EdgeNeXt additions * remove use of funky LayerNorm layer in ConvNeXt and just use nn.LayerNorm and LayerNorm2d (permute) --- timm/models/__init__.py | 1 + timm/models/convnext.py | 190 ++++++++------ timm/models/edgenext.py | 545 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 665 insertions(+), 71 deletions(-) create mode 100644 timm/models/edgenext.py diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 4f81683a..195e451b 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -12,6 +12,7 @@ from .deit import * from .densenet import * from .dla import * from .dpn import * +from .edgenext import * from .efficientnet import * from .ghostnet import * from .gluon_resnet import * diff --git a/timm/models/convnext.py b/timm/models/convnext.py index 1aacef2b..662695c7 100644 --- a/timm/models/convnext.py +++ b/timm/models/convnext.py @@ -19,7 +19,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .fx_features import register_notrace_module from .helpers import named_apply, build_model_with_cfg, checkpoint_seq -from .layers import trunc_normal_, ClassifierHead, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp +from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, create_conv2d from .registry import register_model @@ -44,6 +44,7 @@ default_cfgs = dict( convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"), convnext_nano_hnf=_cfg(url=''), + convnext_nano_ols=_cfg(url=''), convnext_tiny_hnf=_cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth', crop_pct=0.95), @@ -88,35 +89,6 @@ default_cfgs = dict( ) -def _is_contiguous(tensor: torch.Tensor) -> bool: - # jit is oh so lovely :/ - # if torch.jit.is_tracing(): - # return True - if torch.jit.is_scripting(): - return tensor.is_contiguous() - else: - return tensor.is_contiguous(memory_format=torch.contiguous_format) - - -@register_notrace_module -class LayerNorm2d(nn.LayerNorm): - r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W). - """ - - def __init__(self, normalized_shape, eps=1e-6): - super().__init__(normalized_shape, eps=eps) - - def forward(self, x) -> torch.Tensor: - if _is_contiguous(x): - return F.layer_norm( - x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) - else: - s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True) - x = (x - u) * torch.rsqrt(s + self.eps) - x = x * self.weight[:, None, None] + self.bias[:, None, None] - return x - - class ConvNeXtBlock(nn.Module): """ ConvNeXt Block There are two equivalent implementations: @@ -133,21 +105,39 @@ class ConvNeXtBlock(nn.Module): ls_init_value (float): Init value for Layer Scale. Default: 1e-6. """ - def __init__(self, dim, drop_path=0., ls_init_value=1e-6, conv_mlp=False, mlp_ratio=4, norm_layer=None): + def __init__( + self, + dim, + dim_out=None, + stride=1, + mlp_ratio=4, + conv_mlp=False, + conv_bias=True, + ls_init_value=1e-6, + norm_layer=None, + act_layer=nn.GELU, + drop_path=0., + ): super().__init__() + dim_out = dim_out or dim if not norm_layer: norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6) mlp_layer = ConvMlp if conv_mlp else Mlp self.use_conv_mlp = conv_mlp - self.conv_dw = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv - self.norm = norm_layer(dim) - self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=nn.GELU) - self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None + self.shortcut_after_dw = stride > 1 + + self.conv_dw = create_conv2d(dim, dim_out, kernel_size=7, stride=stride, depthwise=True, bias=conv_bias) + self.norm = norm_layer(dim_out) + self.mlp = mlp_layer(dim_out, int(mlp_ratio * dim_out), act_layer=act_layer) + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() def forward(self, x): shortcut = x x = self.conv_dw(x) + if self.shortcut_after_dw: + shortcut = x + if self.use_conv_mlp: x = self.norm(x) x = self.mlp(x) @@ -158,32 +148,55 @@ class ConvNeXtBlock(nn.Module): x = x.permute(0, 3, 1, 2) if self.gamma is not None: x = x.mul(self.gamma.reshape(1, -1, 1, 1)) + x = self.drop_path(x) + shortcut + #print('b', x.shape) return x class ConvNeXtStage(nn.Module): def __init__( - self, in_chs, out_chs, stride=2, depth=2, dp_rates=None, ls_init_value=1.0, conv_mlp=False, - norm_layer=None, cl_norm_layer=None, cross_stage=False): + self, + in_chs, + out_chs, + stride=2, + depth=2, + drop_path_rates=None, + ls_init_value=1.0, + downsample_block=False, + conv_mlp=False, + conv_bias=True, + norm_layer=None, + norm_layer_cl=None + ): super().__init__() self.grad_checkpointing = False - if in_chs != out_chs or stride > 1: + if downsample_block or (in_chs == out_chs and stride == 1): + self.downsample = nn.Identity() + else: self.downsample = nn.Sequential( norm_layer(in_chs), - nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride), + nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride, bias=conv_bias), ) - else: - self.downsample = nn.Identity() - - dp_rates = dp_rates or [0.] * depth - self.blocks = nn.Sequential(*[ConvNeXtBlock( - dim=out_chs, drop_path=dp_rates[j], ls_init_value=ls_init_value, conv_mlp=conv_mlp, - norm_layer=norm_layer if conv_mlp else cl_norm_layer) - for j in range(depth)] - ) + in_chs = out_chs + + drop_path_rates = drop_path_rates or [0.] * depth + stage_blocks = [] + for i in range(depth): + stage_blocks.append(ConvNeXtBlock( + dim=in_chs, + dim_out=out_chs, + stride=stride if downsample_block and i == 0 else 1, + drop_path=drop_path_rates[i], + ls_init_value=ls_init_value, + conv_mlp=conv_mlp, + conv_bias=conv_bias, + norm_layer=norm_layer if conv_mlp else norm_layer_cl + )) + in_chs = out_chs + self.blocks = nn.Sequential(*stage_blocks) def forward(self, x): x = self.downsample(x) @@ -210,41 +223,57 @@ class ConvNeXt(nn.Module): """ def __init__( - self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, patch_size=4, - depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), ls_init_value=1e-6, conv_mlp=False, stem_type='patch', - head_init_scale=1., head_norm_first=False, norm_layer=None, drop_rate=0., drop_path_rate=0., + self, + in_chans=3, + num_classes=1000, + global_pool='avg', + output_stride=32, + depths=(3, 3, 9, 3), + dims=(96, 192, 384, 768), + ls_init_value=1e-6, + stem_type='patch', + stem_kernel_size=4, + stem_stride=4, + head_init_scale=1., + head_norm_first=False, + downsample_block=False, + conv_mlp=False, + conv_bias=True, + norm_layer=None, + drop_rate=0., + drop_path_rate=0., ): super().__init__() assert output_stride == 32 if norm_layer is None: norm_layer = partial(LayerNorm2d, eps=1e-6) - cl_norm_layer = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6) + norm_layer_cl = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6) else: assert conv_mlp,\ 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' - cl_norm_layer = norm_layer + norm_layer_cl = norm_layer self.num_classes = num_classes self.drop_rate = drop_rate self.feature_info = [] - # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 + assert stem_type in ('patch', 'overlap') if stem_type == 'patch': + assert stem_kernel_size == stem_stride + # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 self.stem = nn.Sequential( - nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size), + nn.Conv2d(in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride, bias=conv_bias), norm_layer(dims[0]) ) - curr_stride = patch_size - prev_chs = dims[0] else: self.stem = nn.Sequential( - nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1), - norm_layer(32), - nn.GELU(), - nn.Conv2d(32, 64, kernel_size=3, padding=1), + nn.Conv2d( + in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride, + padding=stem_kernel_size // 2, bias=conv_bias), + norm_layer(dims[0]), ) - curr_stride = 2 - prev_chs = 64 + prev_chs = dims[0] + curr_stride = stem_stride self.stages = nn.Sequential() dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] @@ -256,16 +285,24 @@ class ConvNeXt(nn.Module): curr_stride *= stride out_chs = dims[i] stages.append(ConvNeXtStage( - prev_chs, out_chs, stride=stride, - depth=depths[i], dp_rates=dp_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp, - norm_layer=norm_layer, cl_norm_layer=cl_norm_layer) - ) + prev_chs, + out_chs, + stride=stride, + depth=depths[i], + drop_path_rates=dp_rates[i], + ls_init_value=ls_init_value, + downsample_block=downsample_block, + conv_mlp=conv_mlp, + conv_bias=conv_bias, + norm_layer=norm_layer, + norm_layer_cl=norm_layer_cl + )) prev_chs = out_chs # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] self.stages = nn.Sequential(*stages) - self.num_features = prev_chs + # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() @@ -327,10 +364,11 @@ class ConvNeXt(nn.Module): def _init_weights(module, name=None, head_init_scale=1.0): if isinstance(module, nn.Conv2d): trunc_normal_(module.weight, std=.02) - nn.init.constant_(module.bias, 0) + if module.bias is not None: + nn.init.zeros_(module.bias) elif isinstance(module, nn.Linear): trunc_normal_(module.weight, std=.02) - nn.init.constant_(module.bias, 0) + nn.init.zeros_(module.bias) if name and 'head.' in name: module.weight.data.mul_(head_init_scale) module.bias.data.mul_(head_init_scale) @@ -371,11 +409,21 @@ def _create_convnext(variant, pretrained=False, **kwargs): @register_model def convnext_nano_hnf(pretrained=False, **kwargs): - model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs) + model_args = dict( + depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs) model = _create_convnext('convnext_nano_hnf', pretrained=pretrained, **model_args) return model +@register_model +def convnext_nano_ols(pretrained=False, **kwargs): + model_args = dict( + depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), downsample_block=True, + conv_bias=False, stem_type='overlap', stem_kernel_size=9, **kwargs) + model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args) + return model + + @register_model def convnext_tiny_hnf(pretrained=False, **kwargs): model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs) diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py new file mode 100644 index 00000000..0f8b0464 --- /dev/null +++ b/timm/models/edgenext.py @@ -0,0 +1,545 @@ +""" EdgeNeXt + +Paper: `EdgeNeXt: Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision Applications` + - https://arxiv.org/abs/2206.10589 + +Original code and weights from https://github.com/mmaaz60/EdgeNeXt + +Modifications and additions for timm by / Copyright 2022, Ross Wightman +""" +import math +import torch +from collections import OrderedDict +from functools import partial +from typing import Tuple + +from torch import nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.models.layers import trunc_normal_tf_ +from timm.models.layers import DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d +from .helpers import named_apply, build_model_with_cfg, checkpoint_seq +from .registry import register_model + + +__all__ = ['EdgeNeXt'] # model_registry will add each entrypoint fn to this + + +def _cfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8), + 'crop_pct': 0.9, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.0', 'classifier': 'head.fc', + **kwargs + } + + +default_cfgs = dict( + edgenext_xx_small=_cfg( + url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_xx_small.pth"), + edgenext_x_small=_cfg( + url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_x_small.pth"), + # edgenext_small=_cfg( + # url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_small.pth"), + edgenext_small=_cfg( # USI weights + url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth", + crop_pct=0.95 + ), + + edgenext_small_rw=_cfg(), +) + + +class PositionalEncodingFourier(nn.Module): + def __init__(self, hidden_dim=32, dim=768, temperature=10000): + super().__init__() + self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1) + self.scale = 2 * math.pi + self.temperature = temperature + self.hidden_dim = hidden_dim + self.dim = dim + + def forward(self, shape: Tuple[int, int, int]): + inv_mask = ~torch.zeros(shape).to(device=self.token_projection.weight.device, dtype=torch.bool) + y_embed = inv_mask.cumsum(1, dtype=torch.float32) + x_embed = inv_mask.cumsum(2, dtype=torch.float32) + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=inv_mask.device) + dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), + pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), + pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + pos = self.token_projection(pos) + + return pos + + +class ConvBlock(nn.Module): + def __init__( + self, + dim, + dim_out=None, + kernel_size=7, + stride=1, + conv_bias=True, + expand_ratio=4, + ls_init_value=1e-6, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, drop_path=0., + ): + super().__init__() + dim_out = dim_out or dim + self.shortcut_after_dw = stride > 1 or dim != dim_out + + self.conv_dw = create_conv2d( + dim, dim_out, kernel_size=kernel_size, stride=stride, depthwise=True, bias=conv_bias) + self.norm = norm_layer(dim_out) + self.mlp = Mlp(dim_out, int(expand_ratio * dim_out), act_layer=act_layer) + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + x = self.conv_dw(x) + if self.shortcut_after_dw: + shortcut = x + + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.mlp(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = shortcut + self.drop_path(x) + return x + + +class CrossCovarianceAttn(nn.Module): + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + attn_drop=0., + proj_drop=0. + ): + super().__init__() + self.num_heads = num_heads + self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1) + q, k, v = qkv.unbind(0) + + # NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map + attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) * self.temperature + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + @torch.jit.ignore + def no_weight_decay(self): + return {'temperature'} + + +class SplitTransposeBlock(nn.Module): + def __init__( + self, + dim, + num_scales=1, + num_heads=8, + expand_ratio=4, + use_pos_emb=True, + conv_bias=True, + qkv_bias=True, + ls_init_value=1e-6, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU, + drop_path=0., + attn_drop=0., + proj_drop=0. + ): + super().__init__() + width = max(int(math.ceil(dim / num_scales)), int(math.floor(dim // num_scales))) + self.width = width + self.num_scales = max(1, num_scales - 1) + + convs = [] + for i in range(self.num_scales): + convs.append(create_conv2d(width, width, kernel_size=3, depthwise=True, bias=conv_bias)) + self.convs = nn.ModuleList(convs) + + self.pos_embd = None + if use_pos_emb: + self.pos_embd = PositionalEncodingFourier(dim=dim) + self.norm_xca = norm_layer(dim) + self.gamma_xca = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None + self.xca = CrossCovarianceAttn( + dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop) + + self.norm = norm_layer(dim, eps=1e-6) + self.mlp = Mlp(dim, int(expand_ratio * dim), act_layer=act_layer) + self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + + # scales code re-written for torchscript as per my res2net fixes -rw + spx = torch.split(x, self.width, 1) + spo = [] + sp = spx[0] + for i, conv in enumerate(self.convs): + if i > 0: + sp = sp + spx[i] + sp = conv(sp) + spo.append(sp) + spo.append(spx[-1]) + x = torch.cat(spo, 1) + + # XCA + B, C, H, W = x.shape + x = x.reshape(B, C, H * W).permute(0, 2, 1) + if self.pos_embd is not None: + pos_encoding = self.pos_embd((B, H, W)).reshape(B, -1, x.shape[1]).permute(0, 2, 1) + x = x + pos_encoding + x = x + self.drop_path(self.gamma_xca * self.xca(self.norm_xca(x))) + x = x.reshape(B, H, W, C) + + # Inverted Bottleneck + x = self.norm(x) + x = self.mlp(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = shortcut + self.drop_path(x) + return x + + +class EdgeNeXtStage(nn.Module): + def __init__( + self, + in_chs, + out_chs, + stride=2, + depth=2, + num_global_blocks=1, + num_heads=4, + scales=2, + kernel_size=7, + expand_ratio=4, + use_pos_emb=False, + downsample_block=False, + conv_bias=True, + ls_init_value=1.0, + drop_path_rates=None, + norm_layer=LayerNorm2d, + norm_layer_cl=partial(nn.LayerNorm, eps=1e-6), + act_layer=nn.GELU + ): + super().__init__() + self.grad_checkpointing = False + + if downsample_block or stride == 1: + self.downsample = nn.Identity() + else: + self.downsample = nn.Sequential( + norm_layer(in_chs), + nn.Conv2d(in_chs, out_chs, kernel_size=2, stride=2, bias=conv_bias) + ) + in_chs = out_chs + + stage_blocks = [] + for i in range(depth): + if i < depth - num_global_blocks: + stage_blocks.append( + ConvBlock( + dim=in_chs, + dim_out=out_chs, + stride=stride if downsample_block and i == 0 else 1, + conv_bias=conv_bias, + kernel_size=kernel_size, + expand_ratio=expand_ratio, + ls_init_value=ls_init_value, + drop_path=drop_path_rates[i], + norm_layer=norm_layer_cl, + act_layer=act_layer, + ) + ) + else: + stage_blocks.append( + SplitTransposeBlock( + dim=in_chs, + num_scales=scales, + num_heads=num_heads, + expand_ratio=expand_ratio, + use_pos_emb=use_pos_emb, + conv_bias=conv_bias, + ls_init_value=ls_init_value, + drop_path=drop_path_rates[i], + norm_layer=norm_layer_cl, + act_layer=act_layer, + ) + ) + in_chs = out_chs + self.blocks = nn.Sequential(*stage_blocks) + + def forward(self, x): + x = self.downsample(x) + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint_seq(self.blocks, x) + else: + x = self.blocks(x) + return x + + +class EdgeNeXt(nn.Module): + def __init__( + self, + in_chans=3, + num_classes=1000, + global_pool='avg', + dims=(24, 48, 88, 168), + depths=(3, 3, 9, 3), + global_block_counts=(0, 1, 1, 1), + kernel_sizes=(3, 5, 7, 9), + heads=(8, 8, 8, 8), + d2_scales=(2, 2, 3, 4), + use_pos_emb=(False, True, False, False), + ls_init_value=1e-6, + head_init_scale=1., + expand_ratio=4, + downsample_block=False, + conv_bias=True, + stem_type='patch', + head_norm_first=False, + act_layer=nn.GELU, + drop_path_rate=0., + drop_rate=0., + ): + super().__init__() + self.num_classes = num_classes + self.global_pool = global_pool + self.drop_rate = drop_rate + norm_layer = partial(LayerNorm2d, eps=1e-6) + norm_layer_cl = partial(nn.LayerNorm, eps=1e-6) + + assert stem_type in ('patch', 'overlap') + if stem_type == 'patch': + self.stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4, bias=conv_bias), + norm_layer(dims[0]), + ) + else: + self.stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=9, stride=4, padding=9 // 2, bias=conv_bias), + norm_layer(dims[0]), + ) + + stages = [] + dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + in_chs = dims[0] + for i in range(4): + stages.append(EdgeNeXtStage( + in_chs=in_chs, + out_chs=dims[i], + stride=2 if i > 0 else 1, + depth=depths[i], + num_global_blocks=global_block_counts[i], + num_heads=heads[i], + drop_path_rates=dp_rates[i], + scales=d2_scales[i], + expand_ratio=expand_ratio, + kernel_size=kernel_sizes[i], + use_pos_emb=use_pos_emb[i], + ls_init_value=ls_init_value, + downsample_block=downsample_block, + conv_bias=conv_bias, + norm_layer=norm_layer, + norm_layer_cl=norm_layer_cl, + act_layer=act_layer, + )) + in_chs = dims[i] + self.stages = nn.Sequential(*stages) + + self.num_features = dims[-1] + self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() + self.head = nn.Sequential(OrderedDict([ + ('global_pool', SelectAdaptivePool2d(pool_type=global_pool)), + ('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)), + ('flatten', nn.Flatten(1) if global_pool else nn.Identity()), + ('drop', nn.Dropout(self.drop_rate)), + ('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())])) + + named_apply(partial(_init_weights, head_init_scale=head_init_scale), self) + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^stem', + blocks=r'^stages\.(\d+)' if coarse else [ + (r'^stages\.(\d+)\.downsample', (0,)), # blocks + (r'^stages\.(\d+)\.blocks\.(\d+)', None), + (r'^norm_pre', (99999,)) + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes=0, global_pool=None): + if global_pool is not None: + self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool) + self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity() + self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.norm_pre(x) + return x + + def forward_head(self, x, pre_logits: bool = False): + # NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :( + x = self.head.global_pool(x) + x = self.head.norm(x) + x = self.head.flatten(x) + x = self.head.drop(x) + return x if pre_logits else self.head.fc(x) + + def forward(self, x): + x = self.forward_features(x) + x = self.forward_head(x) + return x + + +def _init_weights(module, name=None, head_init_scale=1.0): + if isinstance(module, nn.Conv2d): + trunc_normal_tf_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Linear): + trunc_normal_tf_(module.weight, std=.02) + nn.init.zeros_(module.bias) + if name and 'head.' in name: + module.weight.data.mul_(head_init_scale) + module.bias.data.mul_(head_init_scale) + + +def checkpoint_filter_fn(state_dict, model): + """ Remap FB checkpoints -> timm """ + if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict: + return state_dict # non-FB checkpoint + + # models were released as train checkpoints... :/ + if 'model_ema' in state_dict: + state_dict = state_dict['model_ema'] + elif 'model' in state_dict: + state_dict = state_dict['model'] + elif 'state_dict' in state_dict: + state_dict = state_dict['state_dict'] + + out_dict = {} + import re + for k, v in state_dict.items(): + k = k.replace('downsample_layers.0.', 'stem.') + k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k) + k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k) + k = k.replace('dwconv', 'conv_dw') + k = k.replace('pwconv', 'mlp.fc') + k = k.replace('head.', 'head.fc.') + if k.startswith('norm.'): + k = k.replace('norm', 'head.norm') + if v.ndim == 2 and 'head' not in k: + model_shape = model.state_dict()[k].shape + v = v.reshape(model_shape) + out_dict[k] = v + return out_dict + + +def _create_edgenext(variant, pretrained=False, **kwargs): + model = build_model_with_cfg( + EdgeNeXt, variant, pretrained, + pretrained_filter_fn=checkpoint_filter_fn, + feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True), + **kwargs) + return model + + +@register_model +def edgenext_xx_small(pretrained=False, **kwargs): + # 1.33M & 260.58M @ 256 resolution + # 71.23% Top-1 accuracy + # No AA, Color Jitter=0.4, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler + # Jetson FPS=51.66 versus 47.67 for MobileViT_XXS + # For A100: FPS @ BS=1: 212.13 & @ BS=256: 7042.06 versus FPS @ BS=1: 96.68 & @ BS=256: 4624.71 for MobileViT_XXS + model_kwargs = dict(depths=(2, 2, 6, 2), dims=(24, 48, 88, 168), heads=(4, 4, 4, 4), **kwargs) + return _create_edgenext('edgenext_xx_small', pretrained=pretrained, **model_kwargs) + + +@register_model +def edgenext_x_small(pretrained=False, **kwargs): + # 2.34M & 538.0M @ 256 resolution + # 75.00% Top-1 accuracy + # No AA, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler + # Jetson FPS=31.61 versus 28.49 for MobileViT_XS + # For A100: FPS @ BS=1: 179.55 & @ BS=256: 4404.95 versus FPS @ BS=1: 94.55 & @ BS=256: 2361.53 for MobileViT_XS + model_kwargs = dict(depths=(3, 3, 9, 3), dims=(32, 64, 100, 192), heads=(4, 4, 4, 4), **kwargs) + return _create_edgenext('edgenext_x_small', pretrained=pretrained, **model_kwargs) + + +@register_model +def edgenext_small(pretrained=False, **kwargs): + # 5.59M & 1260.59M @ 256 resolution + # 79.43% Top-1 accuracy + # AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler + # Jetson FPS=20.47 versus 18.86 for MobileViT_S + # For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S + model_kwargs = dict(depths=(3, 3, 9, 3), dims=(48, 96, 160, 304), **kwargs) + return _create_edgenext('edgenext_small', pretrained=pretrained, **model_kwargs) + + +@register_model +def edgenext_small_rw(pretrained=False, **kwargs): + # 5.59M & 1260.59M @ 256 resolution + # 79.43% Top-1 accuracy + # AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler + # Jetson FPS=20.47 versus 18.86 for MobileViT_S + # For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S + model_kwargs = dict( + depths=(3, 3, 9, 3), dims=(48, 96, 192, 384), + downsample_block=True, conv_bias=False, stem_type='overlap', **kwargs) + return _create_edgenext('edgenext_small_rw', pretrained=pretrained, **model_kwargs) +