pull/876/head
Alexander Soare 3 years ago
parent b5bf4dce98
commit 0cb8ea432c

@ -3,6 +3,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
class GroupNorm(nn.GroupNorm):
@ -22,3 +23,42 @@ class LayerNorm2d(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
class FrozenBatchNorm2d(torchvision.ops.misc.FrozenBatchNorm2d):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Inherits from torchvision while adding the `convert_frozen_batchnorm` from
https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/layers/batch_norm.py
"""
@classmethod
def convert_frozen_batchnorm(cls, module):
"""
Converts all BatchNorm layers of provided module into FrozenBatchNorm. If `module` is a type of BatchNorm, it
converts it into FrozenBatchNorm. Otherwise, the module is walked recursively and BatchNorm type layers are
converted in place.
Args:
module (torch.nn.Module): Any PyTorch module. It doesn't have to be a BatchNorm variant in itself.
Returns:
torch.nn.Module: Resulting module
"""
res = module
if isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
res = cls(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = cls.convert_frozen_batchnorm(child)
if new_child is not child:
res.add_module(name, new_child)
return res

@ -2,10 +2,20 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
from .model_ema import ModelEma
from logging import root
from typing import Sequence
import re
import warnings
import torch
import fnmatch
from torch.nn.modules import module
from .model_ema import ModelEma
from timm.models.layers.norm import FrozenBatchNorm2d
def unwrap_model(model):
if isinstance(model, ModelEma):
return unwrap_model(model.ema)
@ -89,4 +99,55 @@ def extract_spp_stats(model,
hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns)
_ = model(x)
return hook.stats
def freeze(modules, root_module=None, include_bn_running_stats=True, mode=True):
"""
Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is
done in place.
Args:
modules (nn.Module or list[nn.Module] or str or list[str]): List of modules for which the parameters will be
(un)frozen. If a string or strings are provided these will be interpreted according to the named modules
of the provided ``root_module``.
root_module (nn.Module, optional): Root module relative to which named modules (accessible via
``root_module.named_modules()``) are referenced. Must be provided if the `modules` argument is specified
with a string or strings. Defaults to `None`.
include_bn_running_stats (bool): Whether to also (un)freeze the running statistics of batch norm layers.
Defaults to `True`.
mode (bool): Whether to freeze (`True`) or unfreeze (`False`). Defaults to `True`.
TODO before finalizing PR: Implement unfreezing of batch norm
"""
if not isinstance(modules, Sequence):
modules = [modules]
if isinstance(modules[0], str):
assert root_module is not None, \
"When providing strings for the `modules` argument, a `root_module` must be provided"
module_names = modules
modules = [root_module.get_submodule(m) for m in module_names]
for n, m in zip(module_names, modules):
for p in m.parameters():
p.requires_grad = (not mode)
if include_bn_running_stats:
res = FrozenBatchNorm2d.convert_frozen_batchnorm(m)
# It's possible that `m` is a type of BatchNorm in itself, in which case
# `FrozenBatchNorm2d.convert_frozen_batchnorm` won't convert it in place, but will return the converted
# result. In this case `res` holds the converted result and we may try to re-assign the named module
if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
if module_names is not None and root_module is not None:
root_module.add_module(n, res)
else:
raise RuntimeError(
"Could not freeze batch norm statistics due to a technical limitation. Hint: Try calling "
"`freeze` with a list of module names while providing a `root_module` argument.")
def unfreeze(modules, root_module=None, include_bn_running_stats=True):
"""
Idiomatic convenience function to call `freeze` with `mode == False`. See docstring of `freeze` for further
information.
"""
freeze(modules, root_module=root_module, include_bn_running_stats=include_bn_running_stats, mode=False)

Loading…
Cancel
Save