diff --git a/timm/models/edgenext.py b/timm/models/edgenext.py index a81bd9b0..14a1a525 100644 --- a/timm/models/edgenext.py +++ b/timm/models/edgenext.py @@ -152,13 +152,18 @@ class CrossCovarianceAttn(nn.Module): self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) + def normalize(self, x, dim=-1, eps=1e-12): + norm = torch.linalg.norm(x, dim=dim, keepdim=True) + norm[norm < eps] = eps + return torch.div(x, norm) + def forward(self, x): B, N, C = x.shape - qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1) + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 4, 1) q, k, v = qkv.unbind(0) # NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map - attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) * self.temperature + attn = (self.normalize(q, dim=-1) @ self.normalize(k, dim=-1).transpose(-2, -1)) * self.temperature attn = attn.softmax(dim=-1) attn = self.attn_drop(attn)