You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/models/layers/create_act.py

118 lines
2.8 KiB

""" Activation Factory
Hacked together by / Copyright 2020 Ross Wightman
"""
from .activations import *
from .activations_jit import *
from .activations_me import *
from .config import is_exportable, is_scriptable, is_no_jit
_ACT_FN_DEFAULT = dict(
swish=swish,
mish=mish,
relu=F.relu,
relu6=F.relu6,
leaky_relu=F.leaky_relu,
elu=F.elu,
prelu=F.prelu,
celu=F.celu,
selu=F.selu,
gelu=F.gelu,
sigmoid=sigmoid,
tanh=tanh,
hard_sigmoid=hard_sigmoid,
hard_swish=hard_swish,
hard_mish=hard_mish,
)
_ACT_FN_JIT = dict(
swish=swish_jit,
mish=mish_jit,
hard_sigmoid=hard_sigmoid_jit,
hard_swish=hard_swish_jit,
hard_mish=hard_mish_jit
)
_ACT_FN_ME = dict(
swish=swish_me,
mish=mish_me,
hard_sigmoid=hard_sigmoid_me,
hard_swish=hard_swish_me,
hard_mish=hard_mish_me,
)
_ACT_LAYER_DEFAULT = dict(
swish=Swish,
mish=Mish,
relu=nn.ReLU,
relu6=nn.ReLU6,
elu=nn.ELU,
prelu=nn.PReLU,
celu=nn.CELU,
selu=nn.SELU,
gelu=nn.GELU,
sigmoid=Sigmoid,
tanh=Tanh,
hard_sigmoid=HardSigmoid,
hard_swish=HardSwish,
hard_mish=HardMish,
)
_ACT_LAYER_JIT = dict(
swish=SwishJit,
mish=MishJit,
hard_sigmoid=HardSigmoidJit,
hard_swish=HardSwishJit,
hard_mish=HardMishJit
)
_ACT_LAYER_ME = dict(
swish=SwishMe,
mish=MishMe,
hard_sigmoid=HardSigmoidMe,
hard_swish=HardSwishMe,
hard_mish=HardMishMe,
)
def get_act_fn(name='relu'):
""" Activation Function Factory
Fetching activation fns by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
"""
if not name:
return None
if not (is_no_jit() or is_exportable() or is_scriptable()):
# If not exporting or scripting the model, first look for a memory-efficient version with
# custom autograd, then fallback
if name in _ACT_FN_ME:
return _ACT_FN_ME[name]
if not is_no_jit():
if name in _ACT_FN_JIT:
return _ACT_FN_JIT[name]
return _ACT_FN_DEFAULT[name]
def get_act_layer(name='relu'):
""" Activation Layer Factory
Fetching activation layers by name with this function allows export or torch script friendly
functions to be returned dynamically based on current config.
"""
if not name:
return None
if not (is_no_jit() or is_exportable() or is_scriptable()):
if name in _ACT_LAYER_ME:
return _ACT_LAYER_ME[name]
if not is_no_jit():
if name in _ACT_LAYER_JIT:
return _ACT_LAYER_JIT[name]
return _ACT_LAYER_DEFAULT[name]
def create_act_layer(name, inplace=False, **kwargs):
act_layer = get_act_layer(name)
if act_layer is not None:
return act_layer(inplace=inplace, **kwargs)
else:
return None