Make evonorm variables 1d to match other PyTorch norm layers, will break weight compat for any existing use (likely minimal, easy to fix).

pull/989/head
Ross Wightman 3 years ago
parent af607b75cc
commit 93cc08fdc5

@ -21,12 +21,10 @@ class EvoNormBatch2d(nn.Module):
self.apply_act = apply_act # apply activation (non-linearity) self.apply_act = apply_act # apply activation (non-linearity)
self.momentum = momentum self.momentum = momentum
self.eps = eps self.eps = eps
param_shape = (1, num_features, 1, 1) self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True)
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None
if apply_act: self.register_buffer('running_var', torch.ones(num_features))
self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.register_buffer('running_var', torch.ones(1, num_features, 1, 1))
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
@ -38,20 +36,21 @@ class EvoNormBatch2d(nn.Module):
def forward(self, x): def forward(self, x):
assert x.dim() == 4, 'expected 4D input' assert x.dim() == 4, 'expected 4D input'
x_type = x.dtype x_type = x.dtype
running_var = self.running_var.view(1, -1, 1, 1)
if self.training: if self.training:
var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True)
n = x.numel() / x.shape[1] n = x.numel() / x.shape[1]
self.running_var.copy_( running_var = var.detach() * self.momentum * (n / (n - 1)) + running_var * (1 - self.momentum)
var.detach() * self.momentum * (n / (n - 1)) + self.running_var * (1 - self.momentum)) self.running_var.copy_(running_var.view(self.running_var.shape))
else: else:
var = self.running_var var = running_var
if self.apply_act: if self.v is not None:
v = self.v.to(dtype=x_type) v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1)
d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type) d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type)
d = d.max((var + self.eps).sqrt().to(dtype=x_type)) d = d.max((var + self.eps).sqrt().to(dtype=x_type))
x = x / d x = x / d
return x * self.weight + self.bias return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
class EvoNormSample2d(nn.Module): class EvoNormSample2d(nn.Module):
@ -60,11 +59,9 @@ class EvoNormSample2d(nn.Module):
self.apply_act = apply_act # apply activation (non-linearity) self.apply_act = apply_act # apply activation (non-linearity)
self.groups = groups self.groups = groups
self.eps = eps self.eps = eps
param_shape = (1, num_features, 1, 1) self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True)
self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True)
self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None
if apply_act:
self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True)
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
@ -77,9 +74,9 @@ class EvoNormSample2d(nn.Module):
_assert(x.dim() == 4, 'expected 4D input') _assert(x.dim() == 4, 'expected 4D input')
B, C, H, W = x.shape B, C, H, W = x.shape
_assert(C % self.groups == 0, '') _assert(C % self.groups == 0, '')
if self.apply_act: if self.v is not None:
n = x * (x * self.v).sigmoid() n = x * (x * self.v.view(1, -1, 1, 1)).sigmoid()
x = x.reshape(B, self.groups, -1) x = x.reshape(B, self.groups, -1)
x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt() x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt()
x = x.reshape(B, C, H, W) x = x.reshape(B, C, H, W)
return x * self.weight + self.bias return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)

Loading…
Cancel
Save