|
|
@ -248,7 +248,7 @@ class EvoNorm2dS0a(EvoNorm2dS0):
|
|
|
|
d = group_std(x, self.groups, self.eps)
|
|
|
|
d = group_std(x, self.groups, self.eps)
|
|
|
|
if self.v is not None:
|
|
|
|
if self.v is not None:
|
|
|
|
v = self.v.view(v_shape).to(dtype=x_dtype)
|
|
|
|
v = self.v.view(v_shape).to(dtype=x_dtype)
|
|
|
|
x = x * (x * v).sigmoid_()
|
|
|
|
x = x * (x * v).sigmoid()
|
|
|
|
x = x / d
|
|
|
|
x = x / d
|
|
|
|
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
|
|
|
|
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)
|
|
|
|
|
|
|
|
|
|
|
|