Support BatchNormAct2d for sync-bn use. Fix #1254

pull/1317/head
Ross Wightman 2 years ago
parent 7cedc8d474
commit 879df47c0a

@ -61,7 +61,7 @@ from .xcit import *
from .factory import create_model, parse_model_name, safe_model_name from .factory import create_model, parse_model_name, safe_model_name
from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .helpers import load_checkpoint, resume_checkpoint, model_parameters
from .layers import TestTimePoolHead, apply_test_time_pool from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model from .layers import convert_splitbn_model, convert_sync_batchnorm
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value

@ -26,7 +26,7 @@ from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, LayerNorm2d from .norm import GroupNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
from .padding import get_padding, get_same_padding, pad_same from .padding import get_padding, get_same_padding, pad_same
from .patch_embed import PatchEmbed from .patch_embed import PatchEmbed
from .pool2d_same import AvgPool2dSame, create_pool2d from .pool2d_same import AvgPool2dSame, create_pool2d

@ -1,10 +1,15 @@
""" Normalization + Activation Layers """ Normalization + Activation Layers
""" """
from typing import Union, List from typing import Union, List, Optional, Any
import torch import torch
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
try:
from torch.nn.modules._functions import SyncBatchNorm as sync_batch_norm
FULL_SYNC_BN = True
except ImportError:
FULL_SYNC_BN = False
from .trace_utils import _assert from .trace_utils import _assert
from .create_act import get_act_layer from .create_act import get_act_layer
@ -18,8 +23,27 @@ class BatchNormAct2d(nn.BatchNorm2d):
instead of composing it as a .bn member. instead of composing it as a .bn member.
""" """
def __init__( def __init__(
self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, self,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
device=None,
dtype=None
):
try:
factory_kwargs = {'device': device, 'dtype': dtype}
super(BatchNormAct2d, self).__init__(
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats,
**factory_kwargs
)
except TypeError:
# NOTE for backwards compat with old PyTorch w/o factory device/dtype support
super(BatchNormAct2d, self).__init__( super(BatchNormAct2d, self).__init__(
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.drop = drop_layer() if drop_layer is not None else nn.Identity()
@ -81,6 +105,62 @@ class BatchNormAct2d(nn.BatchNorm2d):
return x return x
class SyncBatchNormAct(nn.SyncBatchNorm):
# Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254)
# This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers
# but ONLY when used in conjunction with the timm conversion function below.
# Do not create this module directly or use the PyTorch conversion function.
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = super().forward(x) # SyncBN doesn't work with torchscript anyways, so this is fine
if hasattr(self, "drop"):
x = self.drop(x)
if hasattr(self, "act"):
x = self.act(x)
return x
def convert_sync_batchnorm(module, process_group=None):
# convert both BatchNorm and BatchNormAct layers to Synchronized variants
module_output = module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
if isinstance(module, BatchNormAct2d):
# convert timm norm + act layer
module_output = SyncBatchNormAct(
module.num_features,
module.eps,
module.momentum,
module.affine,
module.track_running_stats,
process_group=process_group,
)
# set act and drop attr from the original module
module_output.act = module.act
module_output.drop = module.drop
else:
# convert standard BatchNorm layers
module_output = torch.nn.SyncBatchNorm(
module.num_features,
module.eps,
module.momentum,
module.affine,
module.track_running_stats,
process_group,
)
if module.affine:
with torch.no_grad():
module_output.weight = module.weight
module_output.bias = module.bias
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
if hasattr(module, "qconfig"):
module_output.qconfig = module.qconfig
for name, child in module.named_children():
module_output.add_module(name, convert_sync_batchnorm(child, process_group))
del module
return module_output
def _num_groups(num_channels, num_groups, group_size): def _num_groups(num_channels, num_groups, group_size):
if group_size: if group_size:
assert num_channels % group_size == 0 assert num_channels % group_size == 0

@ -15,10 +15,9 @@ NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
""" """
import argparse import argparse
import time
import yaml
import os
import logging import logging
import os
import time
from collections import OrderedDict from collections import OrderedDict
from contextlib import suppress from contextlib import suppress
from datetime import datetime from datetime import datetime
@ -26,14 +25,15 @@ from datetime import datetime
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision.utils import torchvision.utils
import yaml
from torch.nn.parallel import DistributedDataParallel as NativeDDP from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\
convert_splitbn_model, model_parameters
from timm import utils from timm import utils
from timm.loss import JsdCrossEntropy, BinaryCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy,\ from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, \
LabelSmoothingCrossEntropy LabelSmoothingCrossEntropy
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \
convert_splitbn_model, convert_sync_batchnorm, model_parameters
from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler from timm.utils import ApexScaler, NativeScaler
@ -440,10 +440,11 @@ def main():
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
assert not args.split_bn assert not args.split_bn
if has_apex and use_amp == 'apex': if has_apex and use_amp == 'apex':
# Apex SyncBN preferred unless native amp is activated # Apex SyncBN used with Apex AMP
# WARNING this won't currently work with models using BatchNormAct2d
model = convert_syncbn_model(model) model = convert_syncbn_model(model)
else: else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = convert_sync_batchnorm(model)
if args.local_rank == 0: if args.local_rank == 0:
_logger.info( _logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '

Loading…
Cancel
Save