Fix FX breaking assert in evonorm

pull/1007/head
Ross Wightman 3 years ago
parent f83b0b01e3
commit 480c676ffa

@ -34,8 +34,9 @@ class EvoNormBatch2d(nn.Module):
nn.init.ones_(self.v) nn.init.ones_(self.v)
def forward(self, x): def forward(self, x):
assert x.dim() == 4, 'expected 4D input' _assert(x.dim() == 4, 'expected 4D input')
x_type = x.dtype x_type = x.dtype
if self.v is not None:
running_var = self.running_var.view(1, -1, 1, 1) running_var = self.running_var.view(1, -1, 1, 1)
if self.training: if self.training:
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
@ -44,8 +45,6 @@ class EvoNormBatch2d(nn.Module):
self.running_var.copy_(running_var.view(self.running_var.shape)) self.running_var.copy_(running_var.view(self.running_var.shape))
else: else:
var = running_var var = running_var
if self.v is not None:
v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1) v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1)
d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type) d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type)
d = d.max((var + self.eps).sqrt().to(dtype=x_type)) d = d.max((var + self.eps).sqrt().to(dtype=x_type))

Loading…
Cancel
Save