pytorch-image-models/timm/models/layers/split_batchnorm.py

76 lines
3.4 KiB

""" Split BatchNorm
A PyTorch BatchNorm layer that splits input batch into N equal parts and passes each through
a separate BN layer. The first split is passed through the parent BN layers with weight/bias
keys the same as the original BN. All other splits pass through BN sub-layers under the '.aux_bn'
namespace.
This allows easily removing the auxiliary BN layers after training to efficiently
achieve the 'Auxiliary BatchNorm' as described in the AdvProp Paper, section 4.2,
'Disentangled Learning via An Auxiliary BN'
Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
import torch.nn as nn
class SplitBatchNorm2d(torch.nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True, num_splits=2):
super().__init__(num_features, eps, momentum, affine, track_running_stats)
assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)'
self.num_splits = num_splits
self.aux_bn = nn.ModuleList([
nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)])
def forward(self, input: torch.Tensor):
if self.training: # aux BN only relevant while training
split_size = input.shape[0] // self.num_splits
assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits"
split_input = input.split(split_size)
x = [super().forward(split_input[0])]
for i, a in enumerate(self.aux_bn):
x.append(a(split_input[i + 1]))
return torch.cat(x, dim=0)
else:
return super().forward(input)
def convert_splitbn_model(module, num_splits=2):
"""
Recursively traverse module and its children to replace all instances of
``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`.
Args:
module (torch.nn.Module): input module
num_splits: number of separate batchnorm layers to split input across
Example::
>>> # model is an instance of torch.nn.Module
>>> model = timm.models.convert_splitbn_model(model, num_splits=2)
"""
mod = module
if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):
return module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
mod = SplitBatchNorm2d(
module.num_features, module.eps, module.momentum, module.affine,
module.track_running_stats, num_splits=num_splits)
mod.running_mean = module.running_mean
mod.running_var = module.running_var
mod.num_batches_tracked = module.num_batches_tracked
if module.affine:
mod.weight.data = module.weight.data.clone().detach()
mod.bias.data = module.bias.data.clone().detach()
for aux in mod.aux_bn:
aux.running_mean = module.running_mean.clone()
aux.running_var = module.running_var.clone()
aux.num_batches_tracked = module.num_batches_tracked.clone()
if module.affine:
aux.weight.data = module.weight.data.clone().detach()
aux.bias.data = module.bias.data.clone().detach()
for name, child in module.named_children():
mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits))
del module
return mod