|
|
|
@ -206,7 +206,7 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
"""
|
|
|
|
|
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
|
|
|
|
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
|
|
|
|
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
|
|
|
|
|
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, pre_logits=True):
|
|
|
|
|
super().__init__()
|
|
|
|
|
self.num_classes = num_classes
|
|
|
|
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
|
|
|
@ -232,8 +232,8 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
self.norm = norm_layer(embed_dim)
|
|
|
|
|
|
|
|
|
|
# NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
|
|
|
|
|
#self.repr = nn.Linear(embed_dim, representation_size)
|
|
|
|
|
#self.repr_act = nn.Tanh()
|
|
|
|
|
self.repr = nn.Linear(embed_dim, representation_size) if pre_logits else nn.Identity()
|
|
|
|
|
self.repr_act = nn.Tanh() if pre_logits else nn.Identity()
|
|
|
|
|
|
|
|
|
|
# Classifier head
|
|
|
|
|
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
|
|
@ -279,6 +279,8 @@ class VisionTransformer(nn.Module):
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
x = self.forward_features(x)
|
|
|
|
|
x = self.repr(x)
|
|
|
|
|
x = self.repr_act(x)
|
|
|
|
|
x = self.head(x)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|