You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
50 lines
1.6 KiB
50 lines
1.6 KiB
""" MLP module w/ dropout and configurable activation layer
|
|
|
|
Hacked together by / Copyright 2020 Ross Wightman
|
|
"""
|
|
from torch import nn as nn
|
|
|
|
|
|
class Mlp(nn.Module):
|
|
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
|
|
"""
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
|
super().__init__()
|
|
out_features = out_features or in_features
|
|
hidden_features = hidden_features or in_features
|
|
self.fc1 = nn.Linear(in_features, hidden_features)
|
|
self.act = act_layer()
|
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
self.drop = nn.Dropout(drop)
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.act(x)
|
|
x = self.drop(x)
|
|
x = self.fc2(x)
|
|
x = self.drop(x)
|
|
return x
|
|
|
|
|
|
class GluMlp(nn.Module):
|
|
""" MLP w/ GLU style gating
|
|
See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202
|
|
"""
|
|
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.):
|
|
super().__init__()
|
|
out_features = out_features or in_features
|
|
hidden_features = hidden_features or in_features
|
|
self.fc1 = nn.Linear(in_features, hidden_features * 2)
|
|
self.act = act_layer()
|
|
self.fc2 = nn.Linear(hidden_features, out_features)
|
|
self.drop = nn.Dropout(drop)
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x, gates = x.chunk(2, dim=-1)
|
|
x = x * self.act(gates)
|
|
x = self.drop(x)
|
|
x = self.fc2(x)
|
|
x = self.drop(x)
|
|
return x
|