From 0cb8ea432ce1648ba28171080216d84544b62d1d Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Tue, 21 Sep 2021 12:46:42 +0100 Subject: [PATCH 1/5] wip --- timm/models/layers/norm.py | 40 +++++++++++++++++++++++ timm/utils/model.py | 65 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 103 insertions(+), 2 deletions(-) diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index aace107b..fc500807 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn import torch.nn.functional as F +import torchvision class GroupNorm(nn.GroupNorm): @@ -22,3 +23,42 @@ class LayerNorm2d(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: return F.layer_norm( x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) + + +class FrozenBatchNorm2d(torchvision.ops.misc.FrozenBatchNorm2d): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + Inherits from torchvision while adding the `convert_frozen_batchnorm` from + https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/layers/batch_norm.py + """ + + @classmethod + def convert_frozen_batchnorm(cls, module): + """ + Converts all BatchNorm layers of provided module into FrozenBatchNorm. If `module` is a type of BatchNorm, it + converts it into FrozenBatchNorm. Otherwise, the module is walked recursively and BatchNorm type layers are + converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. It doesn't have to be a BatchNorm variant in itself. + + Returns: + torch.nn.Module: Resulting module + """ + res = module + if isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): + res = cls(module.num_features) + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for name, child in module.named_children(): + new_child = cls.convert_frozen_batchnorm(child) + if new_child is not child: + res.add_module(name, new_child) + return res + diff --git a/timm/utils/model.py b/timm/utils/model.py index bd46e2f4..d0fb69ed 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -2,10 +2,20 @@ Hacked together by / Copyright 2020 Ross Wightman """ -from .model_ema import ModelEma +from logging import root +from typing import Sequence +import re +import warnings + import torch import fnmatch +from torch.nn.modules import module + +from .model_ema import ModelEma +from timm.models.layers.norm import FrozenBatchNorm2d + + def unwrap_model(model): if isinstance(model, ModelEma): return unwrap_model(model.ema) @@ -89,4 +99,55 @@ def extract_spp_stats(model, hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns) _ = model(x) return hook.stats - \ No newline at end of file + + +def freeze(modules, root_module=None, include_bn_running_stats=True, mode=True): + """ + Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is + done in place. + Args: + modules (nn.Module or list[nn.Module] or str or list[str]): List of modules for which the parameters will be + (un)frozen. If a string or strings are provided these will be interpreted according to the named modules + of the provided ``root_module``. + root_module (nn.Module, optional): Root module relative to which named modules (accessible via + ``root_module.named_modules()``) are referenced. Must be provided if the `modules` argument is specified + with a string or strings. Defaults to `None`. + include_bn_running_stats (bool): Whether to also (un)freeze the running statistics of batch norm layers. + Defaults to `True`. + mode (bool): Whether to freeze (`True`) or unfreeze (`False`). Defaults to `True`. + + TODO before finalizing PR: Implement unfreezing of batch norm + """ + + if not isinstance(modules, Sequence): + modules = [modules] + + if isinstance(modules[0], str): + assert root_module is not None, \ + "When providing strings for the `modules` argument, a `root_module` must be provided" + module_names = modules + modules = [root_module.get_submodule(m) for m in module_names] + + for n, m in zip(module_names, modules): + for p in m.parameters(): + p.requires_grad = (not mode) + if include_bn_running_stats: + res = FrozenBatchNorm2d.convert_frozen_batchnorm(m) + # It's possible that `m` is a type of BatchNorm in itself, in which case + # `FrozenBatchNorm2d.convert_frozen_batchnorm` won't convert it in place, but will return the converted + # result. In this case `res` holds the converted result and we may try to re-assign the named module + if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + if module_names is not None and root_module is not None: + root_module.add_module(n, res) + else: + raise RuntimeError( + "Could not freeze batch norm statistics due to a technical limitation. Hint: Try calling " + "`freeze` with a list of module names while providing a `root_module` argument.") + + +def unfreeze(modules, root_module=None, include_bn_running_stats=True): + """ + Idiomatic convenience function to call `freeze` with `mode == False`. See docstring of `freeze` for further + information. + """ + freeze(modules, root_module=root_module, include_bn_running_stats=include_bn_running_stats, mode=False) From 65c3d78b96c8adce220675841de199516b3041ff Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sat, 2 Oct 2021 15:54:14 +0100 Subject: [PATCH 2/5] Freeze unfreeze functionality finalized. Tests added --- tests/test_utils.py | 60 ++++++++++++ timm/models/layers/norm.py | 40 -------- timm/utils/model.py | 195 +++++++++++++++++++++++++++++-------- 3 files changed, 214 insertions(+), 81 deletions(-) create mode 100644 tests/test_utils.py diff --git a/tests/test_utils.py b/tests/test_utils.py new file mode 100644 index 00000000..3e11eacc --- /dev/null +++ b/tests/test_utils.py @@ -0,0 +1,60 @@ +from torch.nn.modules.batchnorm import BatchNorm2d +from torchvision.ops.misc import FrozenBatchNorm2d + +import timm +from timm.utils.model import freeze, unfreeze + + +def test_freeze_unfreeze(): + model = timm.create_model('resnet18') + + # Freeze all + freeze(model) + # Check top level module + assert model.fc.weight.requires_grad == False + # Check submodule + assert model.layer1[0].conv1.weight.requires_grad == False + # Check BN + assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) + + # Unfreeze all + unfreeze(model) + # Check top level module + assert model.fc.weight.requires_grad == True + # Check submodule + assert model.layer1[0].conv1.weight.requires_grad == True + # Check BN + assert isinstance(model.layer1[0].bn1, BatchNorm2d) + + # Freeze some + freeze(model, ['layer1', 'layer2.0']) + # Check frozen + assert model.layer1[0].conv1.weight.requires_grad == False + assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) + assert model.layer2[0].conv1.weight.requires_grad == False + # Check not frozen + assert model.layer3[0].conv1.weight.requires_grad == True + assert isinstance(model.layer3[0].bn1, BatchNorm2d) + assert model.layer2[1].conv1.weight.requires_grad == True + + # Unfreeze some + unfreeze(model, ['layer1', 'layer2.0']) + # Check not frozen + assert model.layer1[0].conv1.weight.requires_grad == True + assert isinstance(model.layer1[0].bn1, BatchNorm2d) + assert model.layer2[0].conv1.weight.requires_grad == True + + # Freeze BN + # From root + freeze(model, ['layer1.0.bn1']) + assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) + # From direct parent + freeze(model.layer1[0], ['bn1']) + assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) + + # Unfreeze BN + unfreeze(model, ['layer1.0.bn1']) + assert isinstance(model.layer1[0].bn1, BatchNorm2d) + # From direct parent + unfreeze(model.layer1[0], ['bn1']) + assert isinstance(model.layer1[0].bn1, BatchNorm2d) \ No newline at end of file diff --git a/timm/models/layers/norm.py b/timm/models/layers/norm.py index fc500807..aace107b 100644 --- a/timm/models/layers/norm.py +++ b/timm/models/layers/norm.py @@ -3,7 +3,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -import torchvision class GroupNorm(nn.GroupNorm): @@ -23,42 +22,3 @@ class LayerNorm2d(nn.LayerNorm): def forward(self, x: torch.Tensor) -> torch.Tensor: return F.layer_norm( x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2) - - -class FrozenBatchNorm2d(torchvision.ops.misc.FrozenBatchNorm2d): - """ - BatchNorm2d where the batch statistics and the affine parameters are fixed. - - Inherits from torchvision while adding the `convert_frozen_batchnorm` from - https://github.com/facebookresearch/detectron2/blob/cbbc1ce26473cb2a5cc8f58e8ada9ae14cb41052/detectron2/layers/batch_norm.py - """ - - @classmethod - def convert_frozen_batchnorm(cls, module): - """ - Converts all BatchNorm layers of provided module into FrozenBatchNorm. If `module` is a type of BatchNorm, it - converts it into FrozenBatchNorm. Otherwise, the module is walked recursively and BatchNorm type layers are - converted in place. - - Args: - module (torch.nn.Module): Any PyTorch module. It doesn't have to be a BatchNorm variant in itself. - - Returns: - torch.nn.Module: Resulting module - """ - res = module - if isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): - res = cls(module.num_features) - if module.affine: - res.weight.data = module.weight.data.clone().detach() - res.bias.data = module.bias.data.clone().detach() - res.running_mean.data = module.running_mean.data - res.running_var.data = module.running_var.data - res.eps = module.eps - else: - for name, child in module.named_children(): - new_child = cls.convert_frozen_batchnorm(child) - if new_child is not child: - res.add_module(name, new_child) - return res - diff --git a/timm/utils/model.py b/timm/utils/model.py index d0fb69ed..c2786401 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -4,16 +4,12 @@ Hacked together by / Copyright 2020 Ross Wightman """ from logging import root from typing import Sequence -import re -import warnings import torch import fnmatch - -from torch.nn.modules import module +from torchvision.ops.misc import FrozenBatchNorm2d from .model_ema import ModelEma -from timm.models.layers.norm import FrozenBatchNorm2d def unwrap_model(model): @@ -99,55 +95,172 @@ def extract_spp_stats(model, hook = ActivationStatsHook(model, hook_fn_locs=hook_fn_locs, hook_fns=hook_fns) _ = model(x) return hook.stats - -def freeze(modules, root_module=None, include_bn_running_stats=True, mode=True): + +def freeze_batch_norm_2d(module): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + + Returns: + torch.nn.Module: Resulting module + """ + res = module + if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for name, child in module.named_children(): + new_child = freeze_batch_norm_2d(child) + if new_child is not child: + res.add_module(name, new_child) + return res + + +def unfreeze_batch_norm_2d(module): + """ + Converts all `FrozenBatchNorm2d` layers of provided module into `BatchNorm2d`. If `module` is itself and instance + of `FrozenBatchNorm2d`, it is converted into `BatchNorm2d` and returned. Otherwise, the module is walked + recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + + Returns: + torch.nn.Module: Resulting module + """ + res = module + if isinstance(module, FrozenBatchNorm2d): + res = torch.nn.BatchNorm2d(module.num_features) + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for name, child in module.named_children(): + new_child = unfreeze_batch_norm_2d(child) + if new_child is not child: + res.add_module(name, new_child) + return res + + +def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, mode='freeze'): """ Freeze or unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place. Args: - modules (nn.Module or list[nn.Module] or str or list[str]): List of modules for which the parameters will be - (un)frozen. If a string or strings are provided these will be interpreted according to the named modules - of the provided ``root_module``. - root_module (nn.Module, optional): Root module relative to which named modules (accessible via - ``root_module.named_modules()``) are referenced. Must be provided if the `modules` argument is specified - with a string or strings. Defaults to `None`. - include_bn_running_stats (bool): Whether to also (un)freeze the running statistics of batch norm layers. + root_module (nn.Module, optional): Root module relative to which the `submodules` are referenced. + submodules (list[str]): List of modules for which the parameters will be (un)frozen. They are to be provided as + named modules relative to the root module (accessible via `root_module.named_modules()`). An empty list + means that the whole root module will be (un)frozen. Defaults to [] + include_bn_running_stats (bool): Whether to also (un)freeze the running statistics of batch norm 2d layers. Defaults to `True`. - mode (bool): Whether to freeze (`True`) or unfreeze (`False`). Defaults to `True`. - - TODO before finalizing PR: Implement unfreezing of batch norm + mode (bool): Whether to freeze ("freeze") or unfreeze ("unfreeze"). Defaults to `"freeze"`. """ - - if not isinstance(modules, Sequence): - modules = [modules] + assert mode in ["freeze", "unfreeze"], '`mode` must be one of "freeze" or "unfreeze"' + + if isinstance(root_module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + # Raise assertion here because we can't convert it in place + raise AssertionError( + "You have provided a batch norm layer as the `root module`. Please use " + "`timm.utils.model.freeze_batch_norm_2d` or `timm.utils.model.unfreeze_batch_norm_2d` instead.") - if isinstance(modules[0], str): - assert root_module is not None, \ - "When providing strings for the `modules` argument, a `root_module` must be provided" - module_names = modules - modules = [root_module.get_submodule(m) for m in module_names] + if isinstance(submodules, str): + submodules = [submodules] - for n, m in zip(module_names, modules): + named_modules = submodules + submodules = [root_module.get_submodule(m) for m in submodules] + + if not(len(submodules)): + named_modules, submodules = list(zip(*root_module.named_children())) + + for n, m in zip(named_modules, submodules): + # (Un)freeze parameters for p in m.parameters(): - p.requires_grad = (not mode) + p.requires_grad = (False if mode == 'freeze' else True) if include_bn_running_stats: - res = FrozenBatchNorm2d.convert_frozen_batchnorm(m) - # It's possible that `m` is a type of BatchNorm in itself, in which case - # `FrozenBatchNorm2d.convert_frozen_batchnorm` won't convert it in place, but will return the converted - # result. In this case `res` holds the converted result and we may try to re-assign the named module - if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): - if module_names is not None and root_module is not None: - root_module.add_module(n, res) + # Helper to add submodule specified as a named_module + def _add_submodule(module, name, submodule): + split = name.rsplit('.', 1) + if len(split) > 1: + module.get_submodule(split[0]).add_module(split[1], submodule) else: - raise RuntimeError( - "Could not freeze batch norm statistics due to a technical limitation. Hint: Try calling " - "`freeze` with a list of module names while providing a `root_module` argument.") + module.add_module(name, submodule) + # Freeze batch norm + if mode == 'freeze': + res = freeze_batch_norm_2d(m) + # It's possible that `m` is a type of BatchNorm in itself, in which case `unfreeze_batch_norm_2d` won't + # convert it in place, but will return the converted result. In this case `res` holds the converted + # result and we may try to re-assign the named module + if isinstance(m, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): + _add_submodule(root_module, n, res) + # Unfreeze batch norm + else: + res = unfreeze_batch_norm_2d(m) + # Ditto. See note above in mode == 'freeze' branch + if isinstance(m, FrozenBatchNorm2d): + _add_submodule(root_module, n, res) + + +def freeze(root_module, submodules=[], include_bn_running_stats=True): + """ + Freeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place. + Args: + root_module (nn.Module): Root module relative to which `submodules` are referenced. + submodules (list[str]): List of modules for which the parameters will be frozen. They are to be provided as + named modules relative to the root module (accessible via `root_module.named_modules()`). An empty list + means that the whole root module will be frozen. Defaults to `[]`. + include_bn_running_stats (bool): Whether to also freeze the running statistics of `BatchNorm2d` and + `SyncBatchNorm` layers. These will be converted to `FrozenBatchNorm2d` in place. Hint: During fine tuning, + it's good practice to freeze batch norm stats. And note that these are different to the affine parameters + which are just normal PyTorch parameters. Defaults to `True`. + + Hint: If you want to freeze batch norm ONLY, use `timm.utils.model.freeze_batch_norm_2d`. + + Examples:: + + >>> model = timm.create_model('resnet18') + >>> # Freeze up to and including layer2 + >>> submodules = [n for n, _ in model.named_children()] + >>> print(submodules) + ['conv1', 'bn1', 'act1', 'maxpool', 'layer1', 'layer2', 'layer3', 'layer4', 'global_pool', 'fc'] + >>> freeze(model, submodules[:submodules.index('layer2') + 1]) + >>> # Check for yourself that it works as expected + >>> print(model.layer2[0].conv1.weight.requires_grad) + False + >>> print(model.layer3[0].conv1.weight.requires_grad) + True + >>> # Unfreeze + >>> unfreeze(model) + """ + _freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="freeze") -def unfreeze(modules, root_module=None, include_bn_running_stats=True): +def unfreeze(root_module, submodules=[], include_bn_running_stats=True): """ - Idiomatic convenience function to call `freeze` with `mode == False`. See docstring of `freeze` for further - information. + Unfreeze parameters of the specified modules and those of all their hierarchical descendants. This is done in place. + Args: + root_module (nn.Module): Root module relative to which `submodules` are referenced. + submodules (list[str]): List of submodules for which the parameters will be (un)frozen. They are to be provided + as named modules relative to the root module (accessible via `root_module.named_modules()`). An empty + list means that the whole root module will be unfrozen. Defaults to `[]`. + include_bn_running_stats (bool): Whether to also unfreeze the running statistics of `FrozenBatchNorm2d` layers. + These will be converted to `BatchNorm2d` in place. Defaults to `True`. + + See example in docstring for `freeze`. """ - freeze(modules, root_module=root_module, include_bn_running_stats=include_bn_running_stats, mode=False) + _freeze_unfreeze(root_module, submodules, include_bn_running_stats=include_bn_running_stats, mode="unfreeze") + \ No newline at end of file From 6d2acec1bb1fa40fa8f65771e2f18c805920d4e3 Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Sat, 2 Oct 2021 16:10:11 +0100 Subject: [PATCH 3/5] Fix ordering of tests --- tests/test_utils.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 3e11eacc..b0f890d2 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -44,17 +44,14 @@ def test_freeze_unfreeze(): assert isinstance(model.layer1[0].bn1, BatchNorm2d) assert model.layer2[0].conv1.weight.requires_grad == True - # Freeze BN + # Freeze/unfreeze BN # From root freeze(model, ['layer1.0.bn1']) assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) - # From direct parent - freeze(model.layer1[0], ['bn1']) - assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) - - # Unfreeze BN unfreeze(model, ['layer1.0.bn1']) assert isinstance(model.layer1[0].bn1, BatchNorm2d) # From direct parent + freeze(model.layer1[0], ['bn1']) + assert isinstance(model.layer1[0].bn1, FrozenBatchNorm2d) unfreeze(model.layer1[0], ['bn1']) assert isinstance(model.layer1[0].bn1, BatchNorm2d) \ No newline at end of file From 431e60c83fb5db0991a5af55f9e9b635ecea9d8d Mon Sep 17 00:00:00 2001 From: Alexander Soare Date: Wed, 6 Oct 2021 14:28:49 +0100 Subject: [PATCH 4/5] Add acknowledgements for freeze_batch_norm inspiration --- timm/utils/model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/timm/utils/model.py b/timm/utils/model.py index c2786401..ffe66049 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -108,6 +108,8 @@ def freeze_batch_norm_2d(module): Returns: torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 """ res = module if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): @@ -139,6 +141,8 @@ def unfreeze_batch_norm_2d(module): Returns: torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 """ res = module if isinstance(module, FrozenBatchNorm2d): From e5da481073ac4beb634ee9b33e264baa3bee8688 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 6 Oct 2021 17:00:27 -0700 Subject: [PATCH 5/5] Small post-merge tweak for freeze/unfreeze, add to __init__ for utils --- timm/utils/__init__.py | 2 +- timm/utils/model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/timm/utils/__init__.py b/timm/utils/__init__.py index d02e62d2..11de9c9c 100644 --- a/timm/utils/__init__.py +++ b/timm/utils/__init__.py @@ -7,7 +7,7 @@ from .jit import set_jit_legacy from .log import setup_default_logging, FormatterNoInfo from .metrics import AverageMeter, accuracy from .misc import natural_key, add_bool_arg -from .model import unwrap_model, get_state_dict +from .model import unwrap_model, get_state_dict, freeze, unfreeze from .model_ema import ModelEma, ModelEmaV2 from .random import random_seed from .summary import update_summary, get_outdir diff --git a/timm/utils/model.py b/timm/utils/model.py index ffe66049..879ac3f8 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -194,7 +194,7 @@ def _freeze_unfreeze(root_module, submodules=[], include_bn_running_stats=True, for n, m in zip(named_modules, submodules): # (Un)freeze parameters for p in m.parameters(): - p.requires_grad = (False if mode == 'freeze' else True) + p.requires_grad = False if mode == 'freeze' else True if include_bn_running_stats: # Helper to add submodule specified as a named_module def _add_submodule(module, name, submodule):