|
|
@ -98,7 +98,10 @@ def get_act_fn(name='relu'):
|
|
|
|
# custom autograd, then fallback
|
|
|
|
# custom autograd, then fallback
|
|
|
|
if name in _ACT_FN_ME:
|
|
|
|
if name in _ACT_FN_ME:
|
|
|
|
return _ACT_FN_ME[name]
|
|
|
|
return _ACT_FN_ME[name]
|
|
|
|
if not is_no_jit():
|
|
|
|
if is_exportable() and name in ('silu', 'swish'):
|
|
|
|
|
|
|
|
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
|
|
|
|
|
|
|
|
return swish
|
|
|
|
|
|
|
|
if not (is_no_jit() or is_exportable()):
|
|
|
|
if name in _ACT_FN_JIT:
|
|
|
|
if name in _ACT_FN_JIT:
|
|
|
|
return _ACT_FN_JIT[name]
|
|
|
|
return _ACT_FN_JIT[name]
|
|
|
|
return _ACT_FN_DEFAULT[name]
|
|
|
|
return _ACT_FN_DEFAULT[name]
|
|
|
@ -114,7 +117,10 @@ def get_act_layer(name='relu'):
|
|
|
|
if not (is_no_jit() or is_exportable() or is_scriptable()):
|
|
|
|
if not (is_no_jit() or is_exportable() or is_scriptable()):
|
|
|
|
if name in _ACT_LAYER_ME:
|
|
|
|
if name in _ACT_LAYER_ME:
|
|
|
|
return _ACT_LAYER_ME[name]
|
|
|
|
return _ACT_LAYER_ME[name]
|
|
|
|
if not is_no_jit():
|
|
|
|
if is_exportable() and name in ('silu', 'swish'):
|
|
|
|
|
|
|
|
# FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack
|
|
|
|
|
|
|
|
return Swish
|
|
|
|
|
|
|
|
if not (is_no_jit() or is_exportable()):
|
|
|
|
if name in _ACT_LAYER_JIT:
|
|
|
|
if name in _ACT_LAYER_JIT:
|
|
|
|
return _ACT_LAYER_JIT[name]
|
|
|
|
return _ACT_LAYER_JIT[name]
|
|
|
|
return _ACT_LAYER_DEFAULT[name]
|
|
|
|
return _ACT_LAYER_DEFAULT[name]
|
|
|
|