You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
189 lines
7.8 KiB
189 lines
7.8 KiB
""" Normalization + Activation Layers
|
|
"""
|
|
from typing import Union, List
|
|
|
|
import torch
|
|
from torch import nn as nn
|
|
from torch.nn import functional as F
|
|
|
|
from .create_act import get_act_layer
|
|
|
|
|
|
class BatchNormAct2d(nn.BatchNorm2d):
|
|
"""BatchNorm + Activation
|
|
|
|
This module performs BatchNorm + Activation in a manner that will remain backwards
|
|
compatible with weights trained with separate bn, act. This is why we inherit from BN
|
|
instead of composing it as a .bn member.
|
|
"""
|
|
def __init__(
|
|
self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
|
|
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
|
|
super(BatchNormAct2d, self).__init__(
|
|
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
|
|
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
|
|
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
|
if act_layer is not None and apply_act:
|
|
act_args = dict(inplace=True) if inplace else {}
|
|
self.act = act_layer(**act_args)
|
|
else:
|
|
self.act = nn.Identity()
|
|
|
|
def _forward_jit(self, x):
|
|
""" A cut & paste of the contents of the PyTorch BatchNorm2d forward function
|
|
"""
|
|
# exponential_average_factor is set to self.momentum
|
|
# (when it is available) only so that it gets updated
|
|
# in ONNX graph when this node is exported to ONNX.
|
|
if self.momentum is None:
|
|
exponential_average_factor = 0.0
|
|
else:
|
|
exponential_average_factor = self.momentum
|
|
|
|
if self.training and self.track_running_stats:
|
|
# TODO: if statement only here to tell the jit to skip emitting this when it is None
|
|
if self.num_batches_tracked is not None: # type: ignore[has-type]
|
|
self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type]
|
|
if self.momentum is None: # use cumulative moving average
|
|
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
|
|
else: # use exponential moving average
|
|
exponential_average_factor = self.momentum
|
|
|
|
r"""
|
|
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
|
|
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
|
|
"""
|
|
if self.training:
|
|
bn_training = True
|
|
else:
|
|
bn_training = (self.running_mean is None) and (self.running_var is None)
|
|
|
|
r"""
|
|
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
|
|
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
|
|
used for normalization (i.e. in eval mode when buffers are not None).
|
|
"""
|
|
return F.batch_norm(
|
|
x,
|
|
# If buffers are not to be tracked, ensure that they won't be updated
|
|
self.running_mean if not self.training or self.track_running_stats else None,
|
|
self.running_var if not self.training or self.track_running_stats else None,
|
|
self.weight,
|
|
self.bias,
|
|
bn_training,
|
|
exponential_average_factor,
|
|
self.eps,
|
|
)
|
|
|
|
@torch.jit.ignore
|
|
def _forward_python(self, x):
|
|
return super(BatchNormAct2d, self).forward(x)
|
|
|
|
def forward(self, x):
|
|
# FIXME cannot call parent forward() and maintain jit.script compatibility?
|
|
if torch.jit.is_scripting():
|
|
x = self._forward_jit(x)
|
|
else:
|
|
x = self._forward_python(x)
|
|
x = self.drop(x)
|
|
x = self.act(x)
|
|
return x
|
|
|
|
|
|
def group_norm_tpu(x, w, b, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False, flatten: bool = False):
|
|
# This is a workaround for some odd behaviour running on PyTorch XLA w/ TPUs.
|
|
x_shape = x.shape
|
|
x_dtype = x.dtype
|
|
if flatten:
|
|
norm_shape = (x_shape[0], groups, -1)
|
|
reduce_dim = -1
|
|
else:
|
|
norm_shape = (x_shape[0], groups, x_shape[1] // groups) + x_shape[2:]
|
|
reduce_dim = tuple(range(2, x.ndim + 1))
|
|
affine_shape = (1, -1) + (1,) * (x.ndim - 2)
|
|
x = x.reshape(norm_shape)
|
|
# x = x.to(torch.float32) # for testing w/ AMP
|
|
xm = x.mean(dim=reduce_dim, keepdim=True)
|
|
if diff_sqm:
|
|
# difference of squared mean and mean squared, faster on TPU
|
|
var = (x.square().mean(dim=reduce_dim, keepdim=True) - xm.square()).clamp(0)
|
|
else:
|
|
var = (x - xm).square().mean(dim=reduce_dim, keepdim=True)
|
|
x = (x - xm.expand(norm_shape)) / var.add(eps).sqrt().expand(norm_shape)
|
|
x = x.reshape(x_shape) * w.view(affine_shape) + b.view(affine_shape)
|
|
# x = x.to(x_dtype) # for testing w/ AMP
|
|
return x
|
|
|
|
|
|
def _num_groups(num_channels, num_groups, group_size):
|
|
if group_size:
|
|
assert num_channels % group_size == 0
|
|
return num_channels // group_size
|
|
return num_groups
|
|
|
|
|
|
class GroupNormAct(nn.GroupNorm):
|
|
# NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
|
|
def __init__(
|
|
self, num_channels, num_groups=32, eps=1e-5, affine=True, group_size=None,
|
|
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
|
|
super(GroupNormAct, self).__init__(
|
|
_num_groups(num_channels, num_groups, group_size), num_channels, eps=eps, affine=affine)
|
|
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
|
|
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
|
if act_layer is not None and apply_act:
|
|
act_args = dict(inplace=True) if inplace else {}
|
|
self.act = act_layer(**act_args)
|
|
else:
|
|
self.act = nn.Identity()
|
|
|
|
def forward(self, x):
|
|
if False: # FIXME TPU temporary while resolving some performance issues
|
|
x = group_norm_tpu(x, self.weight, self.bias, self.num_groups, self.eps)
|
|
else:
|
|
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
|
|
x = self.drop(x)
|
|
x = self.act(x)
|
|
return x
|
|
|
|
|
|
class LayerNormAct(nn.LayerNorm):
|
|
def __init__(
|
|
self, normalization_shape: Union[int, List[int], torch.Size], eps=1e-5, affine=True,
|
|
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
|
|
super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine)
|
|
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
|
|
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
|
if act_layer is not None and apply_act:
|
|
act_args = dict(inplace=True) if inplace else {}
|
|
self.act = act_layer(**act_args)
|
|
else:
|
|
self.act = nn.Identity()
|
|
|
|
def forward(self, x):
|
|
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
|
|
x = self.drop(x)
|
|
x = self.act(x)
|
|
return x
|
|
|
|
|
|
class LayerNormAct2d(nn.LayerNorm):
|
|
def __init__(
|
|
self, num_channels, eps=1e-5, affine=True,
|
|
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
|
|
super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine)
|
|
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
|
|
act_layer = get_act_layer(act_layer) # string -> nn.Module
|
|
if act_layer is not None and apply_act:
|
|
act_args = dict(inplace=True) if inplace else {}
|
|
self.act = act_layer(**act_args)
|
|
else:
|
|
self.act = nn.Identity()
|
|
|
|
def forward(self, x):
|
|
x = F.layer_norm(
|
|
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
|
|
x = self.drop(x)
|
|
x = self.act(x)
|
|
return x
|