Merge pull request #1317 from rwightman/fixes-syncbn_pretrain_cfg_resolve

Fix SyncBatchNorm for BatchNormAc2d, improve resolve_pretrained_cfg behaviour, other mix fixes.
pull/1322/head
Ross Wightman 3 years ago committed by GitHub
commit beef62e7ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -61,7 +61,7 @@ from .xcit import *
from .factory import create_model, parse_model_name, safe_model_name from .factory import create_model, parse_model_name, safe_model_name
from .helpers import load_checkpoint, resume_checkpoint, model_parameters from .helpers import load_checkpoint, resume_checkpoint, model_parameters
from .layers import TestTimePoolHead, apply_test_time_pool from .layers import TestTimePoolHead, apply_test_time_pool
from .layers import convert_splitbn_model from .layers import convert_splitbn_model, convert_sync_batchnorm
from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit from .layers import is_scriptable, is_exportable, set_scriptable, set_exportable, is_no_jit, set_no_jit
from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\ from .registry import register_model, model_entrypoint, list_models, is_model, list_modules, is_model_in_modules,\
is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value is_model_pretrained, get_pretrained_cfg, has_pretrained_cfg_key, is_pretrained_cfg_key, get_pretrained_cfg_value

@ -455,18 +455,26 @@ def update_pretrained_cfg_and_kwargs(pretrained_cfg, kwargs, kwargs_filter):
filter_kwargs(kwargs, names=kwargs_filter) filter_kwargs(kwargs, names=kwargs_filter)
def resolve_pretrained_cfg(variant: str, pretrained_cfg=None, kwargs=None): def resolve_pretrained_cfg(variant: str, pretrained_cfg=None):
if pretrained_cfg and isinstance(pretrained_cfg, dict): if pretrained_cfg and isinstance(pretrained_cfg, dict):
# highest priority, pretrained_cfg available and passed explicitly # highest priority, pretrained_cfg available and passed as arg
return deepcopy(pretrained_cfg) return deepcopy(pretrained_cfg)
if kwargs and 'pretrained_cfg' in kwargs: # fallback to looking up pretrained cfg in model registry by variant identifier
# next highest, pretrained_cfg in a kwargs dict, pop and return
pretrained_cfg = kwargs.pop('pretrained_cfg', {})
if pretrained_cfg:
return deepcopy(pretrained_cfg)
# lookup pretrained cfg in model registry by variant
pretrained_cfg = get_pretrained_cfg(variant) pretrained_cfg = get_pretrained_cfg(variant)
assert pretrained_cfg if not pretrained_cfg:
_logger.warning(
f"No pretrained configuration specified for {variant} model. Using a default."
f" Please add a config to the model pretrained_cfg registry or pass explicitly.")
pretrained_cfg = dict(
url='',
num_classes=1000,
input_size=(3, 224, 224),
pool_size=None,
crop_pct=.9,
interpolation='bicubic',
first_conv='',
classifier='',
)
return pretrained_cfg return pretrained_cfg

@ -428,7 +428,7 @@ class InceptionV3Aux(InceptionV3):
def _create_inception_v3(variant, pretrained=False, **kwargs): def _create_inception_v3(variant, pretrained=False, **kwargs):
pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
aux_logits = kwargs.pop('aux_logits', False) aux_logits = kwargs.pop('aux_logits', False)
if aux_logits: if aux_logits:
assert not kwargs.pop('features_only', False) assert not kwargs.pop('features_only', False)

@ -26,7 +26,7 @@ from .mixed_conv2d import MixedConv2d
from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp from .mlp import Mlp, GluMlp, GatedMlp, ConvMlp
from .non_local_attn import NonLocalAttn, BatNonLocalAttn from .non_local_attn import NonLocalAttn, BatNonLocalAttn
from .norm import GroupNorm, LayerNorm2d from .norm import GroupNorm, LayerNorm2d
from .norm_act import BatchNormAct2d, GroupNormAct from .norm_act import BatchNormAct2d, GroupNormAct, convert_sync_batchnorm
from .padding import get_padding, get_same_padding, pad_same from .padding import get_padding, get_same_padding, pad_same
from .patch_embed import PatchEmbed from .patch_embed import PatchEmbed
from .pool2d_same import AvgPool2dSame, create_pool2d from .pool2d_same import AvgPool2dSame, create_pool2d

@ -164,3 +164,6 @@ class DropPath(nn.Module):
def forward(self, x): def forward(self, x):
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
def extra_repr(self):
return f'drop_prob={round(self.drop_prob,3):0.3f}'

