diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 8cb6c70a..4f81683a 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -61,7 +61,7 @@ from .xcit import * from .factory import create_model, parse_model_name, safe_model_name from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .layers import TestTimePoolHead, apply_test_time_pool -from .layers import convert_splitbn_model +from .layers import convert_splitbn_model, convert_sync_batchnorm from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 7e9e7b19..b1a64db3 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -26,7 +26,7 @@ from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .norm import GroupNorm, LayerNorm2d -from .norm_act import BatchNormAct2d, GroupNormAct +from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm from .padding import get_padding, get_same_padding, pad_same from .patch_embed import PatchEmbed from .pool2d_same import AvgPool2dSame, create_pool2d diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index 34c4fd64..261bdb0a 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -1,10 +1,15 @@ """ Normalization + Activation Layers """ -from typing import Union, List +from typing import Union, List, Optional, Any import torch from torch import nn as nn from torch.nn import functional as F +try: + from torch.nn.modules._functions import SyncBatchNorm as sync_batch_norm + FULL_SYNC_BN = True +except ImportError: + FULL_SYNC_BN = False from .trace_utils import _assert from .create_act import get_act_layer @@ -18,10 +23,29 @@ class BatchNormAct2d(nn.BatchNorm2d): instead of composing it as a .bn member. """ def __init__( - self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, - apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): - super(BatchNormAct2d, self).__init__( - num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) + self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + apply_act=True, + act_layer=nn.ReLU, + inplace=True, + drop_layer=None, + device=None, + dtype=None + ): + try: + factory_kwargs = {'device': device, 'dtype': dtype} + super(BatchNormAct2d, self).__init__( + num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats, + **factory_kwargs + ) + except TypeError: + # NOTE for backwards compat with old PyTorch w/o factory device/dtype support + super(BatchNormAct2d, self).__init__( + num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) self.drop = drop_layer() if drop_layer is not None else nn.Identity() act_layer = get_act_layer(act_layer) # string -> nn.Module if act_layer is not None and apply_act: @@ -81,6 +105,62 @@ class BatchNormAct2d(nn.BatchNorm2d): return x +class SyncBatchNormAct(nn.SyncBatchNorm): + # Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254) + # This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers + # but ONLY when used in conjunction with the timm conversion function below. + # Do not create this module directly or use the PyTorch conversion function. + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = super().forward(x) # SyncBN doesn't work with torchscript anyways, so this is fine + if hasattr(self, "drop"): + x = self.drop(x) + if hasattr(self, "act"): + x = self.act(x) + return x + + +def convert_sync_batchnorm(module, process_group=None): + # convert both BatchNorm and BatchNormAct layers to Synchronized variants + module_output = module + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + if isinstance(module, BatchNormAct2d): + # convert timm norm + act layer + module_output = SyncBatchNormAct( + module.num_features, + module.eps, + module.momentum, + module.affine, + module.track_running_stats, + process_group=process_group, + ) + # set act and drop attr from the original module + module_output.act = module.act + module_output.drop = module.drop + else: + # convert standard BatchNorm layers + module_output = torch.nn.SyncBatchNorm( + module.num_features, + module.eps, + module.momentum, + module.affine, + module.track_running_stats, + process_group, + ) + if module.affine: + with torch.no_grad(): + module_output.weight = module.weight + module_output.bias = module.bias + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + if hasattr(module, "qconfig"): + module_output.qconfig = module.qconfig + for name, child in module.named_children(): + module_output.add_module(name, convert_sync_batchnorm(child, process_group)) + del module + return module_output + + def _num_groups(num_channels, num_groups, group_size): if group_size: assert num_channels % group_size == 0 diff --git a/train.py b/train.py index 444be066..2a68e05e 100755 --- a/train.py +++ b/train.py @@ -15,10 +15,9 @@ NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) """ import argparse -import time -import yaml -import os import logging +import os +import time from collections import OrderedDict from contextlib import suppress from datetime import datetime @@ -26,14 +25,15 @@ from datetime import datetime import torch import torch.nn as nn import torchvision.utils +import yaml from torch.nn.parallel import DistributedDataParallel as NativeDDP -from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset -from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\ - convert_splitbn_model, model_parameters from timm import utils -from timm.loss import JsdCrossEntropy, BinaryCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy,\ +from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset +from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, \ LabelSmoothingCrossEntropy +from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \ + convert_splitbn_model, convert_sync_batchnorm, model_parameters from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler from timm.utils import ApexScaler, NativeScaler @@ -440,10 +440,11 @@ def main(): if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp == 'apex': - # Apex SyncBN preferred unless native amp is activated + # Apex SyncBN used with Apex AMP + # WARNING this won't currently work with models using BatchNormAct2d model = convert_syncbn_model(model) else: - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '