diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py index 6f2ab83e..3f39bcf4 100644 --- a/timm/models/layers/create_act.py +++ b/timm/models/layers/create_act.py @@ -98,7 +98,10 @@ def get_act_fn(name='relu'): # custom autograd, then fallback if name in _ACT_FN_ME: return _ACT_FN_ME[name] - if not is_no_jit(): + if is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return swish + if not (is_no_jit() or is_exportable()): if name in _ACT_FN_JIT: return _ACT_FN_JIT[name] return _ACT_FN_DEFAULT[name] @@ -114,7 +117,10 @@ def get_act_layer(name='relu'): if not (is_no_jit() or is_exportable() or is_scriptable()): if name in _ACT_LAYER_ME: return _ACT_LAYER_ME[name] - if not is_no_jit(): + if is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return Swish + if not (is_no_jit() or is_exportable()): if name in _ACT_LAYER_JIT: return _ACT_LAYER_JIT[name] return _ACT_LAYER_DEFAULT[name]