A few minor things in SplitBN

pull/74/head
Ross Wightman 5 years ago
parent 7547119891
commit 833066b540

@ -6,9 +6,9 @@ import torch.nn.functional as F
class SplitBatchNorm2d(torch.nn.BatchNorm2d): class SplitBatchNorm2d(torch.nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True, num_splits=1): track_running_stats=True, num_splits=2):
super().__init__(num_features, eps, momentum, affine, track_running_stats) super().__init__(num_features, eps, momentum, affine, track_running_stats)
assert num_splits >= 2, 'Should have at least one aux BN layer (num_splits at least 2)' assert num_splits > 1, 'Should have at least one aux BN layer (num_splits at least 2)'
self.num_splits = num_splits self.num_splits = num_splits
self.aux_bn = nn.ModuleList([ self.aux_bn = nn.ModuleList([
nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)]) nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats) for _ in range(num_splits - 1)])
@ -35,8 +35,7 @@ def convert_splitbn_model(module, num_splits=2):
num_splits: number of separate batchnorm layers to split input across num_splits: number of separate batchnorm layers to split input across
Example:: Example::
>>> # model is an instance of torch.nn.Module >>> # model is an instance of torch.nn.Module
>>> import apex >>> model = timm.models.convert_splitbn_model(model, num_splits=2)
>>> sync_bn_model = timm.models.convert_splitbn_model(model, num_splits=2)
""" """
mod = module mod = module
if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm): if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):

Loading…
Cancel
Save