Include pre_logits in vit

pull/352/head
Zhiyuan Chen 5 years ago committed by GitHub
parent f8463b8fa9
commit 201b0046f7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -206,7 +206,7 @@ class VisionTransformer(nn.Module):
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, pre_logits=True):
super().__init__()
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
@ -232,8 +232,8 @@ class VisionTransformer(nn.Module):
self.norm = norm_layer(embed_dim)
# NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
#self.repr = nn.Linear(embed_dim, representation_size)
#self.repr_act = nn.Tanh()
self.repr = nn.Linear(embed_dim, representation_size) if pre_logits else nn.Identity()
self.repr_act = nn.Tanh() if pre_logits else nn.Identity()
# Classifier head
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
@ -279,6 +279,8 @@ class VisionTransformer(nn.Module):
def forward(self, x):
x = self.forward_features(x)
x = self.repr(x)
x = self.repr_act(x)
x = self.head(x)
return x

Loading…
Cancel
Save