Fix FX breaking assert in evonorm

pull/1007/head
Ross Wightman 3 years ago
parent f83b0b01e3
commit 480c676ffa

@ -34,18 +34,17 @@ class EvoNormBatch2d(nn.Module):
nn.init.ones_(self.v)
def forward(self, x):
assert x.dim() == 4, 'expected 4D input'
_assert(x.dim() == 4, 'expected 4D input')
x_type = x.dtype
running_var = self.running_var.view(1, -1, 1, 1)
if self.training:
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
n = x.numel() / x.shape[1]
running_var = var.detach() * self.momentum * (n / (n - 1)) + running_var * (1 - self.momentum)
self.running_var.copy_(running_var.view(self.running_var.shape))
else:
var = running_var
if self.v is not None:
running_var = self.running_var.view(1, -1, 1, 1)
if self.training:
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
n = x.numel() / x.shape[1]
running_var = var.detach() * self.momentum * (n / (n - 1)) + running_var * (1 - self.momentum)
self.running_var.copy_(running_var.view(self.running_var.shape))
else:
var = running_var
v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1)
d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type)
d = d.max((var + self.eps).sqrt().to(dtype=x_type))

Loading…
Cancel
Save