Bring in JIT version of optimized swish activation from gen_efficientnet as default (while working on feature extraction functionality here).

pull/52/head
Ross Wightman 5 years ago
parent 1f39d15f15
commit 576d360f20

@ -373,25 +373,37 @@ def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'):
_USE_SWISH_OPT = True _USE_SWISH_OPT = True
if _USE_SWISH_OPT: if _USE_SWISH_OPT:
class SwishAutoFn(torch.autograd.Function): @torch.jit.script
""" Memory Efficient Swish def swish_jit_fwd(x):
From: https://blog.ceshine.net/post/pytorch-memory-swish/ return x.mul(torch.sigmoid(x))
@torch.jit.script
def swish_jit_bwd(x, grad_output):
x_sigmoid = torch.sigmoid(x)
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
class SwishJitAutoFn(torch.autograd.Function):
""" torch.jit.script optimised Swish
Inspired by conversation btw Jeremy Howard & Adam Pazske
https://twitter.com/jeremyphoward/status/1188251041835315200
""" """
@staticmethod @staticmethod
def forward(ctx, x): def forward(ctx, x):
result = x.mul(torch.sigmoid(x))
ctx.save_for_backward(x) ctx.save_for_backward(x)
return result return swish_jit_fwd(x)
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
x = ctx.saved_variables[0] x = ctx.saved_tensors[0]
sigmoid_x = torch.sigmoid(x) return swish_jit_bwd(x, grad_output)
return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x)))
def swish(x, inplace=False): def swish(x, inplace=False):
return SwishAutoFn.apply(x) # inplace ignored
return SwishJitAutoFn.apply(x)
else: else:
def swish(x, inplace=False): def swish(x, inplace=False):
return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())

Loading…
Cancel
Save