diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 72f3a61a..72974e29 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -206,7 +206,7 @@ class VisionTransformer(nn.Module): """ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., - drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm): + drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm, pre_logits=True): super().__init__() self.num_classes = num_classes self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models @@ -232,8 +232,8 @@ class VisionTransformer(nn.Module): self.norm = norm_layer(embed_dim) # NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here - #self.repr = nn.Linear(embed_dim, representation_size) - #self.repr_act = nn.Tanh() + self.repr = nn.Linear(embed_dim, representation_size) if pre_logits else nn.Identity() + self.repr_act = nn.Tanh() if pre_logits else nn.Identity() # Classifier head self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() @@ -279,6 +279,8 @@ class VisionTransformer(nn.Module): def forward(self, x): x = self.forward_features(x) + x = self.repr(x) + x = self.repr_act(x) x = self.head(x) return x