@ -92,7 +92,7 @@ def group_rms(x, groups: int = 32, eps: float = 1e-5):
_assert ( C % groups == 0 , ' ' )
x_dtype = x . dtype
x = x . reshape ( B , groups , C / / groups , H , W )
rms = x . float ( ) . square ( ) . mean ( dim = ( 2 , 3 , 4 ) , keepdim = True ) . add ( eps ) . sqrt_ ( ) . to ( dtype= x_dtype)
rms = x . float ( ) . square ( ) . mean ( dim = ( 2 , 3 , 4 ) , keepdim = True ) . add ( eps ) . sqrt_ ( ) . to ( x_dtype)
return rms . expand ( x . shape ) . reshape ( B , C , H , W )
@ -160,14 +160,14 @@ class EvoNorm2dB1(nn.Module):
n = x . numel ( ) / x . shape [ 1 ]
self . running_var . copy_ (
self . running_var * ( 1 - self . momentum ) +
var . detach ( ) . to ( dtype = self . running_var . dtype ) * self . momentum * ( n / ( n - 1 ) ) )
var . detach ( ) . to ( self . running_var . dtype ) * self . momentum * ( n / ( n - 1 ) ) )
else :
var = self . running_var
var = var . to ( dtype= x_dtype) . view ( v_shape )
var = var . to ( x_dtype) . view ( v_shape )
left = var . add ( self . eps ) . sqrt_ ( )
right = ( x + 1 ) * instance_rms ( x , self . eps )
x = x / left . max ( right )
return x * self . weight . view ( v_shape ) . to ( dtype= x_dtype) + self . bias . view ( v_shape ) . to ( dtype = x_dtype )
return x * self . weight . view ( v_shape ) . to ( x_dtype) + self . bias . view ( v_shape ) . to ( x_dtype )
class EvoNorm2dB2 ( nn . Module ) :
@ -195,14 +195,14 @@ class EvoNorm2dB2(nn.Module):
n = x . numel ( ) / x . shape [ 1 ]
self . running_var . copy_ (
self . running_var * ( 1 - self . momentum ) +
var . detach ( ) . to ( dtype = self . running_var . dtype ) * self . momentum * ( n / ( n - 1 ) ) )
var . detach ( ) . to ( self . running_var . dtype ) * self . momentum * ( n / ( n - 1 ) ) )
else :
var = self . running_var
var = var . to ( dtype= x_dtype) . view ( v_shape )
var = var . to ( x_dtype) . view ( v_shape )
left = var . add ( self . eps ) . sqrt_ ( )
right = instance_rms ( x , self . eps ) - x
x = x / left . max ( right )
return x * self . weight . view ( v_shape ) . to ( dtype= x_dtype) + self . bias . view ( v_shape ) . to ( dtype = x_dtype )
return x * self . weight . view ( v_shape ) . to ( x_dtype) + self . bias . view ( v_shape ) . to ( x_dtype )
class EvoNorm2dS0 ( nn . Module ) :
@ -231,9 +231,9 @@ class EvoNorm2dS0(nn.Module):
x_dtype = x . dtype
v_shape = ( 1 , - 1 , 1 , 1 )
if self . v is not None :
v = self . v . view ( v_shape ) . to ( dtype= x_dtype)
v = self . v . view ( v_shape ) . to ( x_dtype)
x = x * ( x * v ) . sigmoid ( ) / group_std ( x , self . groups , self . eps )
return x * self . weight . view ( v_shape ) . to ( dtype= x_dtype) + self . bias . view ( v_shape ) . to ( dtype = x_dtype )
return x * self . weight . view ( v_shape ) . to ( x_dtype) + self . bias . view ( v_shape ) . to ( x_dtype )
class EvoNorm2dS0a ( EvoNorm2dS0 ) :
@ -247,10 +247,10 @@ class EvoNorm2dS0a(EvoNorm2dS0):
v_shape = ( 1 , - 1 , 1 , 1 )
d = group_std ( x , self . groups , self . eps )
if self . v is not None :
v = self . v . view ( v_shape ) . to ( dtype= x_dtype)
v = self . v . view ( v_shape ) . to ( x_dtype)
x = x * ( x * v ) . sigmoid ( )
x = x / d
return x * self . weight . view ( v_shape ) . to ( dtype= x_dtype) + self . bias . view ( v_shape ) . to ( dtype = x_dtype )
return x * self . weight . view ( v_shape ) . to ( x_dtype) + self . bias . view ( v_shape ) . to ( x_dtype )
class EvoNorm2dS1 ( nn . Module ) :
@ -284,7 +284,7 @@ class EvoNorm2dS1(nn.Module):
v_shape = ( 1 , - 1 , 1 , 1 )
if self . apply_act :
x = self . act ( x ) / group_std ( x , self . groups , self . eps )
return x * self . weight . view ( v_shape ) . to ( dtype= x_dtype) + self . bias . view ( v_shape ) . to ( dtype = x_dtype )
return x * self . weight . view ( v_shape ) . to ( x_dtype) + self . bias . view ( v_shape ) . to ( x_dtype )
class EvoNorm2dS1a ( EvoNorm2dS1 ) :
@ -299,7 +299,7 @@ class EvoNorm2dS1a(EvoNorm2dS1):
x_dtype = x . dtype
v_shape = ( 1 , - 1 , 1 , 1 )
x = self . act ( x ) / group_std ( x , self . groups , self . eps )
return x * self . weight . view ( v_shape ) . to ( dtype= x_dtype) + self . bias . view ( v_shape ) . to ( dtype = x_dtype )
return x * self . weight . view ( v_shape ) . to ( x_dtype) + self . bias . view ( v_shape ) . to ( x_dtype )
class EvoNorm2dS2 ( nn . Module ) :
@ -332,7 +332,7 @@ class EvoNorm2dS2(nn.Module):
v_shape = ( 1 , - 1 , 1 , 1 )
if self . apply_act :
x = self . act ( x ) / group_rms ( x , self . groups , self . eps )
return x * self . weight . view ( v_shape ) . to ( dtype= x_dtype) + self . bias . view ( v_shape ) . to ( dtype = x_dtype )
return x * self . weight . view ( v_shape ) . to ( x_dtype) + self . bias . view ( v_shape ) . to ( x_dtype )
class EvoNorm2dS2a ( EvoNorm2dS2 ) :
@ -347,4 +347,4 @@ class EvoNorm2dS2a(EvoNorm2dS2):
x_dtype = x . dtype
v_shape = ( 1 , - 1 , 1 , 1 )
x = self . act ( x ) / group_rms ( x , self . groups , self . eps )
return x * self . weight . view ( v_shape ) . to ( dtype= x_dtype) + self . bias . view ( v_shape ) . to ( dtype = x_dtype )
return x * self . weight . view ( v_shape ) . to ( x_dtype) + self . bias . view ( v_shape ) . to ( x_dtype )