diff --git a/timm/models/layers/activations_me.py b/timm/models/layers/activations_me.py index 0441f7c4..29bc0863 100644 --- a/timm/models/layers/activations_me.py +++ b/timm/models/layers/activations_me.py @@ -152,6 +152,13 @@ class HardSwishJitAutoFn(torch.autograd.Function): x = ctx.saved_tensors[0] return hard_swish_jit_bwd(x, grad_output) + @staticmethod + def symbolic(g, self): + input = g.op("Add", self, g.op('Constant', value_t=torch.tensor(3, dtype=torch.float))) + hardtanh_ = g.op("Clip", input, g.op('Constant', value_t=torch.tensor(0, dtype=torch.float)), g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) + hardtanh_ = g.op("Div", hardtanh_, g.op('Constant', value_t=torch.tensor(6, dtype=torch.float))) + return g.op("Mul", self, hardtanh_) + def hard_swish_me(x, inplace=False): return HardSwishJitAutoFn.apply(x)