diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index b9857ed2..ad727759 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -251,7 +251,7 @@ class VisionTransformer(nn.Module): def no_weight_decay(self): return {'pos_embed', 'cls_token'} - def forward(self, x, attn_mask=None): + def forward_features(self, x, attn_mask=None): B = x.shape[0] x = self.patch_embed(x) @@ -263,6 +263,10 @@ class VisionTransformer(nn.Module): x = blk(x, attn_mask=attn_mask) x = self.norm(x[:, 0]) + return x + + def forward(self, x, attn_mask=None): + x = self.forward_features(x, attn_mask) x = self.head(x) return x