diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py new file mode 100644 index 00000000..2925e5c7 --- /dev/null +++ b/timm/models/layers/norm.py @@ -0,0 +1,14 @@ +""" Normalization layers and wrappers +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class GroupNorm(nn.GroupNorm): + def __init__(self, num_channels, num_groups, eps=1e-5, affine=True): + # NOTE num_channels is swapped to first arg for consistency in swapping norm layers with BN + super().__init__(num_groups, num_channels, eps=eps, affine=affine) + + def forward(self, x): + return F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)