Add memory efficient Swish impl

pull/52/head
Ross Wightman 5 years ago
parent 187ecbafbe
commit a9eb484835

@ -371,11 +371,30 @@ def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'):
return arch_args return arch_args
def swish(x, inplace=False): _USE_SWISH_OPT = True
if inplace: if _USE_SWISH_OPT:
return x.mul_(x.sigmoid()) class SwishAutoFn(torch.autograd.Function):
else: """ Memory Efficient Swish
return x * x.sigmoid() From: https://blog.ceshine.net/post/pytorch-memory-swish/
"""
@staticmethod
def forward(ctx, x):
result = x.mul(torch.sigmoid(x))
ctx.save_for_backward(x)
return result
@staticmethod
def backward(ctx, grad_output):
x = ctx.saved_variables[0]
sigmoid_x = torch.sigmoid(x)
return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x)))
def swish(x, inplace=False):
return SwishAutoFn.apply(x)
else:
def swish(x, inplace=False):
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())
def sigmoid(x, inplace=False): def sigmoid(x, inplace=False):

Loading…
Cancel
Save