|
|
|
@ -34,18 +34,17 @@ class EvoNormBatch2d(nn.Module):
|
|
|
|
|
nn.init.ones_(self.v)
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
assert x.dim() == 4, 'expected 4D input'
|
|
|
|
|
_assert(x.dim() == 4, 'expected 4D input')
|
|
|
|
|
x_type = x.dtype
|
|
|
|
|
running_var = self.running_var.view(1, -1, 1, 1)
|
|
|
|
|
if self.training:
|
|
|
|
|
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
|
|
|
|
|
n = x.numel() / x.shape[1]
|
|
|
|
|
running_var = var.detach() * self.momentum * (n / (n - 1)) + running_var * (1 - self.momentum)
|
|
|
|
|
self.running_var.copy_(running_var.view(self.running_var.shape))
|
|
|
|
|
else:
|
|
|
|
|
var = running_var
|
|
|
|
|
|
|
|
|
|
if self.v is not None:
|
|
|
|
|
running_var = self.running_var.view(1, -1, 1, 1)
|
|
|
|
|
if self.training:
|
|
|
|
|
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
|
|
|
|
|
n = x.numel() / x.shape[1]
|
|
|
|
|
running_var = var.detach() * self.momentum * (n / (n - 1)) + running_var * (1 - self.momentum)
|
|
|
|
|
self.running_var.copy_(running_var.view(self.running_var.shape))
|
|
|
|
|
else:
|
|
|
|
|
var = running_var
|
|
|
|
|
v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1)
|
|
|
|
|
d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type)
|
|
|
|
|
d = d.max((var + self.eps).sqrt().to(dtype=x_type))
|
|
|
|
|