|
|
@ -4,6 +4,8 @@ Hacked together by / Copyright 2020 Ross Wightman
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
from torch import nn as nn
|
|
|
|
from torch import nn as nn
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from .helpers import to_2tuple
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Mlp(nn.Module):
|
|
|
|
class Mlp(nn.Module):
|
|
|
|
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
|
|
|
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
|
|
@ -12,17 +14,20 @@ class Mlp(nn.Module):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
out_features = out_features or in_features
|
|
|
|
out_features = out_features or in_features
|
|
|
|
hidden_features = hidden_features or in_features
|
|
|
|
hidden_features = hidden_features or in_features
|
|
|
|
|
|
|
|
drop_probs = to_2tuple(drop)
|
|
|
|
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
|
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
|
|
self.act = act_layer()
|
|
|
|
self.act = act_layer()
|
|
|
|
|
|
|
|
self.drop1 = nn.Dropout(drop_probs[0])
|
|
|
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
|
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
|
|
self.drop = nn.Dropout(drop)
|
|
|
|
self.drop2 = nn.Dropout(drop_probs[1])
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.fc1(x)
|
|
|
|
x = self.fc1(x)
|
|
|
|
x = self.act(x)
|
|
|
|
x = self.act(x)
|
|
|
|
x = self.drop(x)
|
|
|
|
x = self.drop1(x)
|
|
|
|
x = self.fc2(x)
|
|
|
|
x = self.fc2(x)
|
|
|
|
x = self.drop(x)
|
|
|
|
x = self.drop2(x)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -35,10 +40,13 @@ class GluMlp(nn.Module):
|
|
|
|
out_features = out_features or in_features
|
|
|
|
out_features = out_features or in_features
|
|
|
|
hidden_features = hidden_features or in_features
|
|
|
|
hidden_features = hidden_features or in_features
|
|
|
|
assert hidden_features % 2 == 0
|
|
|
|
assert hidden_features % 2 == 0
|
|
|
|
|
|
|
|
drop_probs = to_2tuple(drop)
|
|
|
|
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
|
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
|
|
self.act = act_layer()
|
|
|
|
self.act = act_layer()
|
|
|
|
|
|
|
|
self.drop1 = nn.Dropout(drop_probs[0])
|
|
|
|
self.fc2 = nn.Linear(hidden_features // 2, out_features)
|
|
|
|
self.fc2 = nn.Linear(hidden_features // 2, out_features)
|
|
|
|
self.drop = nn.Dropout(drop)
|
|
|
|
self.drop2 = nn.Dropout(drop_probs[1])
|
|
|
|
|
|
|
|
|
|
|
|
def init_weights(self):
|
|
|
|
def init_weights(self):
|
|
|
|
# override init of fc1 w/ gate portion set to weight near zero, bias=1
|
|
|
|
# override init of fc1 w/ gate portion set to weight near zero, bias=1
|
|
|
@ -50,9 +58,9 @@ class GluMlp(nn.Module):
|
|
|
|
x = self.fc1(x)
|
|
|
|
x = self.fc1(x)
|
|
|
|
x, gates = x.chunk(2, dim=-1)
|
|
|
|
x, gates = x.chunk(2, dim=-1)
|
|
|
|
x = x * self.act(gates)
|
|
|
|
x = x * self.act(gates)
|
|
|
|
x = self.drop(x)
|
|
|
|
x = self.drop1(x)
|
|
|
|
x = self.fc2(x)
|
|
|
|
x = self.fc2(x)
|
|
|
|
x = self.drop(x)
|
|
|
|
x = self.drop2(x)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -64,8 +72,11 @@ class GatedMlp(nn.Module):
|
|
|
|
super().__init__()
|
|
|
|
super().__init__()
|
|
|
|
out_features = out_features or in_features
|
|
|
|
out_features = out_features or in_features
|
|
|
|
hidden_features = hidden_features or in_features
|
|
|
|
hidden_features = hidden_features or in_features
|
|
|
|
|
|
|
|
drop_probs = to_2tuple(drop)
|
|
|
|
|
|
|
|
|
|
|
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
|
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
|
|
self.act = act_layer()
|
|
|
|
self.act = act_layer()
|
|
|
|
|
|
|
|
self.drop1 = nn.Dropout(drop_probs[0])
|
|
|
|
if gate_layer is not None:
|
|
|
|
if gate_layer is not None:
|
|
|
|
assert hidden_features % 2 == 0
|
|
|
|
assert hidden_features % 2 == 0
|
|
|
|
self.gate = gate_layer(hidden_features)
|
|
|
|
self.gate = gate_layer(hidden_features)
|
|
|
@ -73,15 +84,15 @@ class GatedMlp(nn.Module):
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
self.gate = nn.Identity()
|
|
|
|
self.gate = nn.Identity()
|
|
|
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
|
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
|
|
self.drop = nn.Dropout(drop)
|
|
|
|
self.drop2 = nn.Dropout(drop_probs[1])
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
def forward(self, x):
|
|
|
|
x = self.fc1(x)
|
|
|
|
x = self.fc1(x)
|
|
|
|
x = self.act(x)
|
|
|
|
x = self.act(x)
|
|
|
|
x = self.drop(x)
|
|
|
|
x = self.drop1(x)
|
|
|
|
x = self.gate(x)
|
|
|
|
x = self.gate(x)
|
|
|
|
x = self.fc2(x)
|
|
|
|
x = self.fc2(x)
|
|
|
|
x = self.drop(x)
|
|
|
|
x = self.drop2(x)
|
|
|
|
return x
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|