@ -256,8 +256,9 @@ class EvoNorm2dS0a(EvoNorm2dS0):
class EvoNorm2dS1(nn.Module): class EvoNorm2dS1(nn.Module):
def __init__( def __init__(
self, num_features, groups=32, group_size=None, self, num_features, groups=32, group_size=None,
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_): apply_act=True, act_layer=None, eps=1e-5, **_):
super().__init__() super().__init__()
act_layer = act_layer or nn.SiLU
self.apply_act = apply_act # apply activation (non-linearity) self.apply_act = apply_act # apply activation (non-linearity)
if act_layer is not None and apply_act: if act_layer is not None and apply_act:
self.act = create_act_layer(act_layer) self.act = create_act_layer(act_layer)
@ -290,7 +291,7 @@ class EvoNorm2dS1(nn.Module):
class EvoNorm2dS1a(EvoNorm2dS1): class EvoNorm2dS1a(EvoNorm2dS1):
def __init__( def __init__(
self, num_features, groups=32, group_size=None, self, num_features, groups=32, group_size=None,
apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_): apply_act=True, act_layer=None, eps=1e-3, **_):
super().__init__( super().__init__(
num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps) num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps)
@ -305,8 +306,9 @@ class EvoNorm2dS1a(EvoNorm2dS1):
class EvoNorm2dS2(nn.Module): class EvoNorm2dS2(nn.Module):
def __init__( def __init__(
self, num_features, groups=32, group_size=None, self, num_features, groups=32, group_size=None,
apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_): apply_act=True, act_layer=None, eps=1e-5, **_):
super().__init__() super().__init__()
act_layer = act_layer or nn.SiLU
self.apply_act = apply_act # apply activation (non-linearity) self.apply_act = apply_act # apply activation (non-linearity)
if act_layer is not None and apply_act: if act_layer is not None and apply_act:
self.act = create_act_layer(act_layer) self.act = create_act_layer(act_layer)
@ -338,7 +340,7 @@ class EvoNorm2dS2(nn.Module):
class EvoNorm2dS2a(EvoNorm2dS2): class EvoNorm2dS2a(EvoNorm2dS2):
def __init__( def __init__(
self, num_features, groups=32, group_size=None, self, num_features, groups=32, group_size=None,
apply_act=True, act_layer=nn.SiLU, eps=1e-3, **_): apply_act=True, act_layer=None, eps=1e-3, **_):
super().__init__( super().__init__(
num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps) num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps)

