From 4e4b863b1538be396a205b5c9f6e3ac4721506cf Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 12 Apr 2021 09:57:56 -0700 Subject: [PATCH] Missed norm.py --- timm/models/layers/norm.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 timm/models/layers/norm.py diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py new file mode 100644 index 00000000..2925e5c7 --- /dev/null +++ b/timm/models/layers/norm.py @@ -0,0 +1,14 @@ +""" Normalization layers and wrappers +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GroupNorm(nn.GroupNorm): + def __init__(self, num_channels, num_groups, eps=1e-5, affine=True): + # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN + super().__init__(num_groups, num_channels, eps=eps, affine=affine) + + def forward(self, x): + return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)