|
|
|
@ -373,25 +373,37 @@ def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'):
|
|
|
|
|
|
|
|
|
|
_USE_SWISH_OPT = True
|
|
|
|
|
if _USE_SWISH_OPT:
|
|
|
|
|
class SwishAutoFn(torch.autograd.Function):
|
|
|
|
|
""" Memory Efficient Swish
|
|
|
|
|
From: https://blog.ceshine.net/post/pytorch-memory-swish/
|
|
|
|
|
@torch.jit.script
|
|
|
|
|
def swish_jit_fwd(x):
|
|
|
|
|
return x.mul(torch.sigmoid(x))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.script
|
|
|
|
|
def swish_jit_bwd(x, grad_output):
|
|
|
|
|
x_sigmoid = torch.sigmoid(x)
|
|
|
|
|
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SwishJitAutoFn(torch.autograd.Function):
|
|
|
|
|
""" torch.jit.script optimised Swish
|
|
|
|
|
Inspired by conversation btw Jeremy Howard & Adam Pazske
|
|
|
|
|
https://twitter.com/jeremyphoward/status/1188251041835315200
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def forward(ctx, x):
|
|
|
|
|
result = x.mul(torch.sigmoid(x))
|
|
|
|
|
ctx.save_for_backward(x)
|
|
|
|
|
return result
|
|
|
|
|
return swish_jit_fwd(x)
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def backward(ctx, grad_output):
|
|
|
|
|
x = ctx.saved_variables[0]
|
|
|
|
|
sigmoid_x = torch.sigmoid(x)
|
|
|
|
|
return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x)))
|
|
|
|
|
x = ctx.saved_tensors[0]
|
|
|
|
|
return swish_jit_bwd(x, grad_output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def swish(x, inplace=False):
|
|
|
|
|
return SwishAutoFn.apply(x)
|
|
|
|
|
# inplace ignored
|
|
|
|
|
return SwishJitAutoFn.apply(x)
|
|
|
|
|
else:
|
|
|
|
|
def swish(x, inplace=False):
|
|
|
|
|
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
|
|
|
|
|