From fd962c4b4a5214650a8678a2a987d1853933e1c0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sun, 29 Nov 2020 21:56:55 -0800 Subject: [PATCH] Native SiLU (Swish) op doesn't export to ONNX --- timm/models/layers/create_act.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py index 6f2ab83e..3f39bcf4 100644 --- a/timm/models/layers/create_act.py +++ b/timm/models/layers/create_act.py @@ -98,7 +98,10 @@ def get_act_fn(name='relu'): # custom autograd, then fallback if name in _ACT_FN_ME: return _ACT_FN_ME[name] - if not is_no_jit(): + if is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return swish + if not (is_no_jit() or is_exportable()): if name in _ACT_FN_JIT: return _ACT_FN_JIT[name] return _ACT_FN_DEFAULT[name] @@ -114,7 +117,10 @@ def get_act_layer(name='relu'): if not (is_no_jit() or is_exportable() or is_scriptable()): if name in _ACT_LAYER_ME: return _ACT_LAYER_ME[name] - if not is_no_jit(): + if is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return Swish + if not (is_no_jit() or is_exportable()): if name in _ACT_LAYER_JIT: return _ACT_LAYER_JIT[name] return _ACT_LAYER_DEFAULT[name]