""" MLP module w/ dropout and configurable activation layer Hacked together by / Copyright 2020 Ross Wightman """ from torch import nn as nn class Mlp(nn.Module): """ MLP as used in Vision Transformer, MLP-Mixer and related networks """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class GluMlp(nn.Module): """ MLP w/ GLU style gating See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features assert hidden_features % 2 == 0 self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features // 2, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x, gates = x.chunk(2, dim=-1) x = x * self.act(gates) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class GatedMlp(nn.Module): """ MLP as used in gMLP """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, gate_layer=None, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() if gate_layer is not None: assert hidden_features % 2 == 0 self.gate = gate_layer(hidden_features) hidden_features = hidden_features // 2 # FIXME base reduction on gate property? else: self.gate = nn.Identity() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.gate(x) x = self.fc2(x) x = self.drop(x) return x