From f973722adba839ee0f77c5dcc036b1b186fdf5e8 Mon Sep 17 00:00:00 2001 From: Zvi Lapp Date: Sun, 18 Oct 2020 18:24:44 +0300 Subject: [PATCH] vision transform forward_features --- timm/models/vision_transformer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index b9857ed2..ad727759 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -251,7 +251,7 @@ class VisionTransformer(nn.Module): def no_weight_decay(self): return {'pos_embed', 'cls_token'} - def forward(self, x, attn_mask=None): + def forward_features(self, x, attn_mask=None): B = x.shape[0] x = self.patch_embed(x) @@ -263,6 +263,10 @@ class VisionTransformer(nn.Module): x = blk(x, attn_mask=attn_mask) x = self.norm(x[:, 0]) + return x + + def forward(self, x, attn_mask=None): + x = self.forward_features(x, attn_mask) x = self.head(x) return x