Add edgenext_base model def & weight link, update to improve ONNX export #1385

pull/1420/head
Ross Wightman 2 years ago
parent 56596e4e84
commit 13565aad50

@ -50,6 +50,12 @@ default_cfgs = dict(
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth", url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.1/edgenext_small_usi.pth",
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0, crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
), ),
# edgenext_base=_cfg(
# url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth"),
edgenext_base=_cfg( # USI weights
url="https://github.com/mmaaz60/EdgeNeXt/releases/download/v1.2/edgenext_base_usi.pth",
crop_pct=0.95, test_input_size=(3, 320, 320), test_crop_pct=1.0,
),
edgenext_small_rw=_cfg( edgenext_small_rw=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth', url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/edgenext_small_rw-sw-b00041bb.pth',
@ -154,7 +160,7 @@ class CrossCovarianceAttn(nn.Module):
def forward(self, x): def forward(self, x):
B, N, C = x.shape B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1) qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 4, 1)
q, k, v = qkv.unbind(0) q, k, v = qkv.unbind(0)
# NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map # NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map
@ -217,7 +223,8 @@ class SplitTransposeBlock(nn.Module):
shortcut = x shortcut = x
# scales code re-written for torchscript as per my res2net fixes -rw # scales code re-written for torchscript as per my res2net fixes -rw
spx = torch.split(x, self.width, 1) # NOTE torch.split(x, self.width, 1) causing issues with ONNX export
spx = x.chunk(len(self.convs) + 1, dim=1)
spo = [] spo = []
sp = spx[0] sp = spx[0]
for i, conv in enumerate(self.convs): for i, conv in enumerate(self.convs):
@ -545,13 +552,19 @@ def edgenext_small(pretrained=False, **kwargs):
return _create_edgenext('edgenext_small', pretrained=pretrained, **model_kwargs) return _create_edgenext('edgenext_small', pretrained=pretrained, **model_kwargs)
@register_model
def edgenext_base(pretrained=False, **kwargs):
# 18.51M & 3840.93M @ 256 resolution
# 82.5% (normal) 83.7% (USI) Top-1 accuracy
# AA=True, Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
# Jetson FPS=xx.xx versus xx.xx for MobileViT_S
# For A100: FPS @ BS=1: xxx.xx & @ BS=256: xxxx.xx
model_kwargs = dict(depths=[3, 3, 9, 3], dims=[80, 160, 288, 584], **kwargs)
return _create_edgenext('edgenext_base', pretrained=pretrained, **model_kwargs)
@register_model @register_model
def edgenext_small_rw(pretrained=False, **kwargs): def edgenext_small_rw(pretrained=False, **kwargs):
# 5.59M & 1260.59M @ 256 resolution
# 79.43% Top-1 accuracy
# AA=True, No Mixup & Cutmix, DropPath=0.1, BS=4096, lr=0.006, multi-scale-sampler
# Jetson FPS=20.47 versus 18.86 for MobileViT_S
# For A100: FPS @ BS=1: 172.33 & @ BS=256: 3010.25 versus FPS @ BS=1: 93.84 & @ BS=256: 1785.92 for MobileViT_S
model_kwargs = dict( model_kwargs = dict(
depths=(3, 3, 9, 3), dims=(48, 96, 192, 384), depths=(3, 3, 9, 3), dims=(48, 96, 192, 384),
downsample_block=True, conv_bias=False, stem_type='overlap', **kwargs) downsample_block=True, conv_bias=False, stem_type='overlap', **kwargs)

Loading…
Cancel
Save