diff --git a/tests/test_layers.py b/tests/test_layers.py new file mode 100644 index 00000000..714cb444 --- /dev/null +++ b/tests/test_layers.py @@ -0,0 +1,71 @@ +import pytest +import torch +import torch.nn as nn +import platform +import os + +from timm.models.layers import create_act_layer, get_act_layer, set_layer_config + + +class MLP(nn.Module): + def __init__(self, act_layer="relu"): + super(MLP, self).__init__() + self.fc1 = nn.Linear(1000, 100) + self.act = create_act_layer(act_layer, inplace=True) + self.fc2 = nn.Linear(100, 10) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.fc2(x) + return x + + +def _run_act_layer_grad(act_type): + x = torch.rand(10, 1000) * 10 + m = MLP(act_layer=act_type) + + def _run(x, act_layer=''): + if act_layer: + # replace act layer if set + m.act = create_act_layer(act_layer, inplace=True) + out = m(x) + l = (out - 0).pow(2).sum() + return l + + out_me = _run(x) + + with set_layer_config(scriptable=True): + out_jit = _run(x, act_type) + + assert torch.isclose(out_jit, out_me) + + with set_layer_config(no_jit=True): + out_basic = _run(x, act_type) + + assert torch.isclose(out_basic, out_jit) + + +def test_swish_grad(): + for _ in range(100): + _run_act_layer_grad('swish') + + +def test_mish_grad(): + for _ in range(100): + _run_act_layer_grad('mish') + + +def test_hard_sigmoid_grad(): + for _ in range(100): + _run_act_layer_grad('hard_sigmoid') + + +def test_hard_swish_grad(): + for _ in range(100): + _run_act_layer_grad('hard_swish') + + +def test_hard_mish_grad(): + for _ in range(100): + _run_act_layer_grad('hard_mish') diff --git a/timm/models/layers/activations_me.py b/timm/models/layers/activations_me.py index 9c492f1e..b81f7165 100644 --- a/timm/models/layers/activations_me.py +++ b/timm/models/layers/activations_me.py @@ -185,12 +185,12 @@ class HardMishJitAutoFn(torch.autograd.Function): @staticmethod def forward(ctx, x): ctx.save_for_backward(x) - return mish_jit_fwd(x) + return hard_mish_jit_fwd(x) @staticmethod def backward(ctx, grad_output): x = ctx.saved_tensors[0] - return mish_jit_bwd(x, grad_output) + return hard_mish_jit_bwd(x, grad_output) def hard_mish_me(x, inplace: bool = False): diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py index 66ab1e84..6404d62f 100644 --- a/timm/models/layers/create_act.py +++ b/timm/models/layers/create_act.py @@ -9,6 +9,12 @@ _ACT_FN_DEFAULT = dict( mish=mish, relu=F.relu, relu6=F.relu6, + leaky_relu=F.leaky_relu, + elu=F.elu, + prelu=F.prelu, + celu=F.celu, + selu=F.selu, + gelu=F.gelu, sigmoid=sigmoid, tanh=tanh, hard_sigmoid=hard_sigmoid, @@ -37,6 +43,11 @@ _ACT_LAYER_DEFAULT = dict( mish=Mish, relu=nn.ReLU, relu6=nn.ReLU6, + elu=nn.ELU, + prelu=nn.PReLU, + celu=nn.CELU, + selu=nn.SELU, + gelu=nn.GELU, sigmoid=Sigmoid, tanh=Tanh, hard_sigmoid=HardSigmoid,