You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
219 lines
5.7 KiB
219 lines
5.7 KiB
5 years ago
|
""" Activations (memory-efficient w/ custom autograd)
|
||
|
|
||
|
A collection of activations fn and modules with a common interface so that they can
|
||
|
easily be swapped. All have an `inplace` arg even if not used.
|
||
|
|
||
|
These activations are not compatible with jit scripting or ONNX export of the model, please use either
|
||
|
the JIT or basic versions of the activations.
|
||
|
|
||
4 years ago
|
Hacked together by / Copyright 2020 Ross Wightman
|
||
5 years ago
|
"""
|
||
|
|
||
|
import torch
|
||
|
from torch import nn as nn
|
||
|
from torch.nn import functional as F
|
||
|
|
||
|
|
||
|
@torch.jit.script
|
||
|
def swish_jit_fwd(x):
|
||
|
return x.mul(torch.sigmoid(x))
|
||
|
|
||
|
|
||
|
@torch.jit.script
|
||
|
def swish_jit_bwd(x, grad_output):
|
||
|
x_sigmoid = torch.sigmoid(x)
|
||
|
return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid)))
|
||
|
|
||
|
|
||
|
class SwishJitAutoFn(torch.autograd.Function):
|
||
|
""" torch.jit.script optimised Swish w/ memory-efficient checkpoint
|
||
|
Inspired by conversation btw Jeremy Howard & Adam Pazske
|
||
|
https://twitter.com/jeremyphoward/status/1188251041835315200
|
||
|
"""
|
||
4 years ago
|
@staticmethod
|
||
|
def symbolic(g, x):
|
||
|
return g.op("Mul", x, g.op("Sigmoid", x))
|
||
5 years ago
|
|
||
|
@staticmethod
|
||
|
def forward(ctx, x):
|
||
|
ctx.save_for_backward(x)
|
||
|
return swish_jit_fwd(x)
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_output):
|
||
|
x = ctx.saved_tensors[0]
|
||
|
return swish_jit_bwd(x, grad_output)
|
||
|
|
||
|
|
||
|
def swish_me(x, inplace=False):
|
||
|
return SwishJitAutoFn.apply(x)
|
||
|
|
||
|
|
||
|
class SwishMe(nn.Module):
|
||
|
def __init__(self, inplace: bool = False):
|
||
|
super(SwishMe, self).__init__()
|
||
|
|
||
|
def forward(self, x):
|
||
|
return SwishJitAutoFn.apply(x)
|
||
|
|
||
|
|
||
|
@torch.jit.script
|
||
|
def mish_jit_fwd(x):
|
||
|
return x.mul(torch.tanh(F.softplus(x)))
|
||
|
|
||
|
|
||
|
@torch.jit.script
|
||
|
def mish_jit_bwd(x, grad_output):
|
||
|
x_sigmoid = torch.sigmoid(x)
|
||
|
x_tanh_sp = F.softplus(x).tanh()
|
||
|
return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))
|
||
|
|
||
|
|
||
|
class MishJitAutoFn(torch.autograd.Function):
|
||
|
""" Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681
|
||
|
A memory efficient, jit scripted variant of Mish
|
||
|
"""
|
||
|
@staticmethod
|
||
|
def forward(ctx, x):
|
||
|
ctx.save_for_backward(x)
|
||
|
return mish_jit_fwd(x)
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_output):
|
||
|
x = ctx.saved_tensors[0]
|
||
|
return mish_jit_bwd(x, grad_output)
|
||
|
|
||
|
|
||
|
def mish_me(x, inplace=False):
|
||
|
return MishJitAutoFn.apply(x)
|
||
|
|
||
|
|
||
|
class MishMe(nn.Module):
|
||
|
def __init__(self, inplace: bool = False):
|
||
|
super(MishMe, self).__init__()
|
||
|
|
||
|
def forward(self, x):
|
||
|
return MishJitAutoFn.apply(x)
|
||
|
|
||
|
|
||
|
@torch.jit.script
|
||
|
def hard_sigmoid_jit_fwd(x, inplace: bool = False):
|
||
|
return (x + 3).clamp(min=0, max=6).div(6.)
|
||
|
|
||
|
|
||
|
@torch.jit.script
|
||
|
def hard_sigmoid_jit_bwd(x, grad_output):
|
||
|
m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6.
|
||
|
return grad_output * m
|
||
|
|
||
|
|
||
|
class HardSigmoidJitAutoFn(torch.autograd.Function):
|
||
|
@staticmethod
|
||
|
def forward(ctx, x):
|
||
|
ctx.save_for_backward(x)
|
||
|
return hard_sigmoid_jit_fwd(x)
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_output):
|
||
|
x = ctx.saved_tensors[0]
|
||
|
return hard_sigmoid_jit_bwd(x, grad_output)
|
||
|
|
||
|
|
||
|
def hard_sigmoid_me(x, inplace: bool = False):
|
||
|
return HardSigmoidJitAutoFn.apply(x)
|
||
|
|
||
|
|
||
|
class HardSigmoidMe(nn.Module):
|
||
|
def __init__(self, inplace: bool = False):
|
||
|
super(HardSigmoidMe, self).__init__()
|
||
|
|
||
|
def forward(self, x):
|
||
|
return HardSigmoidJitAutoFn.apply(x)
|
||
|
|
||
|
|
||
|
@torch.jit.script
|
||
|
def hard_swish_jit_fwd(x):
|
||
|
return x * (x + 3).clamp(min=0, max=6).div(6.)
|
||
|
|
||
|
|
||
|
@torch.jit.script
|
||
|
def hard_swish_jit_bwd(x, grad_output):
|
||
|
m = torch.ones_like(x) * (x >= 3.)
|
||
|
m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m)
|
||
|
return grad_output * m
|
||
|
|
||
|
|
||
|
class HardSwishJitAutoFn(torch.autograd.Function):
|
||
|
"""A memory efficient, jit-scripted HardSwish activation"""
|
||
|
@staticmethod
|
||
|
def forward(ctx, x):
|
||
|
ctx.save_for_backward(x)
|
||
|
return hard_swish_jit_fwd(x)
|
||
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_output):
|
||
|
x = ctx.saved_tensors[0]
|
||
|
return hard_swish_jit_bwd(x, grad_output)
|
||
|
|
||
4 years ago
|
@staticmethod
|
||
|
def symbolic(g, self):
|
||
|
input = g.op("Add", self, g.op('Constant', value_t=torch.tensor(3, dtype=torch.float)))
|
||
|
hardtanh_ = g.op("Clip", input, g.op('Constant', value_t=torch.tensor(0, dtype=torch.float)), g.op('Constant', value_t=torch.tensor(6, dtype=torch.float)))
|
||
|
hardtanh_ = g.op("Div", hardtanh_, g.op('Constant', value_t=torch.tensor(6, dtype=torch.float)))
|
||
|
return g.op("Mul", self, hardtanh_)
|
||
|
|
||
5 years ago
|
|
||
|
def hard_swish_me(x, inplace=False):
|
||
|
return HardSwishJitAutoFn.apply(x)
|
||
|
|
||
|
|
||
|
class HardSwishMe(nn.Module):
|
||
|
def __init__(self, inplace: bool = False):
|
||
|
super(HardSwishMe, self).__init__()
|
||
|
|
||
|
def forward(self, x):
|
||
|
return HardSwishJitAutoFn.apply(x)
|
||
|
|
||
|
|
||
|
@torch.jit.script
|
||
|
def hard_mish_jit_fwd(x):
|
||
|
return 0.5 * x * (x + 2).clamp(min=0, max=2)
|
||
|
|
||
|
|
||
|
@torch.jit.script
|
||
|
def hard_mish_jit_bwd(x, grad_output):
|
||
|
m = torch.ones_like(x) * (x >= -2.)
|
||
|
m = torch.where((x >= -2.) & (x <= 0.), x + 1., m)
|
||
|
return grad_output * m
|
||
|
|
||
|
|
||
|
class HardMishJitAutoFn(torch.autograd.Function):
|
||
|
""" A memory efficient, jit scripted variant of Hard Mish
|
||
|
Experimental, based on notes by Mish author Diganta Misra at
|
||
|
https://github.com/digantamisra98/H-Mish/blob/0da20d4bc58e696b6803f2523c58d3c8a82782d0/README.md
|
||
|
"""
|
||
|
@staticmethod
|
||
|
def forward(ctx, x):
|
||
|
ctx.save_for_backward(x)
|
||
5 years ago
|
return hard_mish_jit_fwd(x)
|
||
5 years ago
|
|
||
|
@staticmethod
|
||
|
def backward(ctx, grad_output):
|
||
|
x = ctx.saved_tensors[0]
|
||
5 years ago
|
return hard_mish_jit_bwd(x, grad_output)
|
||
5 years ago
|
|
||
|
|
||
|
def hard_mish_me(x, inplace: bool = False):
|
||
|
return HardMishJitAutoFn.apply(x)
|
||
|
|
||
|
|
||
|
class HardMishMe(nn.Module):
|
||
|
def __init__(self, inplace: bool = False):
|
||
|
super(HardMishMe, self).__init__()
|
||
|
|
||
|
def forward(self, x):
|
||
|
return HardMishJitAutoFn.apply(x)
|
||
|
|
||
|
|
||
|
|