From 93cc08fdc5a3f6716c183150b8370621788a13f0 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 20 Nov 2021 15:50:51 -0800 Subject: [PATCH] Make evonorm variables 1d to match other PyTorch norm layers, will break weight compat for any existing use (likely minimal, easy to fix). --- timm/models/layers/evo_norm.py | 37 ++++++++++++++++------------------ 1 file changed, 17 insertions(+), 20 deletions(-) diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index 50367f9b..8c08e49f 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -21,12 +21,10 @@ class EvoNormBatch2d(nn.Module): self.apply_act = apply_act # apply activation (non-linearity) self.momentum = momentum self.eps = eps - param_shape = (1, num_features, 1, 1) - self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) - self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) - if apply_act: - self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) - self.register_buffer('running_var', torch.ones(1, num_features, 1, 1)) + self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True) + self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True) + self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None + self.register_buffer('running_var', torch.ones(num_features)) self.reset_parameters() def reset_parameters(self): @@ -38,20 +36,21 @@ class EvoNormBatch2d(nn.Module): def forward(self, x): assert x.dim() == 4, 'expected 4D input' x_type = x.dtype + running_var = self.running_var.view(1, -1, 1, 1) if self.training: var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) n = x.numel() / x.shape[1] - self.running_var.copy_( - var.detach() * self.momentum * (n / (n - 1)) + self.running_var * (1 - self.momentum)) + running_var = var.detach() * self.momentum * (n / (n - 1)) + running_var * (1 - self.momentum) + self.running_var.copy_(running_var.view(self.running_var.shape)) else: - var = self.running_var + var = running_var - if self.apply_act: - v = self.v.to(dtype=x_type) + if self.v is not None: + v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1) d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type) d = d.max((var + self.eps).sqrt().to(dtype=x_type)) x = x / d - return x * self.weight + self.bias + return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) class EvoNormSample2d(nn.Module): @@ -60,11 +59,9 @@ class EvoNormSample2d(nn.Module): self.apply_act = apply_act # apply activation (non-linearity) self.groups = groups self.eps = eps - param_shape = (1, num_features, 1, 1) - self.weight = nn.Parameter(torch.ones(param_shape), requires_grad=True) - self.bias = nn.Parameter(torch.zeros(param_shape), requires_grad=True) - if apply_act: - self.v = nn.Parameter(torch.ones(param_shape), requires_grad=True) + self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True) + self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True) + self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None self.reset_parameters() def reset_parameters(self): @@ -77,9 +74,9 @@ class EvoNormSample2d(nn.Module): _assert(x.dim() == 4, 'expected 4D input') B, C, H, W = x.shape _assert(C % self.groups == 0, '') - if self.apply_act: - n = x * (x * self.v).sigmoid() + if self.v is not None: + n = x * (x * self.v.view(1, -1, 1, 1)).sigmoid() x = x.reshape(B, self.groups, -1) x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt() x = x.reshape(B, C, H, W) - return x * self.weight + self.bias + return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)