Update vision_transformer.py

pull/367/head
Zhiyuan Chen 5 years ago committed by GitHub
parent 28c0fa31fe
commit 23dc3f3974
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -139,12 +139,12 @@ class Block(nn.Module):
self.dropout = nn.Dropout(p=drop) self.dropout = nn.Dropout(p=drop)
def forward(self, x): def forward(self, x):
residual = x.clone() identity = x
x = self.norm1(x) x = self.norm1(x)
x = self.attn(x) x = self.attn(x)
x = self.dropout(x) x = self.dropout(x)
x = self.drop_path(x) x = self.drop_path(x)
x = x + residual x = x + identity
y = self.norm2(x) y = self.norm2(x)
y = self.mlp(y) y = self.mlp(y)
y = self.drop_path(y) y = self.drop_path(y)

Loading…
Cancel
Save