|
|
|
@ -38,13 +38,15 @@ class EvoNormBatch2d(nn.Module):
|
|
|
|
|
x_type = x.dtype
|
|
|
|
|
if self.training:
|
|
|
|
|
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
|
|
|
|
|
self.running_var.copy_(self.momentum * var.detach() + (1 - self.momentum) * self.running_var)
|
|
|
|
|
n = x.numel() / x.shape[1]
|
|
|
|
|
self.running_var.copy_(
|
|
|
|
|
var.detach() * self.momentum * (n / (n - 1)) + self.running_var * (1 - self.momentum))
|
|
|
|
|
else:
|
|
|
|
|
var = self.running_var
|
|
|
|
|
|
|
|
|
|
if self.apply_act:
|
|
|
|
|
v = self.v.to(dtype=x_type)
|
|
|
|
|
d = (x * v) + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type)
|
|
|
|
|
d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type)
|
|
|
|
|
d = d.max((var + self.eps).sqrt().to(dtype=x_type))
|
|
|
|
|
x = x / d
|
|
|
|
|
return x * self.weight + self.bias
|
|
|
|
@ -74,8 +76,8 @@ class EvoNormSample2d(nn.Module):
|
|
|
|
|
B, C, H, W = x.shape
|
|
|
|
|
assert C % self.groups == 0
|
|
|
|
|
if self.apply_act:
|
|
|
|
|
n = (x * self.v).sigmoid().reshape(B, self.groups, -1)
|
|
|
|
|
n = x * (x * self.v).sigmoid()
|
|
|
|
|
x = x.reshape(B, self.groups, -1)
|
|
|
|
|
x = n / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt()
|
|
|
|
|
x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt()
|
|
|
|
|
x = x.reshape(B, C, H, W)
|
|
|
|
|
return x * self.weight + self.bias
|
|
|
|
|