Native SiLU (Swish) op doesn't export to ONNX

pull/297/head
Ross Wightman 4 years ago
parent 27bbc70d71
commit fd962c4b4a

@ -98,7 +98,10 @@ def get_act_fn(name='relu'):
# custom autograd, then fallback # custom autograd, then fallback
if name in _ACT_FN_ME: if name in _ACT_FN_ME:
return _ACT_FN_ME[name] return _ACT_FN_ME[name]
if not is_no_jit(): if is_exportable() and name in ('silu', 'swish'):
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
return swish
if not (is_no_jit() or is_exportable()):
if name in _ACT_FN_JIT: if name in _ACT_FN_JIT:
return _ACT_FN_JIT[name] return _ACT_FN_JIT[name]
return _ACT_FN_DEFAULT[name] return _ACT_FN_DEFAULT[name]
@ -114,7 +117,10 @@ def get_act_layer(name='relu'):
if not (is_no_jit() or is_exportable() or is_scriptable()): if not (is_no_jit() or is_exportable() or is_scriptable()):
if name in _ACT_LAYER_ME: if name in _ACT_LAYER_ME:
return _ACT_LAYER_ME[name] return _ACT_LAYER_ME[name]
if not is_no_jit(): if is_exportable() and name in ('silu', 'swish'):
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
return Swish
if not (is_no_jit() or is_exportable()):
if name in _ACT_LAYER_JIT: if name in _ACT_LAYER_JIT:
return _ACT_LAYER_JIT[name] return _ACT_LAYER_JIT[name]
return _ACT_LAYER_DEFAULT[name] return _ACT_LAYER_DEFAULT[name]

Loading…
Cancel
Save