Merge pull request #1317 from rwightman/fixes-syncbn_pretrain_cfg_resolve

Fix SyncBatchNorm for BatchNormAc2d, improve resolve_pretrained_cfg behaviour, other mix fixes.
pull/1322/head
Ross Wightman 3 years ago committed by GitHub
commit beef62e7ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -61,7 +61,7 @@ from .xcit import *
from .factory import create_model, parse_model_name, safe_model_name
from .helpers import load_checkpoint, resume_checkpoint, model_parameters
from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model
from .layers import convert_splitbn_model, convert_sync_batchnorm
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value

@ -455,18 +455,26 @@ def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter):
filter_kwargs(kwargs, names=kwargs_filter)
def resolve_pretrained_cfg(variant: str, pretrained_cfg=None, kwargs=None):
def resolve_pretrained_cfg(variant: str, pretrained_cfg=None):
if pretrained_cfg and isinstance(pretrained_cfg, dict):
# highest priority, pretrained_cfg available and passed explicitly
# highest priority, pretrained_cfg available and passed as arg
return deepcopy(pretrained_cfg)
if kwargs and 'pretrained_cfg' in kwargs:
# next highest, pretrained_cfg in a kwargs dict, pop and return
pretrained_cfg = kwargs.pop('pretrained_cfg', {})
if pretrained_cfg:
return deepcopy(pretrained_cfg)
# lookup pretrained cfg in model registry by variant
# fallback to looking up pretrained cfg in model registry by variant identifier
pretrained_cfg = get_pretrained_cfg(variant)
assert pretrained_cfg
if not pretrained_cfg:
_logger.warning(
f"No pretrained configuration specified for {variant} model. Using a default."
f" Please add a config to the model pretrained_cfg registry or pass explicitly.")
pretrained_cfg = dict(
url='',
num_classes=1000,
input_size=(3, 224, 224),
pool_size=None,
crop_pct=.9,
interpolation='bicubic',
first_conv='',
classifier='',
)
return pretrained_cfg

@ -428,7 +428,7 @@ class InceptionV3Aux(InceptionV3):
def _create_inception_v3(variant, pretrained=False, **kwargs):
pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs)
pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
aux_logits = kwargs.pop('aux_logits', False)
if aux_logits:
assert not kwargs.pop('features_only', False)

@ -26,7 +26,7 @@ from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct
from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
from .padding import get_padding, get_same_padding, pad_same
from .patch_embed import PatchEmbed
from .pool2d_same import AvgPool2dSame, create_pool2d

@ -164,3 +164,6 @@ class DropPath(nn.Module):
def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f'drop_prob={round(self.drop_prob,3):0.3f}'

@ -256,8 +256,9 @@ class EvoNorm2dS0a(EvoNorm2dS0):
class EvoNorm2dS1(nn.Module):
def __init__(
self, num_features, groups=32, group_size=None,
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
apply_act=True, act_layer=None, eps=1e-5, **_):
super().__init__()
act_layer = act_layer or nn.SiLU
self.apply_act = apply_act # apply activation (non-linearity)
if act_layer is not None and apply_act:
self.act = create_act_layer(act_layer)
@ -290,7 +291,7 @@ class EvoNorm2dS1(nn.Module):
class EvoNorm2dS1a(EvoNorm2dS1):
def __init__(
self, num_features, groups=32, group_size=None,
apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_):
apply_act=True, act_layer=None, eps=1e-3, **_):
super().__init__(
num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps)
@ -305,8 +306,9 @@ class EvoNorm2dS1a(EvoNorm2dS1):
class EvoNorm2dS2(nn.Module):
def __init__(
self, num_features, groups=32, group_size=None,
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_):
apply_act=True, act_layer=None, eps=1e-5, **_):
super().__init__()
act_layer = act_layer or nn.SiLU
self.apply_act = apply_act # apply activation (non-linearity)
if act_layer is not None and apply_act:
self.act = create_act_layer(act_layer)
@ -338,7 +340,7 @@ class EvoNorm2dS2(nn.Module):
class EvoNorm2dS2a(EvoNorm2dS2):
def __init__(
self, num_features, groups=32, group_size=None,
apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_):
apply_act=True, act_layer=None, eps=1e-3, **_):
super().__init__(
num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps)

@ -1,6 +1,6 @@
""" Normalization + Activation Layers
"""
from typing import Union, List
from typing import Union, List, Optional, Any
import torch
from torch import nn as nn
@ -18,10 +18,29 @@ class BatchNormAct2d(nn.BatchNorm2d):
instead of composing it as a .bn member.
"""
def __init__(
self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
super(BatchNormAct2d, self).__init__(
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
device=None,
dtype=None
):
try:
factory_kwargs = {'device': device, 'dtype': dtype}
super(BatchNormAct2d, self).__init__(
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats,
**factory_kwargs
)
except TypeError:
# NOTE for backwards compat with old PyTorch w/o factory device/dtype support
super(BatchNormAct2d, self).__init__(
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
@ -81,6 +100,62 @@ class BatchNormAct2d(nn.BatchNorm2d):
return x
class SyncBatchNormAct(nn.SyncBatchNorm):
# Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254)
# This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers
# but ONLY when used in conjunction with the timm conversion function below.
# Do not create this module directly or use the PyTorch conversion function.
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = super().forward(x) # SyncBN doesn't work with torchscript anyways, so this is fine
if hasattr(self, "drop"):
x = self.drop(x)
if hasattr(self, "act"):
x = self.act(x)
return x
def convert_sync_batchnorm(module, process_group=None):
# convert both BatchNorm and BatchNormAct layers to Synchronized variants
module_output = module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
if isinstance(module, BatchNormAct2d):
# convert timm norm + act layer
module_output = SyncBatchNormAct(
module.num_features,
module.eps,
module.momentum,
module.affine,
module.track_running_stats,
process_group=process_group,
)
# set act and drop attr from the original module
module_output.act = module.act
module_output.drop = module.drop
else:
# convert standard BatchNorm layers
module_output = torch.nn.SyncBatchNorm(
module.num_features,
module.eps,
module.momentum,
module.affine,
module.track_running_stats,
process_group,
)
if module.affine:
with torch.no_grad():
module_output.weight = module.weight
module_output.bias = module.bias
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
if hasattr(module, "qconfig"):
module_output.qconfig = module.qconfig
for name, child in module.named_children():
module_output.add_module(name, convert_sync_batchnorm(child, process_group))
del module
return module_output
def _num_groups(num_channels, num_groups, group_size):
if group_size:
assert num_channels % group_size == 0

@ -633,7 +633,7 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.')
pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs)
pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
model = build_model_with_cfg(
VisionTransformer, variant, pretrained,
pretrained_cfg=pretrained_cfg,

@ -16,7 +16,7 @@ import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply
from .helpers import build_model_with_cfg, named_apply
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple
from .registry import register_model

@ -15,10 +15,9 @@ NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
"""
import argparse
import time
import yaml
import os
import logging
import os
import time
from collections import OrderedDict
from contextlib import suppress
from datetime import datetime
@ -26,14 +25,15 @@ from datetime import datetime
import torch
import torch.nn as nn
import torchvision.utils
import yaml
from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\
convert_splitbn_model, model_parameters
from timm import utils
from timm.loss import JsdCrossEntropy, BinaryCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy,\
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, \
LabelSmoothingCrossEntropy
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \
convert_splitbn_model, convert_sync_batchnorm, model_parameters
from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler
@ -438,12 +438,14 @@ def main():
# setup synchronized BatchNorm for distributed training
if args.distributed and args.sync_bn:
args.dist_bn = '' # disable dist_bn when sync BN active
assert not args.split_bn
if has_apex and use_amp == 'apex':
# Apex SyncBN preferred unless native amp is activated
# Apex SyncBN used with Apex AMP
# WARNING this won't currently work with models using BatchNormAct2d
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = convert_sync_batchnorm(model)
if args.local_rank == 0:
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '

Loading…
Cancel
Save