diff --git a/timm/utils/model.py b/timm/utils/model.py index c2786401..ffe66049 100644 --- a/timm/utils/model.py +++ b/timm/utils/model.py @@ -108,6 +108,8 @@ def freeze_batch_norm_2d(module): Returns: torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 """ res = module if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)): @@ -139,6 +141,8 @@ def unfreeze_batch_norm_2d(module): Returns: torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 """ res = module if isinstance(module, FrozenBatchNorm2d):