From fc8b8afb6f0144afcbd927eb25546864e8a283b4 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Thu, 13 Aug 2020 18:23:50 -0700 Subject: [PATCH] Fix a silly bug in Sample version of EvoNorm missing x* part of swish, update EvoNormBatch to accumulated unbiased variance. --- timm/models/layers/evo_norm.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index 2ff692db..9023afd0 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -38,13 +38,15 @@ class EvoNormBatch2d(nn.Module): x_type = x.dtype if self.training: var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) - self.running_var.copy_(self.momentum * var.detach() + (1 - self.momentum) * self.running_var) + n = x.numel() / x.shape[1] + self.running_var.copy_( + var.detach() * self.momentum * (n / (n - 1)) + self.running_var * (1 - self.momentum)) else: var = self.running_var if self.apply_act: v = self.v.to(dtype=x_type) - d = (x * v) + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type) + d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type) d = d.max((var + self.eps).sqrt().to(dtype=x_type)) x = x / d return x * self.weight + self.bias @@ -74,8 +76,8 @@ class EvoNormSample2d(nn.Module): B, C, H, W = x.shape assert C % self.groups == 0 if self.apply_act: - n = (x * self.v).sigmoid().reshape(B, self.groups, -1) + n = x * (x * self.v).sigmoid() x = x.reshape(B, self.groups, -1) - x = n / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt() + x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt() x = x.reshape(B, C, H, W) return x * self.weight + self.bias