diff --git a/timm/models/gen_efficientnet.py b/timm/models/gen_efficientnet.py index c11782a7..00c2c86c 100644 --- a/timm/models/gen_efficientnet.py +++ b/timm/models/gen_efficientnet.py @@ -371,11 +371,30 @@ def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'): return arch_args -def swish(x, inplace=False): - if inplace: - return x.mul_(x.sigmoid()) - else: - return x * x.sigmoid() +_USE_SWISH_OPT = True +if _USE_SWISH_OPT: + class SwishAutoFn(torch.autograd.Function): + """ Memory Efficient Swish + From: https://blog.ceshine.net/post/pytorch-memory-swish/ + """ + @staticmethod + def forward(ctx, x): + result = x.mul(torch.sigmoid(x)) + ctx.save_for_backward(x) + return result + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_variables[0] + sigmoid_x = torch.sigmoid(x) + return grad_output * (sigmoid_x * (1 + x * (1 - sigmoid_x))) + + + def swish(x, inplace=False): + return SwishAutoFn.apply(x) +else: + def swish(x, inplace=False): + return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) def sigmoid(x, inplace=False):