|
|
|
@ -108,6 +108,8 @@ def freeze_batch_norm_2d(module):
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
torch.nn.Module: Resulting module
|
|
|
|
|
|
|
|
|
|
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
|
|
|
|
|
"""
|
|
|
|
|
res = module
|
|
|
|
|
if isinstance(module, (torch.nn.modules.batchnorm.BatchNorm2d, torch.nn.modules.batchnorm.SyncBatchNorm)):
|
|
|
|
@ -139,6 +141,8 @@ def unfreeze_batch_norm_2d(module):
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
torch.nn.Module: Resulting module
|
|
|
|
|
|
|
|
|
|
Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762
|
|
|
|
|
"""
|
|
|
|
|
res = module
|
|
|
|
|
if isinstance(module, FrozenBatchNorm2d):
|
|
|
|
|