@ -1,6 +1,6 @@
""" Normalization + Activation Layers """ Normalization + Activation Layers
""" """
from typing import Union, List from typing import Union, List, Optional, Any
import torch import torch
from torch import nn as nn from torch import nn as nn
@ -18,8 +18,27 @@ class BatchNormAct2d(nn.BatchNorm2d):
instead of composing it as a .bn member. instead of composing it as a .bn member.
""" """
def __init__( def __init__(
self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, self,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None): num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True,
apply_act=True,
act_layer=nn.ReLU,
inplace=True,
drop_layer=None,
device=None,
dtype=None
):
try:
factory_kwargs = {'device': device, 'dtype': dtype}
super(BatchNormAct2d, self).__init__(
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats,
**factory_kwargs
)
except TypeError:
# NOTE for backwards compat with old PyTorch w/o factory device/dtype support
super(BatchNormAct2d, self).__init__( super(BatchNormAct2d, self).__init__(
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
self.drop = drop_layer() if drop_layer is not None else nn.Identity() self.drop = drop_layer() if drop_layer is not None else nn.Identity()
@ -81,6 +100,62 @@ class BatchNormAct2d(nn.BatchNorm2d):
return x return x
class SyncBatchNormAct(nn.SyncBatchNorm):
# Thanks to Selim Seferbekov (https://github.com/rwightman/pytorch-image-models/issues/1254)
# This is a quick workaround to support SyncBatchNorm for timm BatchNormAct2d layers
# but ONLY when used in conjunction with the timm conversion function below.
# Do not create this module directly or use the PyTorch conversion function.
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = super().forward(x) # SyncBN doesn't work with torchscript anyways, so this is fine
if hasattr(self, "drop"):
x = self.drop(x)
if hasattr(self, "act"):
x = self.act(x)
return x
def convert_sync_batchnorm(module, process_group=None):
# convert both BatchNorm and BatchNormAct layers to Synchronized variants
module_output = module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
if isinstance(module, BatchNormAct2d):
# convert timm norm + act layer
module_output = SyncBatchNormAct(
module.num_features,
module.eps,
module.momentum,
module.affine,
module.track_running_stats,
process_group=process_group,
)
# set act and drop attr from the original module
module_output.act = module.act
module_output.drop = module.drop
else:
# convert standard BatchNorm layers
module_output = torch.nn.SyncBatchNorm(
module.num_features,
module.eps,
module.momentum,
module.affine,
module.track_running_stats,
process_group,
)
if module.affine:
with torch.no_grad():
module_output.weight = module.weight
module_output.bias = module.bias
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
if hasattr(module, "qconfig"):
module_output.qconfig = module.qconfig
for name, child in module.named_children():
module_output.add_module(name, convert_sync_batchnorm(child, process_group))
del module
return module_output
def _num_groups(num_channels, num_groups, group_size): def _num_groups(num_channels, num_groups, group_size):
if group_size: if group_size:
assert num_channels % group_size == 0 assert num_channels % group_size == 0

@ -633,7 +633,7 @@ def _create_vision_transformer(variant, pretrained=False, **kwargs):
if kwargs.get('features_only', None): if kwargs.get('features_only', None):
raise RuntimeError('features_only not implemented for Vision Transformer models.') raise RuntimeError('features_only not implemented for Vision Transformer models.')
pretrained_cfg = resolve_pretrained_cfg(variant, kwargs=kwargs) pretrained_cfg = resolve_pretrained_cfg(variant, pretrained_cfg=kwargs.pop('pretrained_cfg', None))
model = build_model_with_cfg( model = build_model_with_cfg(
VisionTransformer, variant, pretrained, VisionTransformer, variant, pretrained,
pretrained_cfg=pretrained_cfg, pretrained_cfg=pretrained_cfg,

@ -16,7 +16,7 @@ import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint from torch.utils.checkpoint import checkpoint
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg, resolve_pretrained_cfg, named_apply from .helpers import build_model_with_cfg, named_apply
from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple from .layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, to_2tuple
from .registry import register_model from .registry import register_model

@ -15,10 +15,9 @@ NVIDIA CUDA specific speedups adopted from NVIDIA Apex examples
Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman) Hacked together by / Copyright 2020 Ross Wightman (https://github.com/rwightman)
""" """
import argparse import argparse
import time
import yaml
import os
import logging import logging
import os
import time
from collections import OrderedDict from collections import OrderedDict
from contextlib import suppress from contextlib import suppress
from datetime import datetime from datetime import datetime
@ -26,14 +25,15 @@ from datetime import datetime
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision.utils import torchvision.utils
import yaml
from torch.nn.parallel import DistributedDataParallel as NativeDDP from torch.nn.parallel import DistributedDataParallel as NativeDDP
from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint,\
convert_splitbn_model, model_parameters
from timm import utils from timm import utils
from timm.loss import JsdCrossEntropy, BinaryCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy,\ from timm.data import create_dataset, create_loader, resolve_data_config, Mixup, FastCollateMixup, AugMixDataset
from timm.loss import JsdCrossEntropy, SoftTargetCrossEntropy, BinaryCrossEntropy, \
LabelSmoothingCrossEntropy LabelSmoothingCrossEntropy
from timm.models import create_model, safe_model_name, resume_checkpoint, load_checkpoint, \
convert_splitbn_model, convert_sync_batchnorm, model_parameters
from timm.optim import create_optimizer_v2, optimizer_kwargs from timm.optim import create_optimizer_v2, optimizer_kwargs
from timm.scheduler import create_scheduler from timm.scheduler import create_scheduler
from timm.utils import ApexScaler, NativeScaler from timm.utils import ApexScaler, NativeScaler
@ -438,12 +438,14 @@ def main():
# setup synchronized BatchNorm for distributed training # setup synchronized BatchNorm for distributed training
if args.distributed and args.sync_bn: if args.distributed and args.sync_bn:
args.dist_bn = '' # disable dist_bn when sync BN active
assert not args.split_bn assert not args.split_bn
if has_apex and use_amp == 'apex': if has_apex and use_amp == 'apex':
# Apex SyncBN preferred unless native amp is activated # Apex SyncBN used with Apex AMP
# WARNING this won't currently work with models using BatchNormAct2d
model = convert_syncbn_model(model) model = convert_syncbn_model(model)
else: else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) model = convert_sync_batchnorm(model)
if args.local_rank == 0: if args.local_rank == 0:
_logger.info( _logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using ' 'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '

Loading…
Cancel
Save