From 3a55a30ed1db33a11e13400521d0643dec131e53 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sun, 11 Jul 2021 14:25:58 +0100 Subject: [PATCH] add notes from author --- timm/models/xcit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/models/xcit.py b/timm/models/xcit.py index be357ddb..886aa930 100644 --- a/timm/models/xcit.py +++ b/timm/models/xcit.py @@ -229,8 +229,7 @@ class ClassAttentionBlock(nn.Module): else: self.gamma1, self.gamma2 = 1.0, 1.0 - # (note from official code) - # FIXME: A hack for models pre-trained with layernorm over all the tokens not just the CLS + # See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 self.tokens_norm = tokens_norm def forward(self, x): @@ -309,6 +308,7 @@ class XCABlock(nn.Module): def forward(self, x, H: int, W: int): x = x + self.drop_path(self.gamma1 * self.attn(self.norm1(x))) # NOTE official code has 3 then 2, so keeping it the same to be consistent with loaded weights + # See https://github.com/rwightman/pytorch-image-models/pull/747#issuecomment-877795721 x = x + self.drop_path(self.gamma3 * self.local_mp(self.norm3(x), H, W)) x = x + self.drop_path(self.gamma2 * self.mlp(self.norm2(x))) return x