diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 119d9774..6aae2130 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -139,12 +139,12 @@ class Block(nn.Module): self.dropout = nn.Dropout(p=drop) def forward(self, x): - residual = x.clone() + identity = x x = self.norm1(x) x = self.attn(x) x = self.dropout(x) x = self.drop_path(x) - x = x + residual + x = x + identity y = self.norm2(x) y = self.mlp(y) y = self.drop_path(y)