Cleanup re-use of Dropout modules in Mlp modules after some twitter feedback :p

more_datasets
Ross Wightman 3 years ago
parent 71f00bfe9e
commit f658a72e72

@ -4,6 +4,8 @@ Hacked together by / Copyright 2020 Ross Wightman
""" """
from torch import nn as nn from torch import nn as nn
from .helpers import to_2tuple
class Mlp(nn.Module): class Mlp(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks """ MLP as used in Vision Transformer, MLP-Mixer and related networks
@ -12,17 +14,20 @@ class Mlp(nn.Module):
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
hidden_features = hidden_features or in_features hidden_features = hidden_features or in_features
drop_probs = to_2tuple(drop)
self.fc1 = nn.Linear(in_features, hidden_features) self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer() self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features, out_features) self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop) self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x): def forward(self, x):
x = self.fc1(x) x = self.fc1(x)
x = self.act(x) x = self.act(x)
x = self.drop(x) x = self.drop1(x)
x = self.fc2(x) x = self.fc2(x)
x = self.drop(x) x = self.drop2(x)
return x return x
@ -35,10 +40,13 @@ class GluMlp(nn.Module):
out_features = out_features or in_features out_features = out_features or in_features
hidden_features = hidden_features or in_features hidden_features = hidden_features or in_features
assert hidden_features % 2 == 0 assert hidden_features % 2 == 0
drop_probs = to_2tuple(drop)
self.fc1 = nn.Linear(in_features, hidden_features) self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer() self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
self.fc2 = nn.Linear(hidden_features // 2, out_features) self.fc2 = nn.Linear(hidden_features // 2, out_features)
self.drop = nn.Dropout(drop) self.drop2 = nn.Dropout(drop_probs[1])
def init_weights(self): def init_weights(self):
# override init of fc1 w/ gate portion set to weight near zero, bias=1 # override init of fc1 w/ gate portion set to weight near zero, bias=1
@ -50,9 +58,9 @@ class GluMlp(nn.Module):
x = self.fc1(x) x = self.fc1(x)
x, gates = x.chunk(2, dim=-1) x, gates = x.chunk(2, dim=-1)
x = x * self.act(gates) x = x * self.act(gates)
x = self.drop(x) x = self.drop1(x)
x = self.fc2(x) x = self.fc2(x)
x = self.drop(x) x = self.drop2(x)
return x return x
@ -64,8 +72,11 @@ class GatedMlp(nn.Module):
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
hidden_features = hidden_features or in_features hidden_features = hidden_features or in_features
drop_probs = to_2tuple(drop)
self.fc1 = nn.Linear(in_features, hidden_features) self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer() self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
if gate_layer is not None: if gate_layer is not None:
assert hidden_features % 2 == 0 assert hidden_features % 2 == 0
self.gate = gate_layer(hidden_features) self.gate = gate_layer(hidden_features)
@ -73,15 +84,15 @@ class GatedMlp(nn.Module):
else: else:
self.gate = nn.Identity() self.gate = nn.Identity()
self.fc2 = nn.Linear(hidden_features, out_features) self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop) self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x): def forward(self, x):
x = self.fc1(x) x = self.fc1(x)
x = self.act(x) x = self.act(x)
x = self.drop(x) x = self.drop1(x)
x = self.gate(x) x = self.gate(x)
x = self.fc2(x) x = self.fc2(x)
x = self.drop(x) x = self.drop2(x)
return x return x

@ -45,6 +45,8 @@ class SpatialMlp(nn.Module):
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
hidden_features = hidden_features or in_features hidden_features = hidden_features or in_features
drop_probs = to_2tuple(drop)
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.spatial_conv = spatial_conv self.spatial_conv = spatial_conv
@ -55,9 +57,9 @@ class SpatialMlp(nn.Module):
hidden_features = in_features * 2 hidden_features = in_features * 2
self.hidden_features = hidden_features self.hidden_features = hidden_features
self.group = group self.group = group
self.drop = nn.Dropout(drop)
self.conv1 = nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0, bias=False) self.conv1 = nn.Conv2d(in_features, hidden_features, 1, stride=1, padding=0, bias=False)
self.act1 = act_layer() self.act1 = act_layer()
self.drop1 = nn.Dropout(drop_probs[0])
if self.spatial_conv: if self.spatial_conv:
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(
hidden_features, hidden_features, 3, stride=1, padding=1, groups=self.group, bias=False) hidden_features, hidden_features, 3, stride=1, padding=1, groups=self.group, bias=False)
@ -66,16 +68,17 @@ class SpatialMlp(nn.Module):
self.conv2 = None self.conv2 = None
self.act2 = None self.act2 = None
self.conv3 = nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0, bias=False) self.conv3 = nn.Conv2d(hidden_features, out_features, 1, stride=1, padding=0, bias=False)
self.drop3 = nn.Dropout(drop_probs[1])
def forward(self, x): def forward(self, x):
x = self.conv1(x) x = self.conv1(x)
x = self.act1(x) x = self.act1(x)
x = self.drop(x) x = self.drop1(x)
if self.conv2 is not None: if self.conv2 is not None:
x = self.conv2(x) x = self.conv2(x)
x = self.act2(x) x = self.act2(x)
x = self.conv3(x) x = self.conv3(x)
x = self.drop(x) x = self.drop3(x)
return x return x

Loading…
Cancel
Save