You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
pytorch-image-models/timm/layers/norm_act.py

465 lines
17 KiB

""" Normalization + Activation Layers
Provides Norm+Act fns for standard PyTorch norm layers such as
* BatchNorm
* GroupNorm
* LayerNorm
This allows swapping with alternative layers that are natively both norm + act such as
* EvoNorm (evo_norm.py)
* FilterResponseNorm (filter_response_norm.py)
* InplaceABN (inplace_abn.py)
Hacked together by / Copyright 2022 Ross Wightman
"""
from typing import Union, List, Optional, Any
import torch
from torch import nn as nn
from torch.nn import functional as F
from torchvision.ops.misc import FrozenBatchNorm2d
from .create_act import get_act_layer
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
from .trace_utils import _assert
class BatchNormAct2d(nn.BatchNorm2d):
"""BatchNorm + Activation
This module performs BatchNorm + Activation in a manner that will remain backwards
compatible with weights trained with separate bn, act. This is why we inherit from BN
instead of composing it as a .bn member.
"""
def __init__(
self,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
device=None,
dtype=None
):
try:
factory_kwargs = {'device': device, 'dtype': dtype}
super(BatchNormAct2d, self).__init__(
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats,
**factory_kwargs
)
except TypeError:
# NOTE for backwards compat with old PyTorch w/o factory device/dtype support
super(BatchNormAct2d, self).__init__(
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
def forward(self, x):
# cut & paste of torch.nn.BatchNorm2d.forward impl to avoid issues with torchscript and tracing
_assert(x.ndim == 4, f'expected 4D input (got {x.ndim}D input)')
# exponential_average_factor is set to self.momentum
# (when it is available) only so that it gets updated
# in ONNX graph when this node is exported to ONNX.
if self.momentum is None:
exponential_average_factor = 0.0
else:
exponential_average_factor = self.momentum
if self.training and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None
if self.num_batches_tracked is not None: # type: ignore[has-type]
self.num_batches_tracked.add_(1) # type: ignore[has-type]
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
r"""
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
"""
if self.training:
bn_training = True
else:
bn_training = (self.running_mean is None) and (self.running_var is None)
r"""
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
used for normalization (i.e. in eval mode when buffers are not None).
"""
x = F.batch_norm(
x,
# If buffers are not to be tracked, ensure that they won't be updated
self.running_mean if not self.training or self.track_running_stats else None,
self.running_var if not self.training or self.track_running_stats else None,
self.weight,
self.bias,
bn_training,
exponential_average_factor,
self.eps,
)
x = self.drop(x)
x = self.act(x)
return x
class SyncBatchNormAct(nn.SyncBatchNorm):
# Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254)
# This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers
# but ONLY when used in conjunction with the timm conversion function below.
# Do not create this module directly or use the PyTorch conversion function.
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = super().forward(x) # SyncBN doesn't work with torchscript anyways, so this is fine
if hasattr(self, "drop"):
x = self.drop(x)
if hasattr(self, "act"):
x = self.act(x)
return x
def convert_sync_batchnorm(module, process_group=None):
# convert both BatchNorm and BatchNormAct layers to Synchronized variants
module_output = module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
if isinstance(module, BatchNormAct2d):
# convert timm norm + act layer
module_output = SyncBatchNormAct(
module.num_features,
module.eps,
module.momentum,
module.affine,
module.track_running_stats,
process_group=process_group,
)
# set act and drop attr from the original module
module_output.act = module.act
module_output.drop = module.drop
else:
# convert standard BatchNorm layers
module_output = torch.nn.SyncBatchNorm(
module.num_features,
module.eps,
module.momentum,
module.affine,
module.track_running_stats,
process_group,
)
if module.affine:
with torch.no_grad():
module_output.weight = module.weight
module_output.bias = module.bias
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
if hasattr(module, "qconfig"):
module_output.qconfig = module.qconfig
for name, child in module.named_children():
module_output.add_module(name, convert_sync_batchnorm(child, process_group))
del module
return module_output
class FrozenBatchNormAct2d(torch.nn.Module):
"""
BatchNormAct2d where the batch statistics and the affine parameters are fixed
Args:
num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)``
eps (float): a value added to the denominator for numerical stability. Default: 1e-5
"""
def __init__(
self,
num_features: int,
eps: float = 1e-5,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super().__init__()
self.eps = eps
self.register_buffer("weight", torch.ones(num_features))
self.register_buffer("bias", torch.zeros(num_features))
self.register_buffer("running_mean", torch.zeros(num_features))
self.register_buffer("running_var", torch.ones(num_features))
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
def _load_from_state_dict(
self,
state_dict: dict,
prefix: str,
local_metadata: dict,
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
):
num_batches_tracked_key = prefix + "num_batches_tracked"
if num_batches_tracked_key in state_dict:
del state_dict[num_batches_tracked_key]
super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# move reshapes to the beginning
# to make it fuser-friendly
w = self.weight.reshape(1, -1, 1, 1)
b = self.bias.reshape(1, -1, 1, 1)
rv = self.running_var.reshape(1, -1, 1, 1)
rm = self.running_mean.reshape(1, -1, 1, 1)
scale = w * (rv + self.eps).rsqrt()
bias = b - rm * scale
x = x * scale + bias
x = self.act(self.drop(x))
return x
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps}, act={self.act})"
def freeze_batch_norm_2d(module):
"""
Converts all `BatchNorm2d` and `SyncBatchNorm` or `BatchNormAct2d` and `SyncBatchNormAct2d` layers
of provided module into `FrozenBatchNorm2d` or `FrozenBatchNormAct2d` respectively.
Args:
module (torch.nn.Module): Any PyTorch module.
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
if isinstance(module, (BatchNormAct2d, SyncBatchNormAct)):
res = FrozenBatchNormAct2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
res.drop = module.drop
res.act = module.act
elif isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
res = FrozenBatchNorm2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = freeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def unfreeze_batch_norm_2d(module):
"""
Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance
of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked
recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
Returns:
torch.nn.Module: Resulting module
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
"""
res = module
if isinstance(module, FrozenBatchNormAct2d):
res = BatchNormAct2d(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
res.drop = module.drop
res.act = module.act
elif isinstance(module, FrozenBatchNorm2d):
res = torch.nn.BatchNorm2d(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = unfreeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def _num_groups(num_channels, num_groups, group_size):
if group_size:
assert num_channels % group_size == 0
return num_channels // group_size
return num_groups
class GroupNormAct(nn.GroupNorm):
# NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
def __init__(
self,
num_channels,
num_groups=32,
eps=1e-5,
affine=True,
group_size=None,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super(GroupNormAct, self).__init__(
_num_groups(num_channels, num_groups, group_size),
num_channels,
eps=eps,
affine=affine,
)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
self._fast_norm = is_fast_norm()
def forward(self, x):
if self._fast_norm:
x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
x = self.drop(x)
x = self.act(x)
return x
class GroupNorm1Act(nn.GroupNorm):
def __init__(
self,
num_channels,
eps=1e-5,
affine=True,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super(GroupNorm1Act, self).__init__(1, num_channels, eps=eps, affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
self._fast_norm = is_fast_norm()
def forward(self, x):
if self._fast_norm:
x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
else:
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
x = self.drop(x)
x = self.act(x)
return x
class LayerNormAct(nn.LayerNorm):
def __init__(
self,
normalization_shape: Union[int, List[int], torch.Size],
eps=1e-5,
affine=True,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
self._fast_norm = is_fast_norm()
def forward(self, x):
if self._fast_norm:
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
else:
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = self.drop(x)
x = self.act(x)
return x
class LayerNormAct2d(nn.LayerNorm):
def __init__(
self,
num_channels,
eps=1e-5,
affine=True,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
):
super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
self._fast_norm = is_fast_norm()
def forward(self, x):
x = x.permute(0, 2, 3, 1)
if self._fast_norm:
x = fast_layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
else:
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.permute(0, 3, 1, 2)
x = self.drop(x)
x = self.act(x)
return x