From 879df47c0a7e8108545a1ca1fcbfa88ca2714778 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jun 2022 14:51:26 -0700 Subject: [PATCH 1/7] Support BatchNormAct2d for sync-bn use. Fix #1254 --- timm/models/__init__.py | 2 +- timm/models/layers/__init__.py | 2 +- timm/models/layers/norm_act.py | 90 ++++++++++++++++++++++++++++++++-- train.py | 19 +++---- 4 files changed, 97 insertions(+), 16 deletions(-) diff --git a/timm/models/__init__.py b/timm/models/__init__.py index 8cb6c70a..4f81683a 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -61,7 +61,7 @@ from .xcit import * from .factory import create_model, parse_model_name, safe_model_name from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .layers import TestTimePoolHead, apply_test_time_pool -from .layers import convert_splitbn_model +from .layers import convert_splitbn_model, convert_sync_batchnorm from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 7e9e7b19..b1a64db3 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -26,7 +26,7 @@ from .mixed_conv2d import MixedConv2d from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .norm import GroupNorm, LayerNorm2d -from .norm_act import BatchNormAct2d, GroupNormAct +from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm from .padding import get_padding, get_same_padding, pad_same from .patch_embed import PatchEmbed from .pool2d_same import AvgPool2dSame, create_pool2d diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index 34c4fd64..261bdb0a 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -1,10 +1,15 @@ """ Normalization + Activation Layers """ -from typing import Union, List +from typing import Union, List, Optional, Any import torch from torch import nn as nn from torch.nn import functional as F +try: + from torch.nn.modules._functions import SyncBatchNorm as sync_batch_norm + FULL_SYNC_BN = True +except ImportError: + FULL_SYNC_BN = False from .trace_utils import _assert from .create_act import get_act_layer @@ -18,10 +23,29 @@ class BatchNormAct2d(nn.BatchNorm2d): instead of composing it as a .bn member. """ def __init__( - self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, - apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): - super(BatchNormAct2d, self).__init__( - num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) + self, + num_features, + eps=1e-5, + momentum=0.1, + affine=True, + track_running_stats=True, + apply_act=True, + act_layer=nn.ReLU, + inplace=True, + drop_layer=None, + device=None, + dtype=None + ): + try: + factory_kwargs = {'device': device, 'dtype': dtype} + super(BatchNormAct2d, self).__init__( + num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats, + **factory_kwargs + ) + except TypeError: + # NOTE for backwards compat with old PyTorch w/o factory device/dtype support + super(BatchNormAct2d, self).__init__( + num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) self.drop = drop_layer() if drop_layer is not None else nn.Identity() act_layer = get_act_layer(act_layer) # string -> nn.Module if act_layer is not None and apply_act: @@ -81,6 +105,62 @@ class BatchNormAct2d(nn.BatchNorm2d): return x +class SyncBatchNormAct(nn.SyncBatchNorm): + # Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254) + # This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers + # but ONLY when used in conjunction with the timm conversion function below. + # Do not create this module directly or use the PyTorch conversion function. + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = super().forward(x) # SyncBN doesn't work with torchscript anyways, so this is fine + if hasattr(self, "drop"): + x = self.drop(x) + if hasattr(self, "act"): + x = self.act(x) + return x + + +def convert_sync_batchnorm(module, process_group=None): + # convert both BatchNorm and BatchNormAct layers to Synchronized variants + module_output = module + if isinstance(module, torch.nn.modules.batchnorm._BatchNorm): + if isinstance(module, BatchNormAct2d): + # convert timm norm + act layer + module_output = SyncBatchNormAct( + module.num_features, + module.eps, + module.momentum, + module.affine, + module.track_running_stats, + process_group=process_group, + ) + # set act and drop attr from the original module + module_output.act = module.act + module_output.drop = module.drop + else: + # convert standard BatchNorm layers + module_output = torch.nn.SyncBatchNorm( + module.num_features, + module.eps, + module.momentum, + module.affine, + module.track_running_stats, + process_group, + ) + if module.affine: + with torch.no_grad(): + module_output.weight = module.weight + module_output.bias = module.bias + module_output.running_mean = module.running_mean + module_output.running_var = module.running_var + module_output.num_batches_tracked = module.num_batches_tracked + if hasattr(module, "qconfig"): + module_output.qconfig = module.qconfig + for name, child in module.named_children(): + module_output.add_module(name, convert_sync_batchnorm(child, process_group)) + del module + return module_output + + def _num_groups(num_channels, num_groups, group_size): if group_size: assert num_channels % group_size == 0 diff --git a/train.py b/train.py index 444be066..2a68e05e 100755 --- a/train.py +++ b/train.py @@ -15,10 +15,9 @@ NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) """ import argparse -import time -import yaml -import os import logging +import os +import time from collections import OrderedDict from contextlib import suppress from datetime import datetime @@ -26,14 +25,15 @@ from datetime import datetime import torch import torch.nn as nn import torchvision.utils +import yaml from torch.nn.parallel import DistributedDataParallel as NativeDDP -from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset -from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\ - convert_splitbn_model, model_parameters from timm import utils -from timm.loss import JsdCrossEntropy, BinaryCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy,\ +from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset +from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, \ LabelSmoothingCrossEntropy +from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \ + convert_splitbn_model, convert_sync_batchnorm, model_parameters from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.scheduler import create_scheduler from timm.utils import ApexScaler, NativeScaler @@ -440,10 +440,11 @@ def main(): if args.distributed and args.sync_bn: assert not args.split_bn if has_apex and use_amp == 'apex': - # Apex SyncBN preferred unless native amp is activated + # Apex SyncBN used with Apex AMP + # WARNING this won't currently work with models using BatchNormAct2d model = convert_syncbn_model(model) else: - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = convert_sync_batchnorm(model) if args.local_rank == 0: _logger.info( 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' From 7d657d2ef45fc841f3a987ca0f18868686dbecf5 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jun 2022 14:55:25 -0700 Subject: [PATCH 2/7] Improve resolve_pretrained_cfg behaviour when no cfg exists, warn instead of crash. Improve usability ex #1311 --- timm/models/helpers.py | 27 ++++++++++++++++-------- timm/models/inception_v3.py | 2 +- timm/models/vision_transformer.py | 2 +- timm/models/vision_transformer_relpos.py | 2 +- 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 1276b68e..11630bb6 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -455,18 +455,27 @@ def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter): filter_kwargs(kwargs, names=kwargs_filter) -def resolve_pretrained_cfg(variant: str, pretrained_cfg=None, kwargs=None): +def resolve_pretrained_cfg(variant: str, **kwargs): + pretrained_cfg = kwargs.pop('pretrained_cfg', None) if pretrained_cfg and isinstance(pretrained_cfg, dict): - # highest priority, pretrained_cfg available and passed explicitly + # highest priority, pretrained_cfg available and passed in args return deepcopy(pretrained_cfg) - if kwargs and 'pretrained_cfg' in kwargs: - # next highest, pretrained_cfg in a kwargs dict, pop and return - pretrained_cfg = kwargs.pop('pretrained_cfg', {}) - if pretrained_cfg: - return deepcopy(pretrained_cfg) - # lookup pretrained cfg in model registry by variant + # fallback to looking up pretrained cfg in model registry by variant identifier pretrained_cfg = get_pretrained_cfg(variant) - assert pretrained_cfg + if not pretrained_cfg: + _logger.warning( + f"No pretrained configuration specified for {variant} model. Using a default." + f" Please add a config to the model pretrained_cfg registry or pass explicitly.") + pretrained_cfg = dict( + url='', + num_classes=1000, + input_size=(3, 224, 224), + pool_size=None, + crop_pct=.9, + interpolation='bicubic', + first_conv='', + classifier='', + ) return pretrained_cfg diff --git a/timm/models/inception_v3.py b/timm/models/inception_v3.py index e34de657..2c6e7eb7 100644 --- a/timm/models/inception_v3.py +++ b/timm/models/inception_v3.py @@ -428,7 +428,7 @@ class InceptionV3Aux(InceptionV3): def _create_inception_v3(variant, pretrained=False, **kwargs): - pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) + pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None)) aux_logits = kwargs.pop('aux_logits', False) if aux_logits: assert not kwargs.pop('features_only', False) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 59fd7849..8551feae 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -633,7 +633,7 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs): if kwargs.get('features_only', None): raise RuntimeError('features_only not implemented for Vision Transformer models.') - pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) + pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None)) model = build_model_with_cfg( VisionTransformer, variant, pretrained, pretrained_cfg=pretrained_cfg, diff --git a/timm/models/vision_transformer_relpos.py b/timm/models/vision_transformer_relpos.py index 0c2ac376..0c9ac989 100644 --- a/timm/models/vision_transformer_relpos.py +++ b/timm/models/vision_transformer_relpos.py @@ -16,7 +16,7 @@ import torch.nn.functional as F from torch.utils.checkpoint import checkpoint from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD -from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply +from .helpers import build_model_with_cfg, named_apply from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple from .registry import register_model From 0da3c9ebbf483f76f51a366bebfd3d3589f74a0c Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jun 2022 14:56:58 -0700 Subject: [PATCH 3/7] Remove SiLU layer in default args that breaks import on old old PyTorch --- timm/models/layers/evo_norm.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index b643302c..ea776207 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -256,8 +256,9 @@ class EvoNorm2dS0a(EvoNorm2dS0): class EvoNorm2dS1(nn.Module): def __init__( self, num_features, groups=32, group_size=None, - apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_): + apply_act=True, act_layer=None, eps=1e-5, **_): super().__init__() + act_layer = act_layer or nn.SiLU self.apply_act = apply_act # apply activation (non-linearity) if act_layer is not None and apply_act: self.act = create_act_layer(act_layer) @@ -290,7 +291,7 @@ class EvoNorm2dS1(nn.Module): class EvoNorm2dS1a(EvoNorm2dS1): def __init__( self, num_features, groups=32, group_size=None, - apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_): + apply_act=True, act_layer=None, eps=1e-3, **_): super().__init__( num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps) @@ -305,8 +306,9 @@ class EvoNorm2dS1a(EvoNorm2dS1): class EvoNorm2dS2(nn.Module): def __init__( self, num_features, groups=32, group_size=None, - apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_): + apply_act=True, act_layer=None, eps=1e-5, **_): super().__init__() + act_layer = act_layer or nn.SiLU self.apply_act = apply_act # apply activation (non-linearity) if act_layer is not None and apply_act: self.act = create_act_layer(act_layer) @@ -338,7 +340,7 @@ class EvoNorm2dS2(nn.Module): class EvoNorm2dS2a(EvoNorm2dS2): def __init__( self, num_features, groups=32, group_size=None, - apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_): + apply_act=True, act_layer=None, eps=1e-3, **_): super().__init__( num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps) From e27c16b8a02eaeae2d42d969f4b137e8e940dbe1 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jun 2022 14:57:42 -0700 Subject: [PATCH 4/7] Remove unecessary code for synbn guard --- timm/models/layers/norm_act.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/timm/models/layers/norm_act.py b/timm/models/layers/norm_act.py index 261bdb0a..ea5b7883 100644 --- a/timm/models/layers/norm_act.py +++ b/timm/models/layers/norm_act.py @@ -5,11 +5,6 @@ from typing import Union, List, Optional, Any import torch from torch import nn as nn from torch.nn import functional as F -try: - from torch.nn.modules._functions import SyncBatchNorm as sync_batch_norm - FULL_SYNC_BN = True -except ImportError: - FULL_SYNC_BN = False from .trace_utils import _assert from .create_act import get_act_layer From 07d0c4ae963481a225ce8243f047408d96f13fab Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jun 2022 14:58:15 -0700 Subject: [PATCH 5/7] Improve repr for DropPath module --- timm/models/layers/drop.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/timm/models/layers/drop.py b/timm/models/layers/drop.py index ae065277..1ab1c8f5 100644 --- a/timm/models/layers/drop.py +++ b/timm/models/layers/drop.py @@ -164,3 +164,6 @@ class DropPath(nn.Module): def forward(self, x): return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f'drop_prob={round(self.drop_prob,3):0.3f}' From a29fba307dc544ac1d4eeab0043c3aae8d5aa7c8 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jun 2022 21:30:17 -0700 Subject: [PATCH 6/7] disable dist_bn when sync_bn active --- train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/train.py b/train.py index 2a68e05e..285981fd 100755 --- a/train.py +++ b/train.py @@ -438,6 +438,7 @@ def main(): # setup synchronized BatchNorm for distributed training if args.distributed and args.sync_bn: + args.dist_bn = '' # disable dist_bn when sync BN active assert not args.split_bn if has_apex and use_amp == 'apex': # Apex SyncBN used with Apex AMP From e6d7df40ecb9595b9e4f2f5b69f2633f81cee9ce Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Fri, 24 Jun 2022 21:32:44 -0700 Subject: [PATCH 7/7] no longer a point using kwargs for pretrain_cfg resolve, just pass explicit arg --- timm/models/helpers.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/timm/models/helpers.py b/timm/models/helpers.py index 11630bb6..fda84171 100644 --- a/timm/models/helpers.py +++ b/timm/models/helpers.py @@ -455,10 +455,9 @@ def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter): filter_kwargs(kwargs, names=kwargs_filter) -def resolve_pretrained_cfg(variant: str, **kwargs): - pretrained_cfg = kwargs.pop('pretrained_cfg', None) +def resolve_pretrained_cfg(variant: str, pretrained_cfg=None): if pretrained_cfg and isinstance(pretrained_cfg, dict): - # highest priority, pretrained_cfg available and passed in args + # highest priority, pretrained_cfg available and passed as arg return deepcopy(pretrained_cfg) # fallback to looking up pretrained cfg in model registry by variant identifier pretrained_cfg = get_pretrained_cfg(variant)