You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
127 lines
4.3 KiB
127 lines
4.3 KiB
4 years ago
|
""" MLP module w/ dropout and configurable activation layer
|
||
|
|
||
|
Hacked together by / Copyright 2020 Ross Wightman
|
||
|
"""
|
||
|
from torch import nn as nn
|
||
|
|
||
3 years ago
|
from .helpers import to_2tuple
|
||
|
|
||
4 years ago
|
|
||
|
class Mlp(nn.Module):
|
||
|
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
||
|
"""
|
||
3 years ago
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
|
||
4 years ago
|
super().__init__()
|
||
|
out_features = out_features or in_features
|
||
|
hidden_features = hidden_features or in_features
|
||
3 years ago
|
bias = to_2tuple(bias)
|
||
3 years ago
|
drop_probs = to_2tuple(drop)
|
||
|
|
||
3 years ago
|
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
|
||
4 years ago
|
self.act = act_layer()
|
||
3 years ago
|
self.drop1 = nn.Dropout(drop_probs[0])
|
||
3 years ago
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
|
||
3 years ago
|
self.drop2 = nn.Dropout(drop_probs[1])
|
||
4 years ago
|
|
||
|
def forward(self, x):
|
||
|
x = self.fc1(x)
|
||
|
x = self.act(x)
|
||
3 years ago
|
x = self.drop1(x)
|
||
4 years ago
|
x = self.fc2(x)
|
||
3 years ago
|
x = self.drop2(x)
|
||
4 years ago
|
return x
|
||
|
|
||
|
|
||
|
class GluMlp(nn.Module):
|
||
|
""" MLP w/ GLU style gating
|
||
|
See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
|
||
|
"""
|
||
3 years ago
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, bias=True, drop=0.):
|
||
4 years ago
|
super().__init__()
|
||
|
out_features = out_features or in_features
|
||
|
hidden_features = hidden_features or in_features
|
||
4 years ago
|
assert hidden_features % 2 == 0
|
||
3 years ago
|
bias = to_2tuple(bias)
|
||
3 years ago
|
drop_probs = to_2tuple(drop)
|
||
|
|
||
3 years ago
|
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
|
||
4 years ago
|
self.act = act_layer()
|
||
3 years ago
|
self.drop1 = nn.Dropout(drop_probs[0])
|
||
3 years ago
|
self.fc2 = nn.Linear(hidden_features // 2, out_features, bias=bias[1])
|
||
3 years ago
|
self.drop2 = nn.Dropout(drop_probs[1])
|
||
4 years ago
|
|
||
3 years ago
|
def init_weights(self):
|
||
|
# override init of fc1 w/ gate portion set to weight near zero, bias=1
|
||
|
fc1_mid = self.fc1.bias.shape[0] // 2
|
||
|
nn.init.ones_(self.fc1.bias[fc1_mid:])
|
||
|
nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6)
|
||
|
|
||
4 years ago
|
def forward(self, x):
|
||
|
x = self.fc1(x)
|
||
|
x, gates = x.chunk(2, dim=-1)
|
||
|
x = x * self.act(gates)
|
||
3 years ago
|
x = self.drop1(x)
|
||
4 years ago
|
x = self.fc2(x)
|
||
3 years ago
|
x = self.drop2(x)
|
||
4 years ago
|
return x
|
||
4 years ago
|
|
||
|
|
||
|
class GatedMlp(nn.Module):
|
||
|
""" MLP as used in gMLP
|
||
|
"""
|
||
3 years ago
|
def __init__(
|
||
|
self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU,
|
||
|
gate_layer=None, bias=True, drop=0.):
|
||
4 years ago
|
super().__init__()
|
||
|
out_features = out_features or in_features
|
||
|
hidden_features = hidden_features or in_features
|
||
3 years ago
|
bias = to_2tuple(bias)
|
||
3 years ago
|
drop_probs = to_2tuple(drop)
|
||
|
|
||
3 years ago
|
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0])
|
||
4 years ago
|
self.act = act_layer()
|
||
3 years ago
|
self.drop1 = nn.Dropout(drop_probs[0])
|
||
4 years ago
|
if gate_layer is not None:
|
||
|
assert hidden_features % 2 == 0
|
||
|
self.gate = gate_layer(hidden_features)
|
||
|
hidden_features = hidden_features // 2 # FIXME base reduction on gate property?
|
||
|
else:
|
||
|
self.gate = nn.Identity()
|
||
3 years ago
|
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1])
|
||
3 years ago
|
self.drop2 = nn.Dropout(drop_probs[1])
|
||
4 years ago
|
|
||
|
def forward(self, x):
|
||
|
x = self.fc1(x)
|
||
|
x = self.act(x)
|
||
3 years ago
|
x = self.drop1(x)
|
||
4 years ago
|
x = self.gate(x)
|
||
|
x = self.fc2(x)
|
||
3 years ago
|
x = self.drop2(x)
|
||
4 years ago
|
return x
|
||
4 years ago
|
|
||
|
|
||
|
class ConvMlp(nn.Module):
|
||
|
""" MLP using 1x1 convs that keeps spatial dims
|
||
|
"""
|
||
|
def __init__(
|
||
3 years ago
|
self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU,
|
||
|
norm_layer=None, bias=True, drop=0.):
|
||
4 years ago
|
super().__init__()
|
||
|
out_features = out_features or in_features
|
||
|
hidden_features = hidden_features or in_features
|
||
3 years ago
|
bias = to_2tuple(bias)
|
||
|
|
||
|
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0])
|
||
4 years ago
|
self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity()
|
||
|
self.act = act_layer()
|
||
|
self.drop = nn.Dropout(drop)
|
||
3 years ago
|
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1])
|
||
4 years ago
|
|
||
|
def forward(self, x):
|
||
|
x = self.fc1(x)
|
||
|
x = self.norm(x)
|
||
|
x = self.act(x)
|
||
|
x = self.drop(x)
|
||
|
x = self.fc2(x)
|
||
|
return x
|