import torch import torch.nn as nn import torch.nn.functional as F class SplitBatchNorm2d(torch.nn.BatchNorm2d): def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, num_splits=2): super().__init__(num_features, eps, momentum, affine, track_running_stats) assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)' self.num_splits = num_splits self.aux_bn = nn.ModuleList([ nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) def forward(self, input: torch.Tensor): if self.training: # aux BN only relevant while training split_size = input.shape[0] // self.num_splits assert input.shape[0] == split_size * self.num_splits, "batch size must be evenly divisible by num_splits" split_input = input.split(split_size) x = [super().forward(split_input[0])] for i, a in enumerate(self.aux_bn): x.append(a(split_input[i + 1])) return torch.cat(x, dim=0) else: return super().forward(input) def convert_splitbn_model(module, num_splits=2): """ Recursively traverse module and its children to replace all instances of ``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`. Args: module (torch.nn.Module): input module num_splits: number of separate batchnorm layers to split input across Example:: >>> # model is an instance of torch.nn.Module >>> model = timm.models.convert_splitbn_model(model, num_splits=2) """ mod = module if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): return module if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): mod = SplitBatchNorm2d( module.num_features, module.eps, module.momentum, module.affine, module.track_running_stats, num_splits=num_splits) mod.running_mean = module.running_mean mod.running_var = module.running_var mod.num_batches_tracked = module.num_batches_tracked if module.affine: mod.weight.data = module.weight.data.clone().detach() mod.bias.data = module.bias.data.clone().detach() for aux in mod.aux_bn: aux.running_mean = module.running_mean.clone() aux.running_var = module.running_var.clone() aux.num_batches_tracked = module.num_batches_tracked.clone() if module.affine: aux.weight.data = module.weight.data.clone().detach() aux.bias.data = module.bias.data.clone().detach() for name, child in module.named_children(): mod.add_module(name, convert_splitbn_model(child, num_splits=num_splits)) del module return mod