Add custom grad tests, fix cut & paste error with hard_mish ME, add a few more pytorch act fns to factory

pull/155/head
Ross Wightman 4 years ago
parent 6c7932fe75
commit 151679c2f1

@ -0,0 +1,71 @@
import pytest
import torch
import torch.nn as nn
import platform
import os
from timm.models.layers import create_act_layer, get_act_layer, set_layer_config
class MLP(nn.Module):
def __init__(self, act_layer="relu"):
super(MLP, self).__init__()
self.fc1 = nn.Linear(1000, 100)
self.act = create_act_layer(act_layer, inplace=True)
self.fc2 = nn.Linear(100, 10)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
def _run_act_layer_grad(act_type):
x = torch.rand(10, 1000) * 10
m = MLP(act_layer=act_type)
def _run(x, act_layer=''):
if act_layer:
# replace act layer if set
m.act = create_act_layer(act_layer, inplace=True)
out = m(x)
l = (out - 0).pow(2).sum()
return l
out_me = _run(x)
with set_layer_config(scriptable=True):
out_jit = _run(x, act_type)
assert torch.isclose(out_jit, out_me)
with set_layer_config(no_jit=True):
out_basic = _run(x, act_type)
assert torch.isclose(out_basic, out_jit)
def test_swish_grad():
for _ in range(100):
_run_act_layer_grad('swish')
def test_mish_grad():
for _ in range(100):
_run_act_layer_grad('mish')
def test_hard_sigmoid_grad():
for _ in range(100):
_run_act_layer_grad('hard_sigmoid')
def test_hard_swish_grad():
for _ in range(100):
_run_act_layer_grad('hard_swish')
def test_hard_mish_grad():
for _ in range(100):
_run_act_layer_grad('hard_mish')

@ -185,12 +185,12 @@ class HardMishJitAutoFn(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
return mish_jit_fwd(x)
return hard_mish_jit_fwd(x)
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_tensors[0]
return mish_jit_bwd(x, grad_output)
return hard_mish_jit_bwd(x, grad_output)
def hard_mish_me(x, inplace: bool = False):

@ -9,6 +9,12 @@ _ACT_FN_DEFAULT = dict(
mish=mish,
relu=F.relu,
relu6=F.relu6,
leaky_relu=F.leaky_relu,
elu=F.elu,
prelu=F.prelu,
celu=F.celu,
selu=F.selu,
gelu=F.gelu,
sigmoid=sigmoid,
tanh=tanh,
hard_sigmoid=hard_sigmoid,
@ -37,6 +43,11 @@ _ACT_LAYER_DEFAULT = dict(
mish=Mish,
relu=nn.ReLU,
relu6=nn.ReLU6,
elu=nn.ELU,
prelu=nn.PReLU,
celu=nn.CELU,
selu=nn.SELU,
gelu=nn.GELU,
sigmoid=Sigmoid,
tanh=Tanh,
hard_sigmoid=HardSigmoid,

Loading…
Cancel
Save