diff --git a/timm/layers/__init__.py b/timm/layers/__init__.py index 625b4826..6b2dabba 100644 --- a/timm/layers/__init__.py +++ b/timm/layers/__init__.py @@ -29,7 +29,8 @@ from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp, GlobalResponseNormMlp from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d -from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm +from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\ + SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d from .padding import get_padding, get_same_padding, pad_same from .patch_embed import PatchEmbed, resample_patch_embed from .pool2d_same import AvgPool2dSame, create_pool2d diff --git a/timm/layers/norm_act.py b/timm/layers/norm_act.py index ff075fbc..5ca21d18 100644 --- a/timm/layers/norm_act.py +++ b/timm/layers/norm_act.py @@ -17,6 +17,7 @@ from typing import Union, List, Optional, Any import torch from torch import nn as nn from torch.nn import functional as F +from torchvision.ops.misc import FrozenBatchNorm2d from .create_act import get_act_layer from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm @@ -77,7 +78,7 @@ class BatchNormAct2d(nn.BatchNorm2d): if self.training and self.track_running_stats: # TODO: if statement only here to tell the jit to skip emitting this when it is None if self.num_batches_tracked is not None: # type: ignore[has-type] - self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type] + self.num_batches_tracked.add_(1) # type: ignore[has-type] if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float(self.num_batches_tracked) else: # use exponential moving average @@ -169,6 +170,159 @@ def convert_sync_batchnorm(module, process_group=None): return module_output +class FrozenBatchNormAct2d(torch.nn.Module): + """ + BatchNormAct2d where the batch statistics and the affine parameters are fixed + + Args: + num_features (int): Number of features ``C`` from an expected input of size ``(N, C, H, W)`` + eps (float): a value added to the denominator for numerical stability. Default: 1e-5 + """ + + def __init__( + self, + num_features: int, + eps: float = 1e-5, + apply_act=True, + act_layer=nn.ReLU, + inplace=True, + drop_layer=None, + ): + super().__init__() + self.eps = eps + self.register_buffer("weight", torch.ones(num_features)) + self.register_buffer("bias", torch.zeros(num_features)) + self.register_buffer("running_mean", torch.zeros(num_features)) + self.register_buffer("running_var", torch.ones(num_features)) + + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + act_layer = get_act_layer(act_layer) # string -> nn.Module + if act_layer is not None and apply_act: + act_args = dict(inplace=True) if inplace else {} + self.act = act_layer(**act_args) + else: + self.act = nn.Identity() + + def _load_from_state_dict( + self, + state_dict: dict, + prefix: str, + local_metadata: dict, + strict: bool, + missing_keys: List[str], + unexpected_keys: List[str], + error_msgs: List[str], + ): + num_batches_tracked_key = prefix + "num_batches_tracked" + if num_batches_tracked_key in state_dict: + del state_dict[num_batches_tracked_key] + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # move reshapes to the beginning + # to make it fuser-friendly + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + rv = self.running_var.reshape(1, -1, 1, 1) + rm = self.running_mean.reshape(1, -1, 1, 1) + scale = w * (rv + self.eps).rsqrt() + bias = b - rm * scale + x = x * scale + bias + x = self.act(self.drop(x)) + return x + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.weight.shape[0]}, eps={self.eps}, act={self.act})" + + +def freeze_batch_norm_2d(module): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` or `BatchNormAct2d` and `SyncBatchNormAct2d` layers + of provided module into `FrozenBatchNorm2d` or `FrozenBatchNormAct2d` respectively. + + Args: + module (torch.nn.Module): Any PyTorch module. + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + if isinstance(module, (BatchNormAct2d, SyncBatchNormAct)): + res = FrozenBatchNormAct2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + res.drop = module.drop + res.act = module.act + elif isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for name, child in module.named_children(): + new_child = freeze_batch_norm_2d(child) + if new_child is not child: + res.add_module(name, new_child) + return res + + +def unfreeze_batch_norm_2d(module): + """ + Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance + of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked + recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + if isinstance(module, FrozenBatchNormAct2d): + res = BatchNormAct2d(module.num_features) + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + res.drop = module.drop + res.act = module.act + elif isinstance(module, FrozenBatchNorm2d): + res = torch.nn.BatchNorm2d(module.num_features) + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for name, child in module.named_children(): + new_child = unfreeze_batch_norm_2d(child) + if new_child is not child: + res.add_module(name, new_child) + return res + + def _num_groups(num_channels, num_groups, group_size): if group_size: assert num_channels % group_size == 0 @@ -179,10 +333,54 @@ def _num_groups(num_channels, num_groups, group_size): class GroupNormAct(nn.GroupNorm): # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args def __init__( - self, num_channels, num_groups=32, eps=1e-5, affine=True, group_size=None, - apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): + self, + num_channels, + num_groups=32, + eps=1e-5, + affine=True, + group_size=None, + apply_act=True, + act_layer=nn.ReLU, + inplace=True, + drop_layer=None, + ): super(GroupNormAct, self).__init__( - _num_groups(num_channels, num_groups, group_size), num_channels, eps=eps, affine=affine) + _num_groups(num_channels, num_groups, group_size), + num_channels, + eps=eps, + affine=affine, + ) + self.drop = drop_layer() if drop_layer is not None else nn.Identity() + act_layer = get_act_layer(act_layer) # string -> nn.Module + if act_layer is not None and apply_act: + act_args = dict(inplace=True) if inplace else {} + self.act = act_layer(**act_args) + else: + self.act = nn.Identity() + self._fast_norm = is_fast_norm() + + def forward(self, x): + if self._fast_norm: + x = fast_group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + else: + x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + x = self.drop(x) + x = self.act(x) + return x + + +class GroupNorm1Act(nn.GroupNorm): + def __init__( + self, + num_channels, + eps=1e-5, + affine=True, + apply_act=True, + act_layer=nn.ReLU, + inplace=True, + drop_layer=None, + ): + super(GroupNorm1Act, self).__init__(1, num_channels, eps=eps, affine=affine) self.drop = drop_layer() if drop_layer is not None else nn.Identity() act_layer = get_act_layer(act_layer) # string -> nn.Module if act_layer is not None and apply_act: @@ -204,8 +402,15 @@ class GroupNormAct(nn.GroupNorm): class LayerNormAct(nn.LayerNorm): def __init__( - self, normalization_shape: Union[int, List[int], torch.Size], eps=1e-5, affine=True, - apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): + self, + normalization_shape: Union[int, List[int], torch.Size], + eps=1e-5, + affine=True, + apply_act=True, + act_layer=nn.ReLU, + inplace=True, + drop_layer=None, + ): super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine) self.drop = drop_layer() if drop_layer is not None else nn.Identity() act_layer = get_act_layer(act_layer) # string -> nn.Module @@ -228,8 +433,15 @@ class LayerNormAct(nn.LayerNorm): class LayerNormAct2d(nn.LayerNorm): def __init__( - self, num_channels, eps=1e-5, affine=True, - apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): + self, + num_channels, + eps=1e-5, + affine=True, + apply_act=True, + act_layer=nn.ReLU, + inplace=True, + drop_layer=None, + ): super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine) self.drop = drop_layer() if drop_layer is not None else nn.Identity() act_layer = get_act_layer(act_layer) # string -> nn.Module diff --git a/timm/utils/model.py b/timm/utils/model.py index b95c4539..d74ee5b7 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -7,6 +7,8 @@ import fnmatch import torch from torchvision.ops.misc import FrozenBatchNorm2d +from timm.layers import BatchNormAct2d, SyncBatchNormAct, FrozenBatchNormAct2d,\ + freeze_batch_norm_2d, unfreeze_batch_norm_2d from .model_ema import ModelEma @@ -100,70 +102,6 @@ def extract_spp_stats( return hook.stats -def freeze_batch_norm_2d(module): - """ - Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is - itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and - returned. Otherwise, the module is walked recursively and submodules are converted in place. - - Args: - module (torch.nn.Module): Any PyTorch module. - - Returns: - torch.nn.Module: Resulting module - - Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 - """ - res = module - if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): - res = FrozenBatchNorm2d(module.num_features) - res.num_features = module.num_features - res.affine = module.affine - if module.affine: - res.weight.data = module.weight.data.clone().detach() - res.bias.data = module.bias.data.clone().detach() - res.running_mean.data = module.running_mean.data - res.running_var.data = module.running_var.data - res.eps = module.eps - else: - for name, child in module.named_children(): - new_child = freeze_batch_norm_2d(child) - if new_child is not child: - res.add_module(name, new_child) - return res - - -def unfreeze_batch_norm_2d(module): - """ - Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance - of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked - recursively and submodules are converted in place. - - Args: - module (torch.nn.Module): Any PyTorch module. - - Returns: - torch.nn.Module: Resulting module - - Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 - """ - res = module - if isinstance(module, FrozenBatchNorm2d): - res = torch.nn.BatchNorm2d(module.num_features) - if module.affine: - res.weight.data = module.weight.data.clone().detach() - res.bias.data = module.bias.data.clone().detach() - res.running_mean.data = module.running_mean.data - res.running_var.data = module.running_var.data - res.eps = module.eps - else: - for name, child in module.named_children(): - new_child = unfreeze_batch_norm_2d(child) - if new_child is not child: - res.add_module(name, new_child) - return res - - def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'): """ Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is @@ -179,7 +117,12 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, """ assert mode in ["freeze", "unfreeze"], '`mode` must be one of "freeze" or "unfreeze"' - if isinstance(root_module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + if isinstance(root_module, ( + torch.nn.modules.batchnorm.BatchNorm2d, + torch.nn.modules.batchnorm.SyncBatchNorm, + BatchNormAct2d, + SyncBatchNormAct, + )): # Raise assertion here because we can't convert it in place raise AssertionError( "You have provided a batch norm layer as the `root module`. Please use " @@ -213,13 +156,18 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, # It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't # convert it in place, but will return the converted result. In this case `res` holds the converted # result and we may try to re-assign the named module - if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + if isinstance(m, ( + torch.nn.modules.batchnorm.BatchNorm2d, + torch.nn.modules.batchnorm.SyncBatchNorm, + BatchNormAct2d, + SyncBatchNormAct, + )): _add_submodule(root_module, n, res) # Unfreeze batch norm else: res = unfreeze_batch_norm_2d(m) # Ditto. See note above in mode == 'freeze' branch - if isinstance(m, FrozenBatchNorm2d): + if isinstance(m, (FrozenBatchNorm2d, FrozenBatchNormAct2d)): _add_submodule(root_module, n, res)