From 7a4be5c035166e0d0fd037c4f834a3fb39433b7a Mon Sep 17 00:00:00 2001 From: hwangdeyu Date: Tue, 2 Feb 2021 18:02:41 +0800 Subject: [PATCH] add operator HardSwishJitAutoFn export to onnx --- timm/models/layers/activations_me.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/timm/models/layers/activations_me.py b/timm/models/layers/activations_me.py index 0441f7c4..29bc0863 100644 --- a/timm/models/layers/activations_me.py +++ b/timm/models/layers/activations_me.py @@ -152,6 +152,13 @@ class HardSwishJitAutoFn(torch.autograd.Function): x = ctx.saved_tensors[0] return hard_swish_jit_bwd(x, grad_output) + @staticmethod + def symbolic(g, self): + input = g.op("Add", self, g.op('Constant', value_t=torch.tensor(3, dtype=torch.float))) + hardtanh_ = g.op("Clip", input, g.op('Constant', value_t=torch.tensor(0, dtype=torch.float)), g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) + hardtanh_ = g.op("Div", hardtanh_, g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) + return g.op("Mul", self, hardtanh_) + def hard_swish_me(x, inplace=False): return HardSwishJitAutoFn.apply(x)