Merge pull request #401 from hwangdeyu/deyu/add_HardSwishJitAutoFn_operator

add HardSwishJitAutoFn operator export to onnx
pull/413/head^2
Ross Wightman 4 years ago committed by GitHub
commit ea36a78cff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -152,6 +152,13 @@ class HardSwishJitAutoFn(torch.autograd.Function):
x = ctx.saved_tensors[0] x = ctx.saved_tensors[0]
return hard_swish_jit_bwd(x, grad_output) return hard_swish_jit_bwd(x, grad_output)
@staticmethod
def symbolic(g, self):
input = g.op("Add", self, g.op('Constant', value_t=torch.tensor(3, dtype=torch.float)))
hardtanh_ = g.op("Clip", input, g.op('Constant', value_t=torch.tensor(0, dtype=torch.float)), g.op('Constant', value_t=torch.tensor(6, dtype=torch.float)))
hardtanh_ = g.op("Div", hardtanh_, g.op('Constant', value_t=torch.tensor(6, dtype=torch.float)))
return g.op("Mul", self, hardtanh_)
def hard_swish_me(x, inplace=False): def hard_swish_me(x, inplace=False):
return HardSwishJitAutoFn.apply(x) return HardSwishJitAutoFn.apply(x)

Loading…
Cancel
Save