Add initial EdgeNeXt import. Significant cleanup / reorg (like ConvNeXt). Fix #1320

* edgenext refactored for torchscript compat, stage base organization
* slight refactor of ConvNeXt to match some EdgeNeXt additions
* remove use of funky LayerNorm layer in ConvNeXt and just use nn.LayerNorm and LayerNorm2d (permute)
pull/1327/head
Ross Wightman 3 years ago
parent 7a9c6811c9
commit 6064d16a2d

@ -12,6 +12,7 @@ from .deit import *
from .densenet import * from .densenet import *
from .dla import * from .dla import *
from .dpn import * from .dpn import *
from .edgenext import *
from .efficientnet import * from .efficientnet import *
from .ghostnet import * from .ghostnet import *
from .gluon_resnet import * from .gluon_resnet import *

@ -19,7 +19,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .fx_features import register_notrace_module from .fx_features import register_notrace_module
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
from .layers import trunc_normal_, ClassifierHead, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp from .layers import trunc_normal_, SelectAdaptivePool2d, DropPath, ConvMlp, Mlp, LayerNorm2d, create_conv2d
from .registry import register_model from .registry import register_model
@ -44,6 +44,7 @@ default_cfgs = dict(
convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"), convnext_large=_cfg(url="https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth"),
convnext_nano_hnf=_cfg(url=''), convnext_nano_hnf=_cfg(url=''),
convnext_nano_ols=_cfg(url=''),
convnext_tiny_hnf=_cfg( convnext_tiny_hnf=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/convnext_tiny_hnf_a2h-ab7e9df2.pth',
crop_pct=0.95), crop_pct=0.95),
@ -88,35 +89,6 @@ default_cfgs = dict(
) )
def _is_contiguous(tensor: torch.Tensor) -> bool:
# jit is oh so lovely :/
# if torch.jit.is_tracing():
# return True
if torch.jit.is_scripting():
return tensor.is_contiguous()
else:
return tensor.is_contiguous(memory_format=torch.contiguous_format)
@register_notrace_module
class LayerNorm2d(nn.LayerNorm):
r""" LayerNorm for channels_first tensors with 2d spatial dimensions (ie N, C, H, W).
"""
def __init__(self, normalized_shape, eps=1e-6):
super().__init__(normalized_shape, eps=eps)
def forward(self, x) -> torch.Tensor:
if _is_contiguous(x):
return F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
else:
s, u = torch.var_mean(x, dim=1, unbiased=False, keepdim=True)
x = (x - u) * torch.rsqrt(s + self.eps)
x = x * self.weight[:, None, None] + self.bias[:, None, None]
return x
class ConvNeXtBlock(nn.Module): class ConvNeXtBlock(nn.Module):
""" ConvNeXt Block """ ConvNeXt Block
There are two equivalent implementations: There are two equivalent implementations:
@ -133,21 +105,39 @@ class ConvNeXtBlock(nn.Module):
ls_init_value (float): Init value for Layer Scale. Default: 1e-6. ls_init_value (float): Init value for Layer Scale. Default: 1e-6.
""" """
def __init__(self, dim, drop_path=0., ls_init_value=1e-6, conv_mlp=False, mlp_ratio=4, norm_layer=None): def __init__(
self,
dim,
dim_out=None,
stride=1,
mlp_ratio=4,
conv_mlp=False,
conv_bias=True,
ls_init_value=1e-6,
norm_layer=None,
act_layer=nn.GELU,
drop_path=0.,
):
super().__init__() super().__init__()
dim_out = dim_out or dim
if not norm_layer: if not norm_layer:
norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6) norm_layer = partial(LayerNorm2d, eps=1e-6) if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
mlp_layer = ConvMlp if conv_mlp else Mlp mlp_layer = ConvMlp if conv_mlp else Mlp
self.use_conv_mlp = conv_mlp self.use_conv_mlp = conv_mlp
self.conv_dw = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv self.shortcut_after_dw = stride > 1
self.norm = norm_layer(dim)
self.mlp = mlp_layer(dim, int(mlp_ratio * dim), act_layer=nn.GELU) self.conv_dw = create_conv2d(dim, dim_out, kernel_size=7, stride=stride, depthwise=True, bias=conv_bias)
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None self.norm = norm_layer(dim_out)
self.mlp = mlp_layer(dim_out, int(mlp_ratio * dim_out), act_layer=act_layer)
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x): def forward(self, x):
shortcut = x shortcut = x
x = self.conv_dw(x) x = self.conv_dw(x)
if self.shortcut_after_dw:
shortcut = x
if self.use_conv_mlp: if self.use_conv_mlp:
x = self.norm(x) x = self.norm(x)
x = self.mlp(x) x = self.mlp(x)
@ -158,32 +148,55 @@ class ConvNeXtBlock(nn.Module):
x = x.permute(0, 3, 1, 2) x = x.permute(0, 3, 1, 2)
if self.gamma is not None: if self.gamma is not None:
x = x.mul(self.gamma.reshape(1, -1, 1, 1)) x = x.mul(self.gamma.reshape(1, -1, 1, 1))
x = self.drop_path(x) + shortcut x = self.drop_path(x) + shortcut
#print('b', x.shape)
return x return x
class ConvNeXtStage(nn.Module): class ConvNeXtStage(nn.Module):
def __init__( def __init__(
self, in_chs, out_chs, stride=2, depth=2, dp_rates=None, ls_init_value=1.0, conv_mlp=False, self,
norm_layer=None, cl_norm_layer=None, cross_stage=False): in_chs,
out_chs,
stride=2,
depth=2,
drop_path_rates=None,
ls_init_value=1.0,
downsample_block=False,
conv_mlp=False,
conv_bias=True,
norm_layer=None,
norm_layer_cl=None
):
super().__init__() super().__init__()
self.grad_checkpointing = False self.grad_checkpointing = False
if in_chs != out_chs or stride > 1: if downsample_block or (in_chs == out_chs and stride == 1):
self.downsample = nn.Identity()
else:
self.downsample = nn.Sequential( self.downsample = nn.Sequential(
norm_layer(in_chs), norm_layer(in_chs),
nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride), nn.Conv2d(in_chs, out_chs, kernel_size=stride, stride=stride, bias=conv_bias),
) )
else: in_chs = out_chs
self.downsample = nn.Identity()
drop_path_rates = drop_path_rates or [0.] * depth
dp_rates = dp_rates or [0.] * depth stage_blocks = []
self.blocks = nn.Sequential(*[ConvNeXtBlock( for i in range(depth):
dim=out_chs, drop_path=dp_rates[j], ls_init_value=ls_init_value, conv_mlp=conv_mlp, stage_blocks.append(ConvNeXtBlock(
norm_layer=norm_layer if conv_mlp else cl_norm_layer) dim=in_chs,
for j in range(depth)] dim_out=out_chs,
) stride=stride if downsample_block and i == 0 else 1,
drop_path=drop_path_rates[i],
ls_init_value=ls_init_value,
conv_mlp=conv_mlp,
conv_bias=conv_bias,
norm_layer=norm_layer if conv_mlp else norm_layer_cl
))
in_chs = out_chs
self.blocks = nn.Sequential(*stage_blocks)
def forward(self, x): def forward(self, x):
x = self.downsample(x) x = self.downsample(x)
@ -210,41 +223,57 @@ class ConvNeXt(nn.Module):
""" """
def __init__( def __init__(
self, in_chans=3, num_classes=1000, global_pool='avg', output_stride=32, patch_size=4, self,
depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), ls_init_value=1e-6, conv_mlp=False, stem_type='patch', in_chans=3,
head_init_scale=1., head_norm_first=False, norm_layer=None, drop_rate=0., drop_path_rate=0., num_classes=1000,
global_pool='avg',
output_stride=32,
depths=(3, 3, 9, 3),
dims=(96, 192, 384, 768),
ls_init_value=1e-6,
stem_type='patch',
stem_kernel_size=4,
stem_stride=4,
head_init_scale=1.,
head_norm_first=False,
downsample_block=False,
conv_mlp=False,
conv_bias=True,
norm_layer=None,
drop_rate=0.,
drop_path_rate=0.,
): ):
super().__init__() super().__init__()
assert output_stride == 32 assert output_stride == 32
if norm_layer is None: if norm_layer is None:
norm_layer = partial(LayerNorm2d, eps=1e-6) norm_layer = partial(LayerNorm2d, eps=1e-6)
cl_norm_layer = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6) norm_layer_cl = norm_layer if conv_mlp else partial(nn.LayerNorm, eps=1e-6)
else: else:
assert conv_mlp,\ assert conv_mlp,\
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input' 'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
cl_norm_layer = norm_layer norm_layer_cl = norm_layer
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
self.feature_info = [] self.feature_info = []
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 assert stem_type in ('patch', 'overlap')
if stem_type == 'patch': if stem_type == 'patch':
assert stem_kernel_size == stem_stride
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
self.stem = nn.Sequential( self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=patch_size, stride=patch_size), nn.Conv2d(in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride, bias=conv_bias),
norm_layer(dims[0]) norm_layer(dims[0])
) )
curr_stride = patch_size
prev_chs = dims[0]
else: else:
self.stem = nn.Sequential( self.stem = nn.Sequential(
nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1), nn.Conv2d(
norm_layer(32), in_chans, dims[0], kernel_size=stem_kernel_size, stride=stem_stride,
nn.GELU(), padding=stem_kernel_size // 2, bias=conv_bias),
nn.Conv2d(32, 64, kernel_size=3, padding=1), norm_layer(dims[0]),
) )
curr_stride = 2 prev_chs = dims[0]
prev_chs = 64 curr_stride = stem_stride
self.stages = nn.Sequential() self.stages = nn.Sequential()
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
@ -256,16 +285,24 @@ class ConvNeXt(nn.Module):
curr_stride *= stride curr_stride *= stride
out_chs = dims[i] out_chs = dims[i]
stages.append(ConvNeXtStage( stages.append(ConvNeXtStage(
prev_chs, out_chs, stride=stride, prev_chs,
depth=depths[i], dp_rates=dp_rates[i], ls_init_value=ls_init_value, conv_mlp=conv_mlp, out_chs,
norm_layer=norm_layer, cl_norm_layer=cl_norm_layer) stride=stride,
) depth=depths[i],
drop_path_rates=dp_rates[i],
ls_init_value=ls_init_value,
downsample_block=downsample_block,
conv_mlp=conv_mlp,
conv_bias=conv_bias,
norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl
))
prev_chs = out_chs prev_chs = out_chs
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')] self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages) self.stages = nn.Sequential(*stages)
self.num_features = prev_chs self.num_features = prev_chs
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets # if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights) # otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity() self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
@ -327,10 +364,11 @@ class ConvNeXt(nn.Module):
def _init_weights(module, name=None, head_init_scale=1.0): def _init_weights(module, name=None, head_init_scale=1.0):
if isinstance(module, nn.Conv2d): if isinstance(module, nn.Conv2d):
trunc_normal_(module.weight, std=.02) trunc_normal_(module.weight, std=.02)
nn.init.constant_(module.bias, 0) if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Linear): elif isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=.02) trunc_normal_(module.weight, std=.02)
nn.init.constant_(module.bias, 0) nn.init.zeros_(module.bias)
if name and 'head.' in name: if name and 'head.' in name:
module.weight.data.mul_(head_init_scale) module.weight.data.mul_(head_init_scale)
module.bias.data.mul_(head_init_scale) module.bias.data.mul_(head_init_scale)
@ -371,11 +409,21 @@ def _create_convnext(variant, pretrained=False, **kwargs):
@register_model @register_model
def convnext_nano_hnf(pretrained=False, **kwargs): def convnext_nano_hnf(pretrained=False, **kwargs):
model_args = dict(depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs) model_args = dict(
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), head_norm_first=True, conv_mlp=True, **kwargs)
model = _create_convnext('convnext_nano_hnf', pretrained=pretrained, **model_args) model = _create_convnext('convnext_nano_hnf', pretrained=pretrained, **model_args)
return model return model
@register_model
def convnext_nano_ols(pretrained=False, **kwargs):
model_args = dict(
depths=(2, 2, 8, 2), dims=(80, 160, 320, 640), downsample_block=True,
conv_bias=False, stem_type='overlap', stem_kernel_size=9, **kwargs)
model = _create_convnext('convnext_nano_ols', pretrained=pretrained, **model_args)
return model
@register_model @register_model
def convnext_tiny_hnf(pretrained=False, **kwargs): def convnext_tiny_hnf(pretrained=False, **kwargs):
model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs) model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), head_norm_first=True, conv_mlp=True, **kwargs)

@ -0,0 +1,545 @@
""" EdgeNeXt
Paper: `EdgeNeXt: Efficiently Amalgamated CNN-Transformer Architecture for Mobile Vision Applications`
- https://arxiv.org/abs/2206.10589
Original code and weights from https://github.com/mmaaz60/EdgeNeXt
Modifications and additions for timm by / Copyright 2022, Ross Wightman
"""
import math
import torch
from collections import OrderedDict
from functools import partial
from typing import Tuple
from torch import nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from timm.models.layers import trunc_normal_tf_
from timm.models.layers import DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d
from .helpers import named_apply, build_model_with_cfg, checkpoint_seq
from .registry import register_model
__all__ = ['EdgeNeXt'] # model_registry will add each entrypoint fn to this
def _cfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 256, 256), 'pool_size': (8, 8),
'crop_pct': 0.9, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.0', 'classifier': 'head.fc',
**kwargs
}
default_cfgs = dict(
edgenext_xx_small=_cfg(
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_xx_small.pth"),
edgenext_x_small=_cfg(
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_x_small.pth"),
# edgenext_small=_cfg(
# url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.0/edgenext_small.pth"),
edgenext_small=_cfg( # USI weights
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth",
crop_pct=0.95
),
edgenext_small_rw=_cfg(),
)
class PositionalEncodingFourier(nn.Module):
def __init__(self, hidden_dim=32, dim=768, temperature=10000):
super().__init__()
self.token_projection = nn.Conv2d(hidden_dim * 2, dim, kernel_size=1)
self.scale = 2 * math.pi
self.temperature = temperature
self.hidden_dim = hidden_dim
self.dim = dim
def forward(self, shape: Tuple[int, int, int]):
inv_mask = ~torch.zeros(shape).to(device=self.token_projection.weight.device, dtype=torch.bool)
y_embed = inv_mask.cumsum(1, dtype=torch.float32)
x_embed = inv_mask.cumsum(2, dtype=torch.float32)
eps = 1e-6
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = torch.arange(self.hidden_dim, dtype=torch.float32, device=inv_mask.device)
dim_t = self.temperature ** (2 * torch.div(dim_t, 2, rounding_mode='floor') / self.hidden_dim)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = torch.stack(
(pos_x[:, :, :, 0::2].sin(),
pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos_y = torch.stack(
(pos_y[:, :, :, 0::2].sin(),
pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
pos = self.token_projection(pos)
return pos
class ConvBlock(nn.Module):
def __init__(
self,
dim,
dim_out=None,
kernel_size=7,
stride=1,
conv_bias=True,
expand_ratio=4,
ls_init_value=1e-6,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU, drop_path=0.,
):
super().__init__()
dim_out = dim_out or dim
self.shortcut_after_dw = stride > 1 or dim != dim_out
self.conv_dw = create_conv2d(
dim, dim_out, kernel_size=kernel_size, stride=stride, depthwise=True, bias=conv_bias)
self.norm = norm_layer(dim_out)
self.mlp = Mlp(dim_out, int(expand_ratio * dim_out), act_layer=act_layer)
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim_out)) if ls_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
shortcut = x
x = self.conv_dw(x)
if self.shortcut_after_dw:
shortcut = x
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.mlp(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = shortcut + self.drop_path(x)
return x
class CrossCovarianceAttn(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
attn_drop=0.,
proj_drop=0.
):
super().__init__()
self.num_heads = num_heads
self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1)
q, k, v = qkv.unbind(0)
# NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@torch.jit.ignore
def no_weight_decay(self):
return {'temperature'}
class SplitTransposeBlock(nn.Module):
def __init__(
self,
dim,
num_scales=1,
num_heads=8,
expand_ratio=4,
use_pos_emb=True,
conv_bias=True,
qkv_bias=True,
ls_init_value=1e-6,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU,
drop_path=0.,
attn_drop=0.,
proj_drop=0.
):
super().__init__()
width = max(int(math.ceil(dim / num_scales)), int(math.floor(dim // num_scales)))
self.width = width
self.num_scales = max(1, num_scales - 1)
convs = []
for i in range(self.num_scales):
convs.append(create_conv2d(width, width, kernel_size=3, depthwise=True, bias=conv_bias))
self.convs = nn.ModuleList(convs)
self.pos_embd = None
if use_pos_emb:
self.pos_embd = PositionalEncodingFourier(dim=dim)
self.norm_xca = norm_layer(dim)
self.gamma_xca = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None
self.xca = CrossCovarianceAttn(
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=proj_drop)
self.norm = norm_layer(dim, eps=1e-6)
self.mlp = Mlp(dim, int(expand_ratio * dim), act_layer=act_layer)
self.gamma = nn.Parameter(ls_init_value * torch.ones(dim)) if ls_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
shortcut = x
# scales code re-written for torchscript as per my res2net fixes -rw
spx = torch.split(x, self.width, 1)
spo = []
sp = spx[0]
for i, conv in enumerate(self.convs):
if i > 0:
sp = sp + spx[i]
sp = conv(sp)
spo.append(sp)
spo.append(spx[-1])
x = torch.cat(spo, 1)
# XCA
B, C, H, W = x.shape
x = x.reshape(B, C, H * W).permute(0, 2, 1)
if self.pos_embd is not None:
pos_encoding = self.pos_embd((B, H, W)).reshape(B, -1, x.shape[1]).permute(0, 2, 1)
x = x + pos_encoding
x = x + self.drop_path(self.gamma_xca * self.xca(self.norm_xca(x)))
x = x.reshape(B, H, W, C)
# Inverted Bottleneck
x = self.norm(x)
x = self.mlp(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = shortcut + self.drop_path(x)
return x
class EdgeNeXtStage(nn.Module):
def __init__(
self,
in_chs,
out_chs,
stride=2,
depth=2,
num_global_blocks=1,
num_heads=4,
scales=2,
kernel_size=7,
expand_ratio=4,
use_pos_emb=False,
downsample_block=False,
conv_bias=True,
ls_init_value=1.0,
drop_path_rates=None,
norm_layer=LayerNorm2d,
norm_layer_cl=partial(nn.LayerNorm, eps=1e-6),
act_layer=nn.GELU
):
super().__init__()
self.grad_checkpointing = False
if downsample_block or stride == 1:
self.downsample = nn.Identity()
else:
self.downsample = nn.Sequential(
norm_layer(in_chs),
nn.Conv2d(in_chs, out_chs, kernel_size=2, stride=2, bias=conv_bias)
)
in_chs = out_chs
stage_blocks = []
for i in range(depth):
if i < depth - num_global_blocks:
stage_blocks.append(
ConvBlock(
dim=in_chs,
dim_out=out_chs,
stride=stride if downsample_block and i == 0 else 1,
conv_bias=conv_bias,
kernel_size=kernel_size,
expand_ratio=expand_ratio,
ls_init_value=ls_init_value,
drop_path=drop_path_rates[i],
norm_layer=norm_layer_cl,
act_layer=act_layer,
)
)
else:
stage_blocks.append(
SplitTransposeBlock(
dim=in_chs,
num_scales=scales,
num_heads=num_heads,
expand_ratio=expand_ratio,
use_pos_emb=use_pos_emb,
conv_bias=conv_bias,
ls_init_value=ls_init_value,
drop_path=drop_path_rates[i],
norm_layer=norm_layer_cl,
act_layer=act_layer,
)
)
in_chs = out_chs
self.blocks = nn.Sequential(*stage_blocks)
def forward(self, x):
x = self.downsample(x)
if self.grad_checkpointing and not torch.jit.is_scripting():
x = checkpoint_seq(self.blocks, x)
else:
x = self.blocks(x)
return x
class EdgeNeXt(nn.Module):
def __init__(
self,
in_chans=3,
num_classes=1000,
global_pool='avg',
dims=(24, 48, 88, 168),
depths=(3, 3, 9, 3),
global_block_counts=(0, 1, 1, 1),
kernel_sizes=(3, 5, 7, 9),
heads=(8, 8, 8, 8),
d2_scales=(2, 2, 3, 4),
use_pos_emb=(False, True, False, False),
ls_init_value=1e-6,
head_init_scale=1.,
expand_ratio=4,
downsample_block=False,
conv_bias=True,
stem_type='patch',
head_norm_first=False,
act_layer=nn.GELU,
drop_path_rate=0.,
drop_rate=0.,
):
super().__init__()
self.num_classes = num_classes
self.global_pool = global_pool
self.drop_rate = drop_rate
norm_layer = partial(LayerNorm2d, eps=1e-6)
norm_layer_cl = partial(nn.LayerNorm, eps=1e-6)
assert stem_type in ('patch', 'overlap')
if stem_type == 'patch':
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4, bias=conv_bias),
norm_layer(dims[0]),
)
else:
self.stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=9, stride=4, padding=9 // 2, bias=conv_bias),
norm_layer(dims[0]),
)
stages = []
dp_rates = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)]
in_chs = dims[0]
for i in range(4):
stages.append(EdgeNeXtStage(
in_chs=in_chs,
out_chs=dims[i],
stride=2 if i > 0 else 1,
depth=depths[i],
num_global_blocks=global_block_counts[i],
num_heads=heads[i],
drop_path_rates=dp_rates[i],
scales=d2_scales[i],
expand_ratio=expand_ratio,
kernel_size=kernel_sizes[i],
use_pos_emb=use_pos_emb[i],
ls_init_value=ls_init_value,
downsample_block=downsample_block,
conv_bias=conv_bias,
norm_layer=norm_layer,
norm_layer_cl=norm_layer_cl,
act_layer=act_layer,
))
in_chs = dims[i]
self.stages = nn.Sequential(*stages)
self.num_features = dims[-1]
self.norm_pre = norm_layer(self.num_features) if head_norm_first else nn.Identity()
self.head = nn.Sequential(OrderedDict([
('global_pool', SelectAdaptivePool2d(pool_type=global_pool)),
('norm', nn.Identity() if head_norm_first else norm_layer(self.num_features)),
('flatten', nn.Flatten(1) if global_pool else nn.Identity()),
('drop', nn.Dropout(self.drop_rate)),
('fc', nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity())]))
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
@torch.jit.ignore
def group_matcher(self, coarse=False):
return dict(
stem=r'^stem',
blocks=r'^stages\.(\d+)' if coarse else [
(r'^stages\.(\d+)\.downsample', (0,)), # blocks
(r'^stages\.(\d+)\.blocks\.(\d+)', None),
(r'^norm_pre', (99999,))
]
)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
for s in self.stages:
s.grad_checkpointing = enable
@torch.jit.ignore
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes=0, global_pool=None):
if global_pool is not None:
self.head.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
self.head.flatten = nn.Flatten(1) if global_pool else nn.Identity()
self.head.fc = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
x = self.norm_pre(x)
return x
def forward_head(self, x, pre_logits: bool = False):
# NOTE nn.Sequential in head broken down since can't call head[:-1](x) in torchscript :(
x = self.head.global_pool(x)
x = self.head.norm(x)
x = self.head.flatten(x)
x = self.head.drop(x)
return x if pre_logits else self.head.fc(x)
def forward(self, x):
x = self.forward_features(x)
x = self.forward_head(x)
return x
def _init_weights(module, name=None, head_init_scale=1.0):
if isinstance(module, nn.Conv2d):
trunc_normal_tf_(module.weight, std=.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Linear):
trunc_normal_tf_(module.weight, std=.02)
nn.init.zeros_(module.bias)
if name and 'head.' in name:
module.weight.data.mul_(head_init_scale)
module.bias.data.mul_(head_init_scale)
def checkpoint_filter_fn(state_dict, model):
""" Remap FB checkpoints -> timm """
if 'head.norm.weight' in state_dict or 'norm_pre.weight' in state_dict:
return state_dict # non-FB checkpoint
# models were released as train checkpoints... :/
if 'model_ema' in state_dict:
state_dict = state_dict['model_ema']
elif 'model' in state_dict:
state_dict = state_dict['model']
elif 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
out_dict = {}
import re
for k, v in state_dict.items():
k = k.replace('downsample_layers.0.', 'stem.')
k = re.sub(r'stages.([0-9]+).([0-9]+)', r'stages.\1.blocks.\2', k)
k = re.sub(r'downsample_layers.([0-9]+).([0-9]+)', r'stages.\1.downsample.\2', k)
k = k.replace('dwconv', 'conv_dw')
k = k.replace('pwconv', 'mlp.fc')
k = k.replace('head.', 'head.fc.')
if k.startswith('norm.'):
k = k.replace('norm', 'head.norm')
if v.ndim == 2 and 'head' not in k:
model_shape = model.state_dict()[k].shape
v = v.reshape(model_shape)
out_dict[k] = v
return out_dict
def _create_edgenext(variant, pretrained=False, **kwargs):
model = build_model_with_cfg(
EdgeNeXt, variant, pretrained,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=(0, 1, 2, 3), flatten_sequential=True),
**kwargs)
return model
@register_model
def edgenext_xx_small(pretrained=False, **kwargs):
# 1.33M & 260.58M @ 256 resolution
# 71.23% Top-1 accuracy
# No AA, Color Jitter=0.4, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler
# Jetson FPS=51.66 versus 47.67 for MobileViT_XXS
# For A100: FPS @ BS=1: 212.13 & @ BS=256: 7042.06 versus FPS @ BS=1: 96.68 & @ BS=256: 4624.71 for MobileViT_XXS
model_kwargs = dict(depths=(2, 2, 6, 2), dims=(24, 48, 88, 168), heads=(4, 4, 4, 4), **kwargs)
return _create_edgenext('edgenext_xx_small', pretrained=pretrained, **model_kwargs)
@register_model
def edgenext_x_small(pretrained=False, **kwargs):
# 2.34M & 538.0M @ 256 resolution
# 75.00% Top-1 accuracy
# No AA, No Mixup & Cutmix, DropPath=0.0, BS=4096, lr=0.006, multi-scale-sampler
# Jetson FPS=31.61 versus 28.49 for MobileViT_XS
# For A100: FPS @ BS=1: 179.55 & @ BS=256: 4404.95 versus FPS @ BS=1: 94.55 & @ BS=256: 2361.53 for MobileViT_XS
model_kwargs = dict(depths=(3, 3, 9, 3), dims=(32, 64, 100, 192), heads=(4, 4, 4, 4), **kwargs)
return _create_edgenext('edgenext_x_small', pretrained=pretrained, **model_kwargs)
@register_model
def edgenext_small(pretrained=False, **kwargs):
# 5.59M & 1260.59M @ 256 resolution
# 79.43% Top-1 accuracy
# AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
# Jetson FPS=20.47 versus 18.86 for MobileViT_S
# For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S
model_kwargs = dict(depths=(3, 3, 9, 3), dims=(48, 96, 160, 304), **kwargs)
return _create_edgenext('edgenext_small', pretrained=pretrained, **model_kwargs)
@register_model
def edgenext_small_rw(pretrained=False, **kwargs):
# 5.59M & 1260.59M @ 256 resolution
# 79.43% Top-1 accuracy
# AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
# Jetson FPS=20.47 versus 18.86 for MobileViT_S
# For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S
model_kwargs = dict(
depths=(3, 3, 9, 3), dims=(48, 96, 192, 384),
downsample_block=True, conv_bias=False, stem_type='overlap', **kwargs)
return _create_edgenext('edgenext_small_rw', pretrained=pretrained, **model_kwargs)
Loading…
Cancel
Save