diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index aace107b..fc500807 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import torchvision class GroupNorm(nn.GroupNorm): @@ -22,3 +23,42 @@ class LayerNorm2d(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: return F.layer_norm( x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + + +class FrozenBatchNorm2d(torchvision.ops.misc.FrozenBatchNorm2d): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Inherits from torchvision while adding the `convert_frozen_batchnorm` from + https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/layers/batch_norm.py + """ + + @classmethod + def convert_frozen_batchnorm(cls, module): + """ + Converts all BatchNorm layers of provided module into FrozenBatchNorm. If `module` is a type of BatchNorm, it + converts it into FrozenBatchNorm. Otherwise, the module is walked recursively and BatchNorm type layers are + converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. It doesn't have to be a BatchNorm variant in itself. + + Returns: + torch.nn.Module: Resulting module + """ + res = module + if isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): + res = cls(module.num_features) + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for name, child in module.named_children(): + new_child = cls.convert_frozen_batchnorm(child) + if new_child is not child: + res.add_module(name, new_child) + return res + diff --git a/timm/utils/model.py b/timm/utils/model.py index bd46e2f4..d0fb69ed 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -2,10 +2,20 @@ Hacked together by / Copyright 2020 Ross Wightman """ -from .model_ema import ModelEma +from logging import root +from typing import Sequence +import re +import warnings + import torch import fnmatch +from torch.nn.modules import module + +from .model_ema import ModelEma +from timm.models.layers.norm import FrozenBatchNorm2d + + def unwrap_model(model): if isinstance(model, ModelEma): return unwrap_model(model.ema) @@ -89,4 +99,55 @@ def extract_spp_stats(model, hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns) _ = model(x) return hook.stats - \ No newline at end of file + + +def freeze(modules, root_module=None, include_bn_running_stats=True, mode=True): + """ + Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is + done in place. + Args: + modules (nn.Module or list[nn.Module] or str or list[str]): List of modules for which the parameters will be + (un)frozen. If a string or strings are provided these will be interpreted according to the named modules + of the provided ``root_module``. + root_module (nn.Module, optional): Root module relative to which named modules (accessible via + ``root_module.named_modules()``) are referenced. Must be provided if the `modules` argument is specified + with a string or strings. Defaults to `None`. + include_bn_running_stats (bool): Whether to also (un)freeze the running statistics of batch norm layers. + Defaults to `True`. + mode (bool): Whether to freeze (`True`) or unfreeze (`False`). Defaults to `True`. + + TODO before finalizing PR: Implement unfreezing of batch norm + """ + + if not isinstance(modules, Sequence): + modules = [modules] + + if isinstance(modules[0], str): + assert root_module is not None, \ + "When providing strings for the `modules` argument, a `root_module` must be provided" + module_names = modules + modules = [root_module.get_submodule(m) for m in module_names] + + for n, m in zip(module_names, modules): + for p in m.parameters(): + p.requires_grad = (not mode) + if include_bn_running_stats: + res = FrozenBatchNorm2d.convert_frozen_batchnorm(m) + # It's possible that `m` is a type of BatchNorm in itself, in which case + # `FrozenBatchNorm2d.convert_frozen_batchnorm` won't convert it in place, but will return the converted + # result. In this case `res` holds the converted result and we may try to re-assign the named module + if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + if module_names is not None and root_module is not None: + root_module.add_module(n, res) + else: + raise RuntimeError( + "Could not freeze batch norm statistics due to a technical limitation. Hint: Try calling " + "`freeze` with a list of module names while providing a `root_module` argument.") + + +def unfreeze(modules, root_module=None, include_bn_running_stats=True): + """ + Idiomatic convenience function to call `freeze` with `mode == False`. See docstring of `freeze` for further + information. + """ + freeze(modules, root_module=root_module, include_bn_running_stats=include_bn_running_stats, mode=False)