diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py index 5bc4db99..6f2ab83e 100644 --- a/timm/models/layers/create_act.py +++ b/timm/models/layers/create_act.py @@ -6,9 +6,14 @@ from .activations_jit import * from .activations_me import * from .config import is_exportable, is_scriptable, is_no_jit +# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. This code +# will use native version if present. Eventually, the custom Swish layers will be removed +# and only native 'silu' will be used. +_has_silu = 'silu' in dir(torch.nn.functional) _ACT_FN_DEFAULT = dict( - swish=swish, + silu=F.silu if _has_silu else swish, + swish=F.silu if _has_silu else swish, mish=mish, relu=F.relu, relu6=F.relu6, @@ -26,7 +31,8 @@ _ACT_FN_DEFAULT = dict( ) _ACT_FN_JIT = dict( - swish=swish_jit, + silu=F.silu if _has_silu else swish_jit, + swish=F.silu if _has_silu else swish_jit, mish=mish_jit, hard_sigmoid=hard_sigmoid_jit, hard_swish=hard_swish_jit, @@ -34,7 +40,8 @@ _ACT_FN_JIT = dict( ) _ACT_FN_ME = dict( - swish=swish_me, + silu=F.silu if _has_silu else swish_me, + swish=F.silu if _has_silu else swish_me, mish=mish_me, hard_sigmoid=hard_sigmoid_me, hard_swish=hard_swish_me, @@ -42,7 +49,8 @@ _ACT_FN_ME = dict( ) _ACT_LAYER_DEFAULT = dict( - swish=Swish, + silu=nn.SiLU if _has_silu else Swish, + swish=nn.SiLU if _has_silu else Swish, mish=Mish, relu=nn.ReLU, relu6=nn.ReLU6, @@ -60,7 +68,8 @@ _ACT_LAYER_DEFAULT = dict( ) _ACT_LAYER_JIT = dict( - swish=SwishJit, + silu=nn.SiLU if _has_silu else SwishJit, + swish=nn.SiLU if _has_silu else SwishJit, mish=MishJit, hard_sigmoid=HardSigmoidJit, hard_swish=HardSwishJit, @@ -68,7 +77,8 @@ _ACT_LAYER_JIT = dict( ) _ACT_LAYER_ME = dict( - swish=SwishMe, + silu=nn.SiLU if _has_silu else SwishMe, + swish=nn.SiLU if _has_silu else SwishMe, mish=MishMe, hard_sigmoid=HardSigmoidMe, hard_swish=HardSwishMe,