From a9eb48483564336049ad5d2283dea0cb334a7510 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 19 Oct 2019 14:48:30 -0700 Subject: [PATCH] Add memory efficient Swish impl --- timm/models/gen_efficientnet.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/timm/models/gen_efficientnet.py b/timm/models/gen_efficientnet.py index c11782a7..00c2c86c 100644 --- a/timm/models/gen_efficientnet.py +++ b/timm/models/gen_efficientnet.py @@ -371,11 +371,30 @@ def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'): return arch_args -def swish(x, inplace=False): - if inplace: - return x.mul_(x.sigmoid()) - else: - return x * x.sigmoid() +_USE_SWISH_OPT = True +if _USE_SWISH_OPT: + class SwishAutoFn(torch.autograd.Function): + """ Memory Efficient Swish + From: https://blog.ceshine.net/post/pytorch-memory-swish/ + """ + @staticmethod + def forward(ctx, x): + result = x.mul(torch.sigmoid(x)) + ctx.save_for_backward(x) + return result + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_variables[0] + sigmoid_x = torch.sigmoid(x) + return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x))) + + + def swish(x, inplace=False): + return SwishAutoFn.apply(x) +else: + def swish(x, inplace=False): + return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) def sigmoid(x, inplace=False):