diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 72f3a61a..50470eaf 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -138,9 +138,16 @@ class Block(nn.Module): self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) def forward(self, x): - x = x + self.drop_path(self.attn(self.norm1(x))) - x = x + self.drop_path(self.mlp(self.norm2(x))) - return x + residual = x.clone() + x = self.norm1(x) + x = self.attn(x) + x = self.dropout(x) + x = self.drop_path(x) + x = x + residual + y = self.norm2(x) + y = self.mlp(y) + y = self.drop_path(y) + return x + y class PatchEmbed(nn.Module):