Fix inplace arg compat for GELU and PreLU via activation factory

pull/297/head
Ross Wightman 4 years ago
parent fd962c4b4a
commit 5f4b6076d8

@ -119,3 +119,27 @@ class HardMish(nn.Module):
def forward(self, x): def forward(self, x):
return hard_mish(x, self.inplace) return hard_mish(x, self.inplace)
class PReLU(nn.PReLU):
"""Applies PReLU (w/ dummy inplace arg)
"""
def __init__(self, num_parameters: int = 1, init: float = 0.25, inplace: bool = False) -> None:
super(PReLU, self).__init__(num_parameters=num_parameters, init=init)
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.prelu(input, self.weight)
def gelu(x: torch.Tensor, inplace: bool = False) -> torch.Tensor:
return F.gelu(x)
class GELU(nn.Module):
"""Applies the Gaussian Error Linear Units function (w/ dummy inplace arg)
"""
def __init__(self, inplace: bool = False):
super(GELU, self).__init__()
def forward(self, input: torch.Tensor) -> torch.Tensor:
return F.gelu(input)

@ -19,10 +19,9 @@ _ACT_FN_DEFAULT = dict(
relu6=F.relu6, relu6=F.relu6,
leaky_relu=F.leaky_relu, leaky_relu=F.leaky_relu,
elu=F.elu, elu=F.elu,
prelu=F.prelu,
celu=F.celu, celu=F.celu,
selu=F.selu, selu=F.selu,
gelu=F.gelu, gelu=gelu,
sigmoid=sigmoid, sigmoid=sigmoid,
tanh=tanh, tanh=tanh,
hard_sigmoid=hard_sigmoid, hard_sigmoid=hard_sigmoid,
@ -56,10 +55,10 @@ _ACT_LAYER_DEFAULT = dict(
relu6=nn.ReLU6, relu6=nn.ReLU6,
leaky_relu=nn.LeakyReLU, leaky_relu=nn.LeakyReLU,
elu=nn.ELU, elu=nn.ELU,
prelu=nn.PReLU, prelu=PReLU,
celu=nn.CELU, celu=nn.CELU,
selu=nn.SELU, selu=nn.SELU,
gelu=nn.GELU, gelu=GELU,
sigmoid=Sigmoid, sigmoid=Sigmoid,
tanh=Tanh, tanh=Tanh,
hard_sigmoid=HardSigmoid, hard_sigmoid=HardSigmoid,

Loading…
Cancel
Save