|
|
@ -50,6 +50,12 @@ default_cfgs = dict(
|
|
|
|
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth",
|
|
|
|
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth",
|
|
|
|
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
|
|
|
|
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
|
|
|
|
),
|
|
|
|
),
|
|
|
|
|
|
|
|
# edgenext_base=_cfg(
|
|
|
|
|
|
|
|
# url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth"),
|
|
|
|
|
|
|
|
edgenext_base=_cfg( # USI weights
|
|
|
|
|
|
|
|
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth",
|
|
|
|
|
|
|
|
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
|
|
|
|
edgenext_small_rw=_cfg(
|
|
|
|
edgenext_small_rw=_cfg(
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth',
|
|
|
|
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth',
|
|
|
@ -154,7 +160,7 @@ class CrossCovarianceAttn(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
B, N, C = x.shape
|
|
|
|
B, N, C = x.shape
|
|
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1)
|
|
|
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 4, 1)
|
|
|
|
q, k, v = qkv.unbind(0)
|
|
|
|
q, k, v = qkv.unbind(0)
|
|
|
|
|
|
|
|
|
|
|
|
# NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map
|
|
|
|
# NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map
|
|
|
@ -217,7 +223,8 @@ class SplitTransposeBlock(nn.Module):
|
|
|
|
shortcut = x
|
|
|
|
shortcut = x
|
|
|
|
|
|
|
|
|
|
|
|
# scales code re-written for torchscript as per my res2net fixes -rw
|
|
|
|
# scales code re-written for torchscript as per my res2net fixes -rw
|
|
|
|
spx = torch.split(x, self.width, 1)
|
|
|
|
# NOTE torch.split(x, self.width, 1) causing issues with ONNX export
|
|
|
|
|
|
|
|
spx = x.chunk(len(self.convs) + 1, dim=1)
|
|
|
|
spo = []
|
|
|
|
spo = []
|
|
|
|
sp = spx[0]
|
|
|
|
sp = spx[0]
|
|
|
|
for i, conv in enumerate(self.convs):
|
|
|
|
for i, conv in enumerate(self.convs):
|
|
|
@ -545,13 +552,19 @@ def edgenext_small(pretrained=False, **kwargs):
|
|
|
|
return _create_edgenext('edgenext_small', pretrained=pretrained, **model_kwargs)
|
|
|
|
return _create_edgenext('edgenext_small', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
|
|
|
|
def edgenext_base(pretrained=False, **kwargs):
|
|
|
|
|
|
|
|
# 18.51M & 3840.93M @ 256 resolution
|
|
|
|
|
|
|
|
# 82.5% (normal) 83.7% (USI) Top-1 accuracy
|
|
|
|
|
|
|
|
# AA=True, Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
|
|
|
|
|
|
|
|
# Jetson FPS=xx.xx versus xx.xx for MobileViT_S
|
|
|
|
|
|
|
|
# For A100: FPS @ BS=1: xxx.xx & @ BS=256: xxxx.xx
|
|
|
|
|
|
|
|
model_kwargs = dict(depths=[3, 3, 9, 3], dims=[80, 160, 288, 584], **kwargs)
|
|
|
|
|
|
|
|
return _create_edgenext('edgenext_base', pretrained=pretrained, **model_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_model
|
|
|
|
@register_model
|
|
|
|
def edgenext_small_rw(pretrained=False, **kwargs):
|
|
|
|
def edgenext_small_rw(pretrained=False, **kwargs):
|
|
|
|
# 5.59M & 1260.59M @ 256 resolution
|
|
|
|
|
|
|
|
# 79.43% Top-1 accuracy
|
|
|
|
|
|
|
|
# AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
|
|
|
|
|
|
|
|
# Jetson FPS=20.47 versus 18.86 for MobileViT_S
|
|
|
|
|
|
|
|
# For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S
|
|
|
|
|
|
|
|
model_kwargs = dict(
|
|
|
|
model_kwargs = dict(
|
|
|
|
depths=(3, 3, 9, 3), dims=(48, 96, 192, 384),
|
|
|
|
depths=(3, 3, 9, 3), dims=(48, 96, 192, 384),
|
|
|
|
downsample_block=True, conv_bias=False, stem_type='overlap', **kwargs)
|
|
|
|
downsample_block=True, conv_bias=False, stem_type='overlap', **kwargs)
|
|
|
|