Support native silu activation (aka swish). An optimized ver is available in PyTorch 1.7.

pull/263/head
Ross Wightman 4 years ago
parent da6cd2cc1f
commit e90edce438

@ -6,9 +6,14 @@ from .activations_jit import *
from .activations_me import * from .activations_me import *
from .config import is_exportable, is_scriptable, is_no_jit from .config import is_exportable, is_scriptable, is_no_jit
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. This code
# will use native version if present. Eventually, the custom Swish layers will be removed
# and only native 'silu' will be used.
_has_silu = 'silu' in dir(torch.nn.functional)
_ACT_FN_DEFAULT = dict( _ACT_FN_DEFAULT = dict(
swish=swish, silu=F.silu if _has_silu else swish,
swish=F.silu if _has_silu else swish,
mish=mish, mish=mish,
relu=F.relu, relu=F.relu,
relu6=F.relu6, relu6=F.relu6,
@ -26,7 +31,8 @@ _ACT_FN_DEFAULT = dict(
) )
_ACT_FN_JIT = dict( _ACT_FN_JIT = dict(
swish=swish_jit, silu=F.silu if _has_silu else swish_jit,
swish=F.silu if _has_silu else swish_jit,
mish=mish_jit, mish=mish_jit,
hard_sigmoid=hard_sigmoid_jit, hard_sigmoid=hard_sigmoid_jit,
hard_swish=hard_swish_jit, hard_swish=hard_swish_jit,
@ -34,7 +40,8 @@ _ACT_FN_JIT = dict(
) )
_ACT_FN_ME = dict( _ACT_FN_ME = dict(
swish=swish_me, silu=F.silu if _has_silu else swish_me,
swish=F.silu if _has_silu else swish_me,
mish=mish_me, mish=mish_me,
hard_sigmoid=hard_sigmoid_me, hard_sigmoid=hard_sigmoid_me,
hard_swish=hard_swish_me, hard_swish=hard_swish_me,
@ -42,7 +49,8 @@ _ACT_FN_ME = dict(
) )
_ACT_LAYER_DEFAULT = dict( _ACT_LAYER_DEFAULT = dict(
swish=Swish, silu=nn.SiLU if _has_silu else Swish,
swish=nn.SiLU if _has_silu else Swish,
mish=Mish, mish=Mish,
relu=nn.ReLU, relu=nn.ReLU,
relu6=nn.ReLU6, relu6=nn.ReLU6,
@ -60,7 +68,8 @@ _ACT_LAYER_DEFAULT = dict(
) )
_ACT_LAYER_JIT = dict( _ACT_LAYER_JIT = dict(
swish=SwishJit, silu=nn.SiLU if _has_silu else SwishJit,
swish=nn.SiLU if _has_silu else SwishJit,
mish=MishJit, mish=MishJit,
hard_sigmoid=HardSigmoidJit, hard_sigmoid=HardSigmoidJit,
hard_swish=HardSwishJit, hard_swish=HardSwishJit,
@ -68,7 +77,8 @@ _ACT_LAYER_JIT = dict(
) )
_ACT_LAYER_ME = dict( _ACT_LAYER_ME = dict(
swish=SwishMe, silu=nn.SiLU if _has_silu else SwishMe,
swish=nn.SiLU if _has_silu else SwishMe,
mish=MishMe, mish=MishMe,
hard_sigmoid=HardSigmoidMe, hard_sigmoid=HardSigmoidMe,
hard_swish=HardSwishMe, hard_swish=HardSwishMe,

Loading…
Cancel
Save