vision transform forward_features

pull/256/head
Zvi Lapp 5 years ago
parent ccfb5751ab
commit f973722adb

@ -251,7 +251,7 @@ class VisionTransformer(nn.Module):
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def forward(self, x, attn_mask=None):
def forward_features(self, x, attn_mask=None):
B = x.shape[0]
x = self.patch_embed(x)
@ -263,6 +263,10 @@ class VisionTransformer(nn.Module):
x = blk(x, attn_mask=attn_mask)
x = self.norm(x[:, 0])
return x
def forward(self, x, attn_mask=None):
x = self.forward_features(x, attn_mask)
x = self.head(x)
return x

Loading…
Cancel
Save