Make EdgeNeXt onnx-exportable.

pull/1385/head
kakukakujirori 3 years ago
parent 7430a85d07
commit ed1057f1e7

@ -152,13 +152,18 @@ class CrossCovarianceAttn(nn.Module):
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def normalize(self, x, dim=-1, eps=1e-12):
norm = torch.linalg.norm(x, dim=dim, keepdim=True)
norm[norm < eps] = eps
return torch.div(x, norm)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1)
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 4, 1)
q, k, v = qkv.unbind(0)
# NOTE, this is NOT spatial attn, q, k, v are B, num_heads, C, L --> C x C attn map
attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) * self.temperature
attn = (self.normalize(q, dim=-1) @ self.normalize(k, dim=-1).transpose(-2, -1)) * self.temperature
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

Loading…
Cancel
Save