From 576d360f20c8299cfd909c86edad4afff45d3d01 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 22 Nov 2019 13:57:45 -0800 Subject: [PATCH] Bring in JIT version of optimized swish activation from gen_efficientnet as default (while working on feature extraction functionality here). --- timm/models/gen_efficientnet.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/timm/models/gen_efficientnet.py b/timm/models/gen_efficientnet.py index d3a5cb60..a7191025 100644 --- a/timm/models/gen_efficientnet.py +++ b/timm/models/gen_efficientnet.py @@ -373,25 +373,37 @@ def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'): _USE_SWISH_OPT = True if _USE_SWISH_OPT: - class SwishAutoFn(torch.autograd.Function): - """ Memory Efficient Swish - From: https://blog.ceshine.net/post/pytorch-memory-swish/ + @torch.jit.script + def swish_jit_fwd(x): + return x.mul(torch.sigmoid(x)) + + + @torch.jit.script + def swish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) + + + class SwishJitAutoFn(torch.autograd.Function): + """ torch.jit.script optimised Swish + Inspired by conversation btw Jeremy Howard & Adam Pazske + https://twitter.com/jeremyphoward/status/1188251041835315200 """ + @staticmethod def forward(ctx, x): - result = x.mul(torch.sigmoid(x)) ctx.save_for_backward(x) - return result + return swish_jit_fwd(x) @staticmethod def backward(ctx, grad_output): - x = ctx.saved_variables[0] - sigmoid_x = torch.sigmoid(x) - return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x))) + x = ctx.saved_tensors[0] + return swish_jit_bwd(x, grad_output) def swish(x, inplace=False): - return SwishAutoFn.apply(x) + # inplace ignored + return SwishJitAutoFn.apply(x) else: def swish(x, inplace=False): return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid())