Freeze unfreeze functionality finalized. Tests added

pull/876/head
Alexander Soare 3 years ago
parent 0cb8ea432c
commit 65c3d78b96

@ -0,0 +1,60 @@
from torch.nn.modules.batchnorm import BatchNorm2d
from torchvision.ops.misc import FrozenBatchNorm2d
import timm
from timm.utils.model import freeze, unfreeze
def test_freeze_unfreeze():
model = timm.create_model('resnet18')
# Freeze all
freeze(model)
# Check top level module
assert model.fc.weight.requires_grad == False
# Check submodule
assert model.layer1[0].conv1.weight.requires_grad == False
# Check BN
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
# Unfreeze all
unfreeze(model)
# Check top level module
assert model.fc.weight.requires_grad == True
# Check submodule
assert model.layer1[0].conv1.weight.requires_grad == True
# Check BN
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
# Freeze some
freeze(model, ['layer1', 'layer2.0'])
# Check frozen
assert model.layer1[0].conv1.weight.requires_grad == False
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
assert model.layer2[0].conv1.weight.requires_grad == False
# Check not frozen
assert model.layer3[0].conv1.weight.requires_grad == True
assert isinstance(model.layer3[0].bn1, BatchNorm2d)
assert model.layer2[1].conv1.weight.requires_grad == True
# Unfreeze some
unfreeze(model, ['layer1', 'layer2.0'])
# Check not frozen
assert model.layer1[0].conv1.weight.requires_grad == True
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
assert model.layer2[0].conv1.weight.requires_grad == True
# Freeze BN
# From root
freeze(model, ['layer1.0.bn1'])
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
# From direct parent
freeze(model.layer1[0], ['bn1'])
assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d)
# Unfreeze BN
unfreeze(model, ['layer1.0.bn1'])
assert isinstance(model.layer1[0].bn1, BatchNorm2d)
# From direct parent
unfreeze(model.layer1[0], ['bn1'])
assert isinstance(model.layer1[0].bn1, BatchNorm2d)

@ -3,7 +3,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
class GroupNorm(nn.GroupNorm):
@ -23,42 +22,3 @@ class LayerNorm2d(nn.LayerNorm):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
class FrozenBatchNorm2d(torchvision.ops.misc.FrozenBatchNorm2d):
"""
BatchNorm2d where the batch statistics and the affine parameters are fixed.
Inherits from torchvision while adding the `convert_frozen_batchnorm` from
https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/layers/batch_norm.py
"""
@classmethod
def convert_frozen_batchnorm(cls, module):
"""
Converts all BatchNorm layers of provided module into FrozenBatchNorm. If `module` is a type of BatchNorm, it
converts it into FrozenBatchNorm. Otherwise, the module is walked recursively and BatchNorm type layers are
converted in place.
Args:
module (torch.nn.Module): Any PyTorch module. It doesn't have to be a BatchNorm variant in itself.
Returns:
torch.nn.Module: Resulting module
"""
res = module
if isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)):
res = cls(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = cls.convert_frozen_batchnorm(child)
if new_child is not child:
res.add_module(name, new_child)
return res

@ -4,16 +4,12 @@ Hacked together by / Copyright 2020 Ross Wightman
"""
from logging import root
from typing import Sequence
import re
import warnings
import torch
import fnmatch
from torch.nn.modules import module
from torchvision.ops.misc import FrozenBatchNorm2d
from .model_ema import ModelEma
from timm.models.layers.norm import FrozenBatchNorm2d
def unwrap_model(model):
@ -99,55 +95,172 @@ def extract_spp_stats(model,
hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns)
_ = model(x)
return hook.stats
def freeze(modules, root_module=None, include_bn_running_stats=True, mode=True):
def freeze_batch_norm_2d(module):
"""
Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is
itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and
returned. Otherwise, the module is walked recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
Returns:
torch.nn.Module: Resulting module
"""
res = module
if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
res = FrozenBatchNorm2d(module.num_features)
res.num_features = module.num_features
res.affine = module.affine
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = freeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def unfreeze_batch_norm_2d(module):
"""
Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance
of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked
recursively and submodules are converted in place.
Args:
module (torch.nn.Module): Any PyTorch module.
Returns:
torch.nn.Module: Resulting module
"""
res = module
if isinstance(module, FrozenBatchNorm2d):
res = torch.nn.BatchNorm2d(module.num_features)
if module.affine:
res.weight.data = module.weight.data.clone().detach()
res.bias.data = module.bias.data.clone().detach()
res.running_mean.data = module.running_mean.data
res.running_var.data = module.running_var.data
res.eps = module.eps
else:
for name, child in module.named_children():
new_child = unfreeze_batch_norm_2d(child)
if new_child is not child:
res.add_module(name, new_child)
return res
def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'):
"""
Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is
done in place.
Args:
modules (nn.Module or list[nn.Module] or str or list[str]): List of modules for which the parameters will be
(un)frozen. If a string or strings are provided these will be interpreted according to the named modules
of the provided ``root_module``.
root_module (nn.Module, optional): Root module relative to which named modules (accessible via
``root_module.named_modules()``) are referenced. Must be provided if the `modules` argument is specified
with a string or strings. Defaults to `None`.
include_bn_running_stats (bool): Whether to also (un)freeze the running statistics of batch norm layers.
root_module (nn.Module, optional): Root module relative to which the `submodules` are referenced.
submodules (list[str]): List of modules for which the parameters will be (un)frozen. They are to be provided as
named modules relative to the root module (accessible via `root_module.named_modules()`). An empty list
means that the whole root module will be (un)frozen. Defaults to []
include_bn_running_stats (bool): Whether to also (un)freeze the running statistics of batch norm 2d layers.
Defaults to `True`.
mode (bool): Whether to freeze (`True`) or unfreeze (`False`). Defaults to `True`.
TODO before finalizing PR: Implement unfreezing of batch norm
mode (bool): Whether to freeze ("freeze") or unfreeze ("unfreeze"). Defaults to `"freeze"`.
"""
if not isinstance(modules, Sequence):
modules = [modules]
assert mode in ["freeze", "unfreeze"], '`mode` must be one of "freeze" or "unfreeze"'
if isinstance(root_module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
# Raise assertion here because we can't convert it in place
raise AssertionError(
"You have provided a batch norm layer as the `root module`. Please use "
"`timm.utils.model.freeze_batch_norm_2d` or `timm.utils.model.unfreeze_batch_norm_2d` instead.")
if isinstance(modules[0], str):
assert root_module is not None, \
"When providing strings for the `modules` argument, a `root_module` must be provided"
module_names = modules
modules = [root_module.get_submodule(m) for m in module_names]
if isinstance(submodules, str):
submodules = [submodules]
for n, m in zip(module_names, modules):
named_modules = submodules
submodules = [root_module.get_submodule(m) for m in submodules]
if not(len(submodules)):
named_modules, submodules = list(zip(*root_module.named_children()))
for n, m in zip(named_modules, submodules):
# (Un)freeze parameters
for p in m.parameters():
p.requires_grad = (not mode)
p.requires_grad = (False if mode == 'freeze' else True)
if include_bn_running_stats:
res = FrozenBatchNorm2d.convert_frozen_batchnorm(m)
# It's possible that `m` is a type of BatchNorm in itself, in which case
# `FrozenBatchNorm2d.convert_frozen_batchnorm` won't convert it in place, but will return the converted
# result. In this case `res` holds the converted result and we may try to re-assign the named module
if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
if module_names is not None and root_module is not None:
root_module.add_module(n, res)
# Helper to add submodule specified as a named_module
def _add_submodule(module, name, submodule):
split = name.rsplit('.', 1)
if len(split) > 1:
module.get_submodule(split[0]).add_module(split[1], submodule)
else:
raise RuntimeError(
"Could not freeze batch norm statistics due to a technical limitation. Hint: Try calling "
"`freeze` with a list of module names while providing a `root_module` argument.")
module.add_module(name, submodule)
# Freeze batch norm
if mode == 'freeze':
res = freeze_batch_norm_2d(m)
# It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't
# convert it in place, but will return the converted result. In this case `res` holds the converted
# result and we may try to re-assign the named module
if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
_add_submodule(root_module, n, res)
# Unfreeze batch norm
else:
res = unfreeze_batch_norm_2d(m)
# Ditto. See note above in mode == 'freeze' branch
if isinstance(m, FrozenBatchNorm2d):
_add_submodule(root_module, n, res)
def freeze(root_module, submodules=[], include_bn_running_stats=True):
"""
Freeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place.
Args:
root_module (nn.Module): Root module relative to which `submodules` are referenced.
submodules (list[str]): List of modules for which the parameters will be frozen. They are to be provided as
named modules relative to the root module (accessible via `root_module.named_modules()`). An empty list
means that the whole root module will be frozen. Defaults to `[]`.
include_bn_running_stats (bool): Whether to also freeze the running statistics of `BatchNorm2d` and
`SyncBatchNorm` layers. These will be converted to `FrozenBatchNorm2d` in place. Hint: During fine tuning,
it's good practice to freeze batch norm stats. And note that these are different to the affine parameters
which are just normal PyTorch parameters. Defaults to `True`.
Hint: If you want to freeze batch norm ONLY, use `timm.utils.model.freeze_batch_norm_2d`.
Examples::
>>> model = timm.create_model('resnet18')
>>> # Freeze up to and including layer2
>>> submodules = [n for n, _ in model.named_children()]
>>> print(submodules)
['conv1', 'bn1', 'act1', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'global_pool', 'fc']
>>> freeze(model, submodules[:submodules.index('layer2') + 1])
>>> # Check for yourself that it works as expected
>>> print(model.layer2[0].conv1.weight.requires_grad)
False
>>> print(model.layer3[0].conv1.weight.requires_grad)
True
>>> # Unfreeze
>>> unfreeze(model)
"""
_freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="freeze")
def unfreeze(modules, root_module=None, include_bn_running_stats=True):
def unfreeze(root_module, submodules=[], include_bn_running_stats=True):
"""
Idiomatic convenience function to call `freeze` with `mode == False`. See docstring of `freeze` for further
information.
Unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place.
Args:
root_module (nn.Module): Root module relative to which `submodules` are referenced.
submodules (list[str]): List of submodules for which the parameters will be (un)frozen. They are to be provided
as named modules relative to the root module (accessible via `root_module.named_modules()`). An empty
list means that the whole root module will be unfrozen. Defaults to `[]`.
include_bn_running_stats (bool): Whether to also unfreeze the running statistics of `FrozenBatchNorm2d` layers.
These will be converted to `BatchNorm2d` in place. Defaults to `True`.
See example in docstring for `freeze`.
"""
freeze(modules, root_module=root_module, include_bn_running_stats=include_bn_running_stats, mode=False)
_freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="unfreeze")
Loading…
Cancel
Save