Significant norm update

* ConvBnAct layer renamed -> ConvNormAct and ConvNormActAa for anti-aliased
* Significant update to EfficientNet and MobileNetV3 arch to support NormAct layers and grouped conv (as alternative to depthwise)
* Update RegNet to add Z variant
* Add Pre variant of XceptionAligned that works with NormAct layers
* EvoNorm matches bits_and_tpu branch for merge
pull/1239/head
Ross Wightman 2 years ago
parent 57fca2b5b2
commit d829858550

@ -34,8 +34,8 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, named_apply from .helpers import build_model_with_cfg, named_apply
from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ from .layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple, EvoNorm2dS0, EvoNorm2dS0a,\ create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0, EvoNorm2dS0a,\
EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a, FilterResponseNormAct2d, FilterResponseNormTlu2d EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a, FilterResponseNormAct2d, FilterResponseNormTlu2d
from .registry import register_model from .registry import register_model
@ -975,7 +975,7 @@ def num_groups(group_size, channels):
@dataclass @dataclass
class LayerFn: class LayerFn:
conv_norm_act: Callable = ConvBnAct conv_norm_act: Callable = ConvNormAct
norm_act: Callable = BatchNormAct2d norm_act: Callable = BatchNormAct2d
act: Callable = nn.ReLU act: Callable = nn.ReLU
attn: Optional[Callable] = None attn: Optional[Callable] = None
@ -1032,7 +1032,7 @@ class BasicBlock(nn.Module):
self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0]) self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0])
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
self.conv2_kxk = layers.conv_norm_act( self.conv2_kxk = layers.conv_norm_act(
mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block, apply_act=False) mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups, drop_layer=drop_block, apply_act=False)
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True) self.act = nn.Identity() if linear_out else layers.act(inplace=True)
@ -1073,11 +1073,9 @@ class BottleneckBlock(nn.Module):
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1) self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
self.conv2_kxk = layers.conv_norm_act( self.conv2_kxk = layers.conv_norm_act(
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block)
groups=groups, drop_block=drop_block)
if extra_conv: if extra_conv:
self.conv2b_kxk = layers.conv_norm_act( self.conv2b_kxk = layers.conv_norm_act(mid_chs, mid_chs, kernel_size, dilation=dilation[1], groups=groups)
mid_chs, mid_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block)
else: else:
self.conv2b_kxk = nn.Identity() self.conv2b_kxk = nn.Identity()
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
@ -1134,7 +1132,7 @@ class DarkBlock(nn.Module):
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
self.conv2_kxk = layers.conv_norm_act( self.conv2_kxk = layers.conv_norm_act(
mid_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0], mid_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block, apply_act=False) groups=groups, drop_layer=drop_block, apply_act=False)
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True) self.act = nn.Identity() if linear_out else layers.act(inplace=True)
@ -1181,8 +1179,7 @@ class EdgeBlock(nn.Module):
apply_act=False, layers=layers) apply_act=False, layers=layers)
self.conv1_kxk = layers.conv_norm_act( self.conv1_kxk = layers.conv_norm_act(
in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block)
groups=groups, drop_block=drop_block)
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs) self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
self.conv2_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False) self.conv2_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs) self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
@ -1226,7 +1223,7 @@ class RepVggBlock(nn.Module):
self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None
self.conv_kxk = layers.conv_norm_act( self.conv_kxk = layers.conv_norm_act(
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0], in_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block, apply_act=False) groups=groups, drop_layer=drop_block, apply_act=False)
self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False) self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False)
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs) self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity() self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
@ -1273,7 +1270,7 @@ class SelfAttnBlock(nn.Module):
if extra_conv: if extra_conv:
self.conv2_kxk = layers.conv_norm_act( self.conv2_kxk = layers.conv_norm_act(
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block) groups=groups, drop_layer=drop_block)
stride = 1 # striding done via conv if enabled stride = 1 # striding done via conv if enabled
else: else:
self.conv2_kxk = nn.Identity() self.conv2_kxk = nn.Identity()
@ -1520,8 +1517,8 @@ def create_byob_stages(
def get_layer_fns(cfg: ByoModelCfg): def get_layer_fns(cfg: ByoModelCfg):
act = get_act_layer(cfg.act_layer) act = get_act_layer(cfg.act_layer)
norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act) norm_act = get_norm_act_layer(norm_layer=cfg.norm_layer, act_layer=act)
conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act) conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act)
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
self_attn = partial(get_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None self_attn = partial(get_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None
layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn) layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn)

@ -14,11 +14,10 @@ Hacked together by / Copyright 2020 Ross Wightman
""" """
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import ClassifierHead, ConvBnAct, DropPath, create_attn, get_norm_act_layer from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, create_attn, get_norm_act_layer
from .registry import register_model from .registry import register_model
@ -130,7 +129,7 @@ model_cfgs = dict(
def create_stem( def create_stem(
in_chans=3, out_chs=32, kernel_size=3, stride=2, pool='', in_chans=3, out_chs=32, kernel_size=3, stride=2, pool='',
act_layer=None, norm_layer=None, aa_layer=None): act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None):
stem = nn.Sequential() stem = nn.Sequential()
if not isinstance(out_chs, (tuple, list)): if not isinstance(out_chs, (tuple, list)):
out_chs = [out_chs] out_chs = [out_chs]
@ -138,7 +137,7 @@ def create_stem(
in_c = in_chans in_c = in_chans
for i, out_c in enumerate(out_chs): for i, out_c in enumerate(out_chs):
conv_name = f'conv{i + 1}' conv_name = f'conv{i + 1}'
stem.add_module(conv_name, ConvBnAct( stem.add_module(conv_name, ConvNormAct(
in_c, out_c, kernel_size, stride=stride if i == 0 else 1, in_c, out_c, kernel_size, stride=stride if i == 0 else 1,
act_layer=act_layer, norm_layer=norm_layer)) act_layer=act_layer, norm_layer=norm_layer))
in_c = out_c in_c = out_c
@ -161,12 +160,14 @@ class ResBottleneck(nn.Module):
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None): attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(ResBottleneck, self).__init__() super(ResBottleneck, self).__init__()
mid_chs = int(round(out_chs * bottle_ratio)) mid_chs = int(round(out_chs * bottle_ratio))
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block) ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
self.conv1 = ConvBnAct(in_chs, mid_chs, kernel_size=1, **ckwargs) self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs)
self.conv2 = ConvBnAct(mid_chs, mid_chs, kernel_size=3, dilation=dilation, groups=groups, **ckwargs) self.conv2 = ConvNormActAa(
mid_chs, mid_chs, kernel_size=3, dilation=dilation, groups=groups,
aa_layer=aa_layer, drop_layer=drop_block, **ckwargs)
self.attn2 = create_attn(attn_layer, channels=mid_chs) if not attn_last else None self.attn2 = create_attn(attn_layer, channels=mid_chs) if not attn_last else None
self.conv3 = ConvBnAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs) self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs)
self.attn3 = create_attn(attn_layer, channels=out_chs) if attn_last else None self.attn3 = create_attn(attn_layer, channels=out_chs) if attn_last else None
self.drop_path = drop_path self.drop_path = drop_path
self.act3 = act_layer(inplace=True) self.act3 = act_layer(inplace=True)
@ -201,9 +202,11 @@ class DarkBlock(nn.Module):
drop_block=None, drop_path=None): drop_block=None, drop_path=None):
super(DarkBlock, self).__init__() super(DarkBlock, self).__init__()
mid_chs = int(round(out_chs * bottle_ratio)) mid_chs = int(round(out_chs * bottle_ratio))
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block) ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
self.conv1 = ConvBnAct(in_chs, mid_chs, kernel_size=1, **ckwargs) self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs)
self.conv2 = ConvBnAct(mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups, **ckwargs) self.conv2 = ConvNormActAa(
mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups,
aa_layer=aa_layer, drop_layer=drop_block, **ckwargs)
self.attn = create_attn(attn_layer, channels=out_chs) self.attn = create_attn(attn_layer, channels=out_chs)
self.drop_path = drop_path self.drop_path = drop_path
@ -235,7 +238,7 @@ class CrossStage(nn.Module):
conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer')) conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
if stride != 1 or first_dilation != dilation: if stride != 1 or first_dilation != dilation:
self.conv_down = ConvBnAct( self.conv_down = ConvNormActAa(
in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs) aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs)
prev_chs = down_chs prev_chs = down_chs
@ -246,7 +249,7 @@ class CrossStage(nn.Module):
# FIXME this 1x1 expansion is pushed down into the cross and block paths in the darknet cfgs. Also, # FIXME this 1x1 expansion is pushed down into the cross and block paths in the darknet cfgs. Also,
# there is also special case for the first stage for some of the model that results in uneven split # there is also special case for the first stage for some of the model that results in uneven split
# across the two paths. I did it this way for simplicity for now. # across the two paths. I did it this way for simplicity for now.
self.conv_exp = ConvBnAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs) self.conv_exp = ConvNormAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs)
prev_chs = exp_chs // 2 # output of conv_exp is always split in two prev_chs = exp_chs // 2 # output of conv_exp is always split in two
self.blocks = nn.Sequential() self.blocks = nn.Sequential()
@ -257,8 +260,8 @@ class CrossStage(nn.Module):
prev_chs = block_out_chs prev_chs = block_out_chs
# transition convs # transition convs
self.conv_transition_b = ConvBnAct(prev_chs, exp_chs // 2, kernel_size=1, **conv_kwargs) self.conv_transition_b = ConvNormAct(prev_chs, exp_chs // 2, kernel_size=1, **conv_kwargs)
self.conv_transition = ConvBnAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs) self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs)
def forward(self, x): def forward(self, x):
if self.conv_down is not None: if self.conv_down is not None:
@ -280,7 +283,7 @@ class DarkStage(nn.Module):
super(DarkStage, self).__init__() super(DarkStage, self).__init__()
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
self.conv_down = ConvBnAct( self.conv_down = ConvNormActAa(
in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups, in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'), act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'),
aa_layer=block_kwargs.get('aa_layer', None)) aa_layer=block_kwargs.get('aa_layer', None))
@ -437,7 +440,7 @@ def cspresnext50(pretrained=False, **kwargs):
@register_model @register_model
def cspresnext50_iabn(pretrained=False, **kwargs): def cspresnext50_iabn(pretrained=False, **kwargs):
norm_layer = get_norm_act_layer('iabn') norm_layer = get_norm_act_layer('iabn', act_layer='leaky_relu')
return _create_cspnet('cspresnext50_iabn', pretrained=pretrained, norm_layer=norm_layer, **kwargs) return _create_cspnet('cspresnext50_iabn', pretrained=pretrained, norm_layer=norm_layer, **kwargs)
@ -448,7 +451,7 @@ def cspdarknet53(pretrained=False, **kwargs):
@register_model @register_model
def cspdarknet53_iabn(pretrained=False, **kwargs): def cspdarknet53_iabn(pretrained=False, **kwargs):
norm_layer = get_norm_act_layer('iabn') norm_layer = get_norm_act_layer('iabn', act_layer='leaky_relu')
return _create_cspnet('cspdarknet53_iabn', pretrained=pretrained, block_fn=DarkBlock, norm_layer=norm_layer, **kwargs) return _create_cspnet('cspdarknet53_iabn', pretrained=pretrained, block_fn=DarkBlock, norm_layer=norm_layer, **kwargs)

@ -14,7 +14,7 @@ from torch.jit.annotations import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import BatchNormAct2d, create_norm_act, BlurPool2d, create_classifier from .layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier
from .registry import register_model from .registry import register_model
__all__ = ['DenseNet'] __all__ = ['DenseNet']
@ -370,7 +370,7 @@ def densenet264d_iabn(pretrained=False, **kwargs):
r"""Densenet-264 model with deep stem and Inplace-ABN r"""Densenet-264 model with deep stem and Inplace-ABN
""" """
def norm_act_fn(num_features, **kwargs): def norm_act_fn(num_features, **kwargs):
return create_norm_act('iabn', num_features, **kwargs) return create_norm_act_layer('iabn', num_features, act_layer='leaky_relu', **kwargs)
model = _create_densenet( model = _create_densenet(
'densenet264d_iabn', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep', 'densenet264d_iabn', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep',
norm_layer=norm_act_fn, pretrained=pretrained, **kwargs) norm_layer=norm_act_fn, pretrained=pretrained, **kwargs)

@ -16,7 +16,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import BatchNormAct2d, ConvBnAct, create_conv2d, create_classifier from .layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier
from .registry import register_model from .registry import register_model
__all__ = ['DPN'] __all__ = ['DPN']
@ -180,7 +180,7 @@ class DPN(nn.Module):
blocks = OrderedDict() blocks = OrderedDict()
# conv1 # conv1
blocks['conv1_1'] = ConvBnAct( blocks['conv1_1'] = ConvNormAct(
in_chans, num_init_features, kernel_size=3 if small else 7, stride=2, norm_layer=norm_layer) in_chans, num_init_features, kernel_size=3 if small else 7, stride=2, norm_layer=norm_layer)
blocks['conv1_pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) blocks['conv1_pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.feature_info = [dict(num_chs=num_init_features, reduction=2, module='features.conv1_1')] self.feature_info = [dict(num_chs=num_init_features, reduction=2, module='features.conv1_1')]

@ -45,7 +45,7 @@ from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficien
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .features import FeatureInfo, FeatureHooks from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg, default_cfg_for_features from .helpers import build_model_with_cfg, default_cfg_for_features
from .layers import create_conv2d, create_classifier from .layers import create_conv2d, create_classifier, get_norm_act_layer, EvoNorm2dS0, GroupNormAct
from .registry import register_model from .registry import register_model
__all__ = ['EfficientNet', 'EfficientNetFeatures'] __all__ = ['EfficientNet', 'EfficientNetFeatures']
@ -117,6 +117,20 @@ default_cfgs = {
'efficientnet_l2': _cfg( 'efficientnet_l2': _cfg(
url='', input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961), url='', input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961),
# FIXME experimental
'efficientnet_b0_gn': _cfg(
url=''),
'efficientnet_b0_g8': _cfg(
url=''),
'efficientnet_b0_g16_evos': _cfg(
url=''),
'efficientnet_b3_gn': _cfg(
url='',
input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0),
'efficientnet_b3_g8_gn': _cfg(
url='',
input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0),
'efficientnet_es': _cfg( 'efficientnet_es': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth'),
'efficientnet_em': _cfg( 'efficientnet_em': _cfg(
@ -431,6 +445,7 @@ class EfficientNet(nn.Module):
super(EfficientNet, self).__init__() super(EfficientNet, self).__init__()
act_layer = act_layer or nn.ReLU act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d norm_layer = norm_layer or nn.BatchNorm2d
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
se_layer = se_layer or SqueezeExcite se_layer = se_layer or SqueezeExcite
self.num_classes = num_classes self.num_classes = num_classes
self.num_features = num_features self.num_features = num_features
@ -440,8 +455,7 @@ class EfficientNet(nn.Module):
if not fix_stem: if not fix_stem:
stem_size = round_chs_fn(stem_size) stem_size = round_chs_fn(stem_size)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size) self.bn1 = norm_act_layer(stem_size, inplace=True)
self.act1 = act_layer(inplace=True)
# Middle stages (IR/ER/DS Blocks) # Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder( builder = EfficientNetBuilder(
@ -453,17 +467,16 @@ class EfficientNet(nn.Module):
# Head + Pooling # Head + Pooling
self.conv_head = create_conv2d(head_chs, self.num_features, 1, padding=pad_type) self.conv_head = create_conv2d(head_chs, self.num_features, 1, padding=pad_type)
self.bn2 = norm_layer(self.num_features) self.bn2 = norm_act_layer(self.num_features, inplace=True)
self.act2 = act_layer(inplace=True)
self.global_pool, self.classifier = create_classifier( self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool) self.num_features, self.num_classes, pool_type=global_pool)
efficientnet_init_weights(self) efficientnet_init_weights(self)
def as_sequential(self): def as_sequential(self):
layers = [self.conv_stem, self.bn1, self.act1] layers = [self.conv_stem, self.bn1]
layers.extend(self.blocks) layers.extend(self.blocks)
layers.extend([self.conv_head, self.bn2, self.act2, self.global_pool]) layers.extend([self.conv_head, self.bn2, self.global_pool])
layers.extend([nn.Dropout(self.drop_rate), self.classifier]) layers.extend([nn.Dropout(self.drop_rate), self.classifier])
return nn.Sequential(*layers) return nn.Sequential(*layers)
@ -478,11 +491,9 @@ class EfficientNet(nn.Module):
def forward_features(self, x): def forward_features(self, x):
x = self.conv_stem(x) x = self.conv_stem(x)
x = self.bn1(x) x = self.bn1(x)
x = self.act1(x)
x = self.blocks(x) x = self.blocks(x)
x = self.conv_head(x) x = self.conv_head(x)
x = self.bn2(x) x = self.bn2(x)
x = self.act2(x)
return x return x
def forward(self, x): def forward(self, x):
@ -506,6 +517,7 @@ class EfficientNetFeatures(nn.Module):
super(EfficientNetFeatures, self).__init__() super(EfficientNetFeatures, self).__init__()
act_layer = act_layer or nn.ReLU act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d norm_layer = norm_layer or nn.BatchNorm2d
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
se_layer = se_layer or SqueezeExcite se_layer = se_layer or SqueezeExcite
self.drop_rate = drop_rate self.drop_rate = drop_rate
@ -513,8 +525,7 @@ class EfficientNetFeatures(nn.Module):
if not fix_stem: if not fix_stem:
stem_size = round_chs_fn(stem_size) stem_size = round_chs_fn(stem_size)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size) self.bn1 = norm_act_layer(stem_size, inplace=True)
self.act1 = act_layer(inplace=True)
# Middle stages (IR/ER/DS Blocks) # Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder( builder = EfficientNetBuilder(
@ -536,7 +547,6 @@ class EfficientNetFeatures(nn.Module):
def forward(self, x) -> List[torch.Tensor]: def forward(self, x) -> List[torch.Tensor]:
x = self.conv_stem(x) x = self.conv_stem(x)
x = self.bn1(x) x = self.bn1(x)
x = self.act1(x)
if self.feature_hooks is None: if self.feature_hooks is None:
features = [] features = []
if 0 in self._stage_out_idx: if 0 in self._stage_out_idx:
@ -767,7 +777,9 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
return model return model
def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): def _gen_efficientnet(
variant, channel_multiplier=1.0, depth_multiplier=1.0, channel_divisor=8,
group_size=None, pretrained=False, **kwargs):
"""Creates an EfficientNet model. """Creates an EfficientNet model.
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
@ -800,9 +812,9 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
['ir_r4_k5_s2_e6_c192_se0.25'], ['ir_r4_k5_s2_e6_c192_se0.25'],
['ir_r1_k3_s1_e6_c320_se0.25'], ['ir_r1_k3_s1_e6_c320_se0.25'],
] ]
round_chs_fn = partial(round_channels, multiplier=channel_multiplier) round_chs_fn = partial(round_channels, multiplier=channel_multiplier, divisor=channel_divisor)
model_kwargs = dict( model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier), block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
num_features=round_chs_fn(1280), num_features=round_chs_fn(1280),
stem_size=32, stem_size=32,
round_chs_fn=round_chs_fn, round_chs_fn=round_chs_fn,
@ -814,7 +826,8 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
return model return model
def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): def _gen_efficientnet_edge(
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs):
""" Creates an EfficientNet-EdgeTPU model """ Creates an EfficientNet-EdgeTPU model
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu
@ -832,7 +845,7 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0
] ]
round_chs_fn = partial(round_channels, multiplier=channel_multiplier) round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
model_kwargs = dict( model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier), block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
num_features=round_chs_fn(1280), num_features=round_chs_fn(1280),
stem_size=32, stem_size=32,
round_chs_fn=round_chs_fn, round_chs_fn=round_chs_fn,
@ -946,7 +959,7 @@ def _gen_efficientnetv2_base(
def _gen_efficientnetv2_s( def _gen_efficientnetv2_s(
variant, channel_multiplier=1.0, depth_multiplier=1.0, rw=False, pretrained=False, **kwargs): variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, rw=False, pretrained=False, **kwargs):
""" Creates an EfficientNet-V2 Small model """ Creates an EfficientNet-V2 Small model
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2 Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@ -972,7 +985,7 @@ def _gen_efficientnetv2_s(
round_chs_fn = partial(round_channels, multiplier=channel_multiplier) round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
model_kwargs = dict( model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier), block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
num_features=round_chs_fn(num_features), num_features=round_chs_fn(num_features),
stem_size=24, stem_size=24,
round_chs_fn=round_chs_fn, round_chs_fn=round_chs_fn,
@ -1366,6 +1379,52 @@ def efficientnet_l2(pretrained=False, **kwargs):
return model return model
# FIXME experimental group cong / GroupNorm / EvoNorm experiments
@register_model
def efficientnet_b0_gn(pretrained=False, **kwargs):
""" EfficientNet-B0 + GroupNorm"""
model = _gen_efficientnet(
'efficientnet_b0_gn', norm_layer=partial(GroupNormAct, group_size=8), pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_b0_g8(pretrained=False, **kwargs):
""" EfficientNet-B0 w/ group conv + BN"""
model = _gen_efficientnet(
'efficientnet_b0_g8', group_size=8, pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_b0_g16_evos(pretrained=False, **kwargs):
""" EfficientNet-B0 w/ group 16 conv + EvoNorm"""
model = _gen_efficientnet(
'efficientnet_b0_g16_evos', group_size=16, channel_divisor=16,
norm_layer=partial(EvoNorm2dS0, group_size=16), pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_b3_gn(pretrained=False, **kwargs):
""" EfficientNet-B3 w/ GroupNorm """
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
model = _gen_efficientnet(
'efficientnet_b3_gn', channel_multiplier=1.2, depth_multiplier=1.4, channel_divisor=16,
norm_layer=partial(GroupNormAct, group_size=16), pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_b3_g8_gn(pretrained=False, **kwargs):
""" EfficientNet-B3 w/ grouped conv + BN"""
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
model = _gen_efficientnet(
'efficientnet_b3_g8_gn', channel_multiplier=1.2, depth_multiplier=1.4, group_size=8, channel_divisor=16,
norm_layer=partial(GroupNormAct, group_size=16), pretrained=pretrained, **kwargs)
return model
@register_model @register_model
def efficientnet_es(pretrained=False, **kwargs): def efficientnet_es(pretrained=False, **kwargs):
""" EfficientNet-Edge Small. """ """ EfficientNet-Edge Small. """
@ -1373,6 +1432,7 @@ def efficientnet_es(pretrained=False, **kwargs):
'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) 'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model return model
@register_model @register_model
def efficientnet_es_pruned(pretrained=False, **kwargs): def efficientnet_es_pruned(pretrained=False, **kwargs):
""" EfficientNet-Edge Small Pruned. For more info: https://github.com/DeGirum/pruned-models/releases/tag/efficientnet_v1.0""" """ EfficientNet-Edge Small Pruned. For more info: https://github.com/DeGirum/pruned-models/releases/tag/efficientnet_v1.0"""

@ -2,18 +2,31 @@
Hacked together by / Copyright 2020 Ross Wightman Hacked together by / Copyright 2020 Ross Wightman
""" """
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from .layers import create_conv2d, drop_path, make_divisible, create_act_layer from .layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer
from .layers.activations import sigmoid
__all__ = [ __all__ = [
'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual'] 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual']
def num_groups(group_size, channels):
if not group_size: # 0 or None
return 1 # normal conv with 1 group
else:
# NOTE group_size == 1 -> depthwise conv
#assert channels % group_size == 0
if channels % group_size != 0:
num_groups = math.floor(channels / group_size)
print(channels, group_size, num_groups)
return int(num_groups)
return channels // group_size
class SqueezeExcite(nn.Module): class SqueezeExcite(nn.Module):
""" Squeeze-and-Excitation w/ specific features for EfficientNet/MobileNet family """ Squeeze-and-Excitation w/ specific features for EfficientNet/MobileNet family
@ -51,31 +64,30 @@ class ConvBnAct(nn.Module):
""" Conv + Norm Layer + Activation w/ optional skip connection """ Conv + Norm Layer + Activation w/ optional skip connection
""" """
def __init__( def __init__(
self, in_chs, out_chs, kernel_size, stride=1, dilation=1, pad_type='', self, in_chs, out_chs, kernel_size, stride=1, dilation=1, group_size=0, pad_type='',
skip=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_path_rate=0.): skip=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_path_rate=0.):
super(ConvBnAct, self).__init__() super(ConvBnAct, self).__init__()
self.has_residual = skip and stride == 1 and in_chs == out_chs norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
self.drop_path_rate = drop_path_rate groups = num_groups(group_size, in_chs)
self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type) self.has_skip = skip and stride == 1 and in_chs == out_chs
self.bn1 = norm_layer(out_chs)
self.act1 = act_layer(inplace=True) self.conv = create_conv2d(
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, groups=groups, padding=pad_type)
self.bn1 = norm_act_layer(out_chs, inplace=True)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
def feature_info(self, location): def feature_info(self, location):
if location == 'expansion': # output of conv after act, same as block coutput if location == 'expansion': # output of conv after act, same as block coutput
info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels) return dict(module='bn1', hook_type='forward', num_chs=self.conv.out_channels)
else: # location == 'bottleneck', block output else: # location == 'bottleneck', block output
info = dict(module='', hook_type='', num_chs=self.conv.out_channels) return dict(module='', hook_type='', num_chs=self.conv.out_channels)
return info
def forward(self, x): def forward(self, x):
shortcut = x shortcut = x
x = self.conv(x) x = self.conv(x)
x = self.bn1(x) x = self.bn1(x)
x = self.act1(x) if self.has_skip:
if self.has_residual: x = x + self.drop_path(shortcut)
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += shortcut
return x return x
@ -85,50 +97,41 @@ class DepthwiseSeparableConv(nn.Module):
(factor of 1.0). This is an alternative to having a IR with an optional first pw conv. (factor of 1.0). This is an alternative to having a IR with an optional first pw conv.
""" """
def __init__( def __init__(
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='',
noskip=False, pw_kernel_size=1, pw_act=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, noskip=False, pw_kernel_size=1, pw_act=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
se_layer=None, drop_path_rate=0.): se_layer=None, drop_path_rate=0.):
super(DepthwiseSeparableConv, self).__init__() super(DepthwiseSeparableConv, self).__init__()
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
groups = num_groups(group_size, in_chs)
self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip
self.has_pw_act = pw_act # activation after point-wise conv self.has_pw_act = pw_act # activation after point-wise conv
self.drop_path_rate = drop_path_rate
self.conv_dw = create_conv2d( self.conv_dw = create_conv2d(
in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True) in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, groups=groups)
self.bn1 = norm_layer(in_chs) self.bn1 = norm_act_layer(in_chs, inplace=True)
self.act1 = act_layer(inplace=True)
# Squeeze-and-excitation # Squeeze-and-excitation
self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity() self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity()
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
self.bn2 = norm_layer(out_chs) self.bn2 = norm_act_layer(out_chs, inplace=True, apply_act=self.has_pw_act)
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity() self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
def feature_info(self, location): def feature_info(self, location):
if location == 'expansion': # after SE, input to PW if location == 'expansion': # after SE, input to PW
info = dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels) return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
else: # location == 'bottleneck', block output else: # location == 'bottleneck', block output
info = dict(module='', hook_type='', num_chs=self.conv_pw.out_channels) return dict(module='', hook_type='', num_chs=self.conv_pw.out_channels)
return info
def forward(self, x): def forward(self, x):
shortcut = x shortcut = x
x = self.conv_dw(x) x = self.conv_dw(x)
x = self.bn1(x) x = self.bn1(x)
x = self.act1(x)
x = self.se(x) x = self.se(x)
x = self.conv_pw(x) x = self.conv_pw(x)
x = self.bn2(x) x = self.bn2(x)
x = self.act2(x) if self.has_skip:
x = x + self.drop_path(shortcut)
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += shortcut
return x return x
@ -143,66 +146,51 @@ class InvertedResidual(nn.Module):
""" """
def __init__( def __init__(
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='',
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU, noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.): norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.):
super(InvertedResidual, self).__init__() super(InvertedResidual, self).__init__()
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
conv_kwargs = conv_kwargs or {} conv_kwargs = conv_kwargs or {}
mid_chs = make_divisible(in_chs * exp_ratio) mid_chs = make_divisible(in_chs * exp_ratio)
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip groups = num_groups(group_size, mid_chs)
self.drop_path_rate = drop_path_rate self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
# Point-wise expansion # Point-wise expansion
self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
self.bn1 = norm_layer(mid_chs) self.bn1 = norm_act_layer(mid_chs, inplace=True)
self.act1 = act_layer(inplace=True)
# Depth-wise convolution # Depth-wise convolution
self.conv_dw = create_conv2d( self.conv_dw = create_conv2d(
mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation, mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation,
padding=pad_type, depthwise=True, **conv_kwargs) groups=groups, padding=pad_type, **conv_kwargs)
self.bn2 = norm_layer(mid_chs) self.bn2 = norm_act_layer(mid_chs, inplace=True)
self.act2 = act_layer(inplace=True)
# Squeeze-and-excitation # Squeeze-and-excitation
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
# Point-wise linear projection # Point-wise linear projection
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
self.bn3 = norm_layer(out_chs) self.bn3 = norm_act_layer(out_chs, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
def feature_info(self, location): def feature_info(self, location):
if location == 'expansion': # after SE, input to PWL if location == 'expansion': # after SE, input to PWL
info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
else: # location == 'bottleneck', block output else: # location == 'bottleneck', block output
info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels) return dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
return info
def forward(self, x): def forward(self, x):
shortcut = x shortcut = x
# Point-wise expansion
x = self.conv_pw(x) x = self.conv_pw(x)
x = self.bn1(x) x = self.bn1(x)
x = self.act1(x)
# Depth-wise convolution
x = self.conv_dw(x) x = self.conv_dw(x)
x = self.bn2(x) x = self.bn2(x)
x = self.act2(x)
# Squeeze-and-excitation
x = self.se(x) x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x) x = self.conv_pwl(x)
x = self.bn3(x) x = self.bn3(x)
if self.has_skip:
if self.has_residual: x = x + self.drop_path(shortcut)
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += shortcut
return x return x
@ -210,7 +198,7 @@ class CondConvResidual(InvertedResidual):
""" Inverted residual block w/ CondConv routing""" """ Inverted residual block w/ CondConv routing"""
def __init__( def __init__(
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='', self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='',
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU, noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.): norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.):
@ -218,8 +206,8 @@ class CondConvResidual(InvertedResidual):
conv_kwargs = dict(num_experts=self.num_experts) conv_kwargs = dict(num_experts=self.num_experts)
super(CondConvResidual, self).__init__( super(CondConvResidual, self).__init__(
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type, in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, group_size=group_size,
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, pad_type=pad_type, act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size, se_layer=se_layer, norm_layer=norm_layer, conv_kwargs=conv_kwargs, pw_kernel_size=pw_kernel_size, se_layer=se_layer, norm_layer=norm_layer, conv_kwargs=conv_kwargs,
drop_path_rate=drop_path_rate) drop_path_rate=drop_path_rate)
@ -227,32 +215,17 @@ class CondConvResidual(InvertedResidual):
def forward(self, x): def forward(self, x):
shortcut = x shortcut = x
pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) # CondConv routing
# CondConv routing
pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1)
routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs)) routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs))
# Point-wise expansion
x = self.conv_pw(x, routing_weights) x = self.conv_pw(x, routing_weights)
x = self.bn1(x) x = self.bn1(x)
x = self.act1(x)
# Depth-wise convolution
x = self.conv_dw(x, routing_weights) x = self.conv_dw(x, routing_weights)
x = self.bn2(x) x = self.bn2(x)
x = self.act2(x)
# Squeeze-and-excitation
x = self.se(x) x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x, routing_weights) x = self.conv_pwl(x, routing_weights)
x = self.bn3(x) x = self.bn3(x)
if self.has_skip:
if self.has_residual: x = x + self.drop_path(shortcut)
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += shortcut
return x return x
@ -269,55 +242,44 @@ class EdgeResidual(nn.Module):
""" """
def __init__( def __init__(
self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, pad_type='', self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, group_size=0, pad_type='',
force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, act_layer=nn.ReLU, force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.): norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
super(EdgeResidual, self).__init__() super(EdgeResidual, self).__init__()
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
if force_in_chs > 0: if force_in_chs > 0:
mid_chs = make_divisible(force_in_chs * exp_ratio) mid_chs = make_divisible(force_in_chs * exp_ratio)
else: else:
mid_chs = make_divisible(in_chs * exp_ratio) mid_chs = make_divisible(in_chs * exp_ratio)
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip groups = num_groups(group_size, in_chs)
self.drop_path_rate = drop_path_rate self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
# Expansion convolution # Expansion convolution
self.conv_exp = create_conv2d( self.conv_exp = create_conv2d(
in_chs, mid_chs, exp_kernel_size, stride=stride, dilation=dilation, padding=pad_type) in_chs, mid_chs, exp_kernel_size, stride=stride, dilation=dilation, groups=groups, padding=pad_type)
self.bn1 = norm_layer(mid_chs) self.bn1 = norm_act_layer(mid_chs, inplace=True)
self.act1 = act_layer(inplace=True)
# Squeeze-and-excitation # Squeeze-and-excitation
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity() self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
# Point-wise linear projection # Point-wise linear projection
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type) self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type)
self.bn2 = norm_layer(out_chs) self.bn2 = norm_act_layer(out_chs, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
def feature_info(self, location): def feature_info(self, location):
if location == 'expansion': # after SE, before PWL if location == 'expansion': # after SE, before PWL
info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels) return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
else: # location == 'bottleneck', block output else: # location == 'bottleneck', block output
info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels) return dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
return info
def forward(self, x): def forward(self, x):
shortcut = x shortcut = x
# Expansion convolution
x = self.conv_exp(x) x = self.conv_exp(x)
x = self.bn1(x) x = self.bn1(x)
x = self.act1(x)
# Squeeze-and-excitation
x = self.se(x) x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x) x = self.conv_pwl(x)
x = self.bn2(x) x = self.bn2(x)
if self.has_skip:
if self.has_residual: x = x + self.drop_path(shortcut)
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += shortcut
return x return x

@ -139,60 +139,52 @@ def _decode_block_str(block_str):
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
num_repeat = int(options['r']) num_repeat = int(options['r'])
# each type of block has different valid arguments, fill accordingly # each type of block has different valid arguments, fill accordingly
block_args = dict(
block_type=block_type,
out_chs=int(options['c']),
stride=int(options['s']),
act_layer=act_layer,
)
if block_type == 'ir': if block_type == 'ir':
block_args = dict( block_args.update(dict(
block_type=block_type,
dw_kernel_size=_parse_ksize(options['k']), dw_kernel_size=_parse_ksize(options['k']),
exp_kernel_size=exp_kernel_size, exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size, pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']), exp_ratio=float(options['e']),
se_ratio=float(options['se']) if 'se' in options else 0., se_ratio=float(options['se']) if 'se' in options else 0.,
stride=int(options['s']),
act_layer=act_layer,
noskip=skip is False, noskip=skip is False,
) ))
if 'cc' in options: if 'cc' in options:
block_args['num_experts'] = int(options['cc']) block_args['num_experts'] = int(options['cc'])
elif block_type == 'ds' or block_type == 'dsa': elif block_type == 'ds' or block_type == 'dsa':
block_args = dict( block_args.update(dict(
block_type=block_type,
dw_kernel_size=_parse_ksize(options['k']), dw_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size, pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
se_ratio=float(options['se']) if 'se' in options else 0., se_ratio=float(options['se']) if 'se' in options else 0.,
stride=int(options['s']),
act_layer=act_layer,
pw_act=block_type == 'dsa', pw_act=block_type == 'dsa',
noskip=block_type == 'dsa' or skip is False, noskip=block_type == 'dsa' or skip is False,
) ))
elif block_type == 'er': elif block_type == 'er':
block_args = dict( block_args.update(dict(
block_type=block_type,
exp_kernel_size=_parse_ksize(options['k']), exp_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size, pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']), exp_ratio=float(options['e']),
force_in_chs=force_in_chs, force_in_chs=force_in_chs,
se_ratio=float(options['se']) if 'se' in options else 0., se_ratio=float(options['se']) if 'se' in options else 0.,
stride=int(options['s']),
act_layer=act_layer,
noskip=skip is False, noskip=skip is False,
) ))
elif block_type == 'cn': elif block_type == 'cn':
block_args = dict( block_args.update(dict(
block_type=block_type,
kernel_size=int(options['k']), kernel_size=int(options['k']),
out_chs=int(options['c']),
stride=int(options['s']),
act_layer=act_layer,
skip=skip is True, skip=skip is True,
) ))
else: else:
assert False, 'Unknown block type (%s)' % block_type assert False, 'Unknown block type (%s)' % block_type
if 'gs' in options:
block_args['group_size'] = options['gs']
return block_args, num_repeat return block_args, num_repeat
@ -235,7 +227,27 @@ def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='c
return sa_scaled return sa_scaled
def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False): def decode_arch_def(
arch_def,
depth_multiplier=1.0,
depth_trunc='ceil',
experts_multiplier=1,
fix_first_last=False,
group_size=None,
):
""" Decode block architecture definition strings -> block kwargs
Args:
arch_def: architecture definition strings, list of list of strings
depth_multiplier: network depth multiplier
depth_trunc: networ depth truncation mode when applying multiplier
experts_multiplier: CondConv experts multiplier
fix_first_last: fix first and last block depths when multiplier is applied
group_size: group size override for all blocks that weren't explicitly set in arch string
Returns:
list of list of block kwargs
"""
arch_args = [] arch_args = []
if isinstance(depth_multiplier, tuple): if isinstance(depth_multiplier, tuple):
assert len(depth_multiplier) == len(arch_def) assert len(depth_multiplier) == len(arch_def)
@ -250,6 +262,8 @@ def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_
ba, rep = _decode_block_str(block_str) ba, rep = _decode_block_str(block_str)
if ba.get('num_experts', 0) > 0 and experts_multiplier > 1: if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
ba['num_experts'] *= experts_multiplier ba['num_experts'] *= experts_multiplier
if group_size is not None:
ba.setdefault('group_size', group_size)
stack_args.append(ba) stack_args.append(ba)
repeats.append(rep) repeats.append(rep)
if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1): if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):

@ -7,11 +7,11 @@ from .cond_conv2d import CondConv2d, get_condconv_initializer
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\ from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
set_layer_config set_layer_config
from .conv2d_same import Conv2dSame, conv2d_same from .conv2d_same import Conv2dSame, conv2d_same
from .conv_bn_act import ConvBnAct from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
from .create_act import create_act_layer, get_act_layer, get_act_fn from .create_act import create_act_layer, get_act_layer, get_act_fn
from .create_attn import get_attn, create_attn from .create_attn import get_attn, create_attn
from .create_conv2d import create_conv2d from .create_conv2d import create_conv2d
from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
@ -32,7 +32,7 @@ from .patch_embed import PatchEmbed
from .pool2d_same import AvgPool2dSame, create_pool2d from .pool2d_same import AvgPool2dSame, create_pool2d
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernel from .selective_kernel import SelectiveKernel
from .separable_conv import SeparableConv2d, SeparableConvBnAct from .separable_conv import SeparableConv2d, SeparableConvNormAct
from .space_to_depth import SpaceToDepthModule from .space_to_depth import SpaceToDepthModule
from .split_attn import SplitAttn from .split_attn import SplitAttn
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model

@ -11,7 +11,7 @@ import torch
from torch import nn as nn from torch import nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .conv_bn_act import ConvBnAct from .conv_bn_act import ConvNormAct
from .create_act import create_act_layer, get_act_layer from .create_act import create_act_layer, get_act_layer
from .helpers import make_divisible from .helpers import make_divisible
@ -56,7 +56,7 @@ class SpatialAttn(nn.Module):
""" """
def __init__(self, kernel_size=7, gate_layer='sigmoid'): def __init__(self, kernel_size=7, gate_layer='sigmoid'):
super(SpatialAttn, self).__init__() super(SpatialAttn, self).__init__()
self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None) self.conv = ConvNormAct(2, 1, kernel_size, apply_act=False)
self.gate = create_act_layer(gate_layer) self.gate = create_act_layer(gate_layer)
def forward(self, x): def forward(self, x):
@ -70,7 +70,7 @@ class LightSpatialAttn(nn.Module):
""" """
def __init__(self, kernel_size=7, gate_layer='sigmoid'): def __init__(self, kernel_size=7, gate_layer='sigmoid'):
super(LightSpatialAttn, self).__init__() super(LightSpatialAttn, self).__init__()
self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None) self.conv = ConvNormAct(1, 1, kernel_size, apply_act=False)
self.gate = create_act_layer(gate_layer) self.gate = create_act_layer(gate_layer)
def forward(self, x): def forward(self, x):

@ -5,14 +5,46 @@ Hacked together by / Copyright 2020 Ross Wightman
from torch import nn as nn from torch import nn as nn
from .create_conv2d import create_conv2d from .create_conv2d import create_conv2d
from .create_norm_act import convert_norm_act from .create_norm_act import get_norm_act_layer
class ConvBnAct(nn.Module): class ConvNormAct(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1, def __init__(
bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
drop_block=None): bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, drop_layer=None):
super(ConvBnAct, self).__init__() super(ConvNormAct, self).__init__()
self.conv = create_conv2d(
in_channels, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
@property
def in_channels(self):
return self.conv.in_channels
@property
def out_channels(self):
return self.conv.out_channels
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
ConvBnAct = ConvNormAct
class ConvNormActAa(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None):
super(ConvNormActAa, self).__init__()
use_aa = aa_layer is not None use_aa = aa_layer is not None
self.conv = create_conv2d( self.conv = create_conv2d(
@ -20,9 +52,11 @@ class ConvBnAct(nn.Module):
padding=padding, dilation=dilation, groups=groups, bias=bias) padding=padding, dilation=dilation, groups=groups, bias=bias)
# NOTE for backwards compatibility with models that use separate norm and act layer definitions # NOTE for backwards compatibility with models that use separate norm and act layer definitions
norm_act_layer = convert_norm_act(norm_layer, act_layer) norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block) # NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else nn.Identity()
@property @property
def in_channels(self): def in_channels(self):
@ -35,6 +69,5 @@ class ConvBnAct(nn.Module):
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
x = self.bn(x) x = self.bn(x)
if self.aa is not None: x = self.aa(x)
x = self.aa(x)
return x return x

@ -16,7 +16,12 @@ def create_conv2d(in_channels, out_channels, kernel_size, **kwargs):
""" """
if isinstance(kernel_size, list): if isinstance(kernel_size, list):
assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
assert 'groups' not in kwargs # MixedConv groups are defined by kernel list if 'groups' in kwargs:
groups = kwargs.pop('groups')
if groups == in_channels:
kwargs['depthwise'] = True
else:
assert groups == 1
# We're going to use only lists for defining the MixedConv2d kernel groups, # We're going to use only lists for defining the MixedConv2d kernel groups,
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w. # ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs) m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs)

@ -11,12 +11,15 @@ import functools
from .evo_norm import * from .evo_norm import *
from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d
from .norm_act import BatchNormAct2d, GroupNormAct from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d
from .inplace_abn import InplaceAbn from .inplace_abn import InplaceAbn
_NORM_ACT_MAP = dict( _NORM_ACT_MAP = dict(
batchnorm=BatchNormAct2d, batchnorm=BatchNormAct2d,
batchnorm2d=BatchNormAct2d,
groupnorm=GroupNormAct, groupnorm=GroupNormAct,
layernorm=LayerNormAct,
layernorm2d=LayerNormAct2d,
evonormb0=EvoNorm2dB0, evonormb0=EvoNorm2dB0,
evonormb1=EvoNorm2dB1, evonormb1=EvoNorm2dB1,
evonormb2=EvoNorm2dB2, evonormb2=EvoNorm2dB2,
@ -33,28 +36,19 @@ _NORM_ACT_MAP = dict(
) )
_NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()} _NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()}
# has act_layer arg to define act type # has act_layer arg to define act type
_NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, FilterResponseNormAct2d, InplaceAbn} _NORM_ACT_REQUIRES_ARG = {
BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn}
def get_norm_act_layer(layer_name): def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs):
layer_name = layer_name.replace('_', '').lower().split('-')[0] layer = get_norm_act_layer(layer_name, act_layer=act_layer)
layer = _NORM_ACT_MAP.get(layer_name, None)
assert layer is not None, "Invalid norm_act layer (%s)" % layer_name
return layer
def create_norm_act(layer_name, num_features, apply_act=True, jit=False, **kwargs):
layer_parts = layer_name.split('-') # e.g. batchnorm-leaky_relu
assert len(layer_parts) in (1, 2)
layer = get_norm_act_layer(layer_parts[0])
#activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection?
layer_instance = layer(num_features, apply_act=apply_act, **kwargs) layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
if jit: if jit:
layer_instance = torch.jit.script(layer_instance) layer_instance = torch.jit.script(layer_instance)
return layer_instance return layer_instance
def convert_norm_act(norm_layer, act_layer): def get_norm_act_layer(norm_layer, act_layer=None):
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial)) assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial)) assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial))
norm_act_kwargs = {} norm_act_kwargs = {}
@ -65,7 +59,8 @@ def convert_norm_act(norm_layer, act_layer):
norm_layer = norm_layer.func norm_layer = norm_layer.func
if isinstance(norm_layer, str): if isinstance(norm_layer, str):
norm_act_layer = get_norm_act_layer(norm_layer) layer_name = norm_layer.replace('_', '').lower().split('-')[0]
norm_act_layer = _NORM_ACT_MAP.get(layer_name, None)
elif norm_layer in _NORM_ACT_TYPES: elif norm_layer in _NORM_ACT_TYPES:
norm_act_layer = norm_layer norm_act_layer = norm_layer
elif isinstance(norm_layer, types.FunctionType): elif isinstance(norm_layer, types.FunctionType):
@ -77,6 +72,10 @@ def convert_norm_act(norm_layer, act_layer):
norm_act_layer = BatchNormAct2d norm_act_layer = BatchNormAct2d
elif type_name.startswith('groupnorm'): elif type_name.startswith('groupnorm'):
norm_act_layer = GroupNormAct norm_act_layer = GroupNormAct
elif type_name.startswith('layernorm2d'):
norm_act_layer = LayerNormAct2d
elif type_name.startswith('layernorm'):
norm_act_layer = LayerNormAct
else: else:
assert False, f"No equivalent norm_act layer for {type_name}" assert False, f"No equivalent norm_act layer for {type_name}"

@ -20,7 +20,7 @@ import torch.nn.functional as F
def drop_block_2d( def drop_block_2d(
x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0, x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
with_noise: bool = False, inplace: bool = False, batchwise: bool = False): with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
@ -32,7 +32,7 @@ def drop_block_2d(
clipped_block_size = min(block_size, min(W, H)) clipped_block_size = min(block_size, min(W, H))
# seed_drop_rate, the gamma parameter # seed_drop_rate, the gamma parameter
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / ( gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
(W - block_size + 1) * (H - block_size + 1)) (W - block_size + 1) * (H - block_size + 1))
# Forces the block to be inside the feature map. # Forces the block to be inside the feature map.
w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device)) w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
@ -104,14 +104,16 @@ def drop_block_fast_2d(
class DropBlock2d(nn.Module): class DropBlock2d(nn.Module):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
""" """
def __init__(self,
drop_prob=0.1, def __init__(
block_size=7, self,
gamma_scale=1.0, drop_prob=0.1,
with_noise=False, block_size=7,
inplace=False, gamma_scale=1.0,
batchwise=False, with_noise=False,
fast=True): inplace=False,
batchwise=False,
fast=True):
super(DropBlock2d, self).__init__() super(DropBlock2d, self).__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
self.gamma_scale = gamma_scale self.gamma_scale = gamma_scale
@ -155,6 +157,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: b
class DropPath(nn.Module): class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
""" """
def __init__(self, drop_prob=None, scale_by_keep=True): def __init__(self, drop_prob=None, scale_by_keep=True):
super(DropPath, self).__init__() super(DropPath, self).__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob

@ -38,7 +38,7 @@ class InplaceAbn(nn.Module):
""" """
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True, def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True,
act_layer="leaky_relu", act_param=0.01, drop_block=None): act_layer="leaky_relu", act_param=0.01, drop_layer=None):
super(InplaceAbn, self).__init__() super(InplaceAbn, self).__init__()
self.num_features = num_features self.num_features = num_features
self.affine = affine self.affine = affine
@ -54,7 +54,7 @@ class InplaceAbn(nn.Module):
self.act_name = 'elu' self.act_name = 'elu'
elif act_layer == nn.LeakyReLU: elif act_layer == nn.LeakyReLU:
self.act_name = 'leaky_relu' self.act_name = 'leaky_relu'
elif act_layer == nn.Identity: elif act_layer is None or act_layer == nn.Identity:
self.act_name = 'identity' self.act_name = 'identity'
else: else:
assert False, f'Invalid act layer {act_layer.__name__} for IABN' assert False, f'Invalid act layer {act_layer.__name__} for IABN'

@ -8,7 +8,7 @@ import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
from .conv_bn_act import ConvBnAct from .conv_bn_act import ConvNormAct
from .helpers import make_divisible from .helpers import make_divisible
from .trace_utils import _assert from .trace_utils import _assert
@ -74,10 +74,10 @@ class BilinearAttnTransform(nn.Module):
def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(BilinearAttnTransform, self).__init__() super(BilinearAttnTransform, self).__init__()
self.conv1 = ConvBnAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer) self.conv1 = ConvNormAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer)
self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1)) self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1))
self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size)) self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size))
self.conv2 = ConvBnAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) self.conv2 = ConvNormAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
self.block_size = block_size self.block_size = block_size
self.groups = groups self.groups = groups
self.in_channels = in_channels self.in_channels = in_channels
@ -132,9 +132,9 @@ class BatNonLocalAttn(nn.Module):
super().__init__() super().__init__()
if rd_channels is None: if rd_channels is None:
rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor) rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)
self.conv1 = ConvBnAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer) self.conv1 = ConvNormAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer) self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer)
self.conv2 = ConvBnAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer) self.conv2 = ConvNormAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
self.dropout = nn.Dropout2d(p=drop_rate) self.dropout = nn.Dropout2d(p=drop_rate)
def forward(self, x): def forward(self, x):

@ -1,5 +1,7 @@
""" Normalization + Activation Layers """ Normalization + Activation Layers
""" """
from typing import Union, List
import torch import torch
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
@ -14,12 +16,13 @@ class BatchNormAct2d(nn.BatchNorm2d):
compatible with weights trained with separate bn, act. This is why we inherit from BN compatible with weights trained with separate bn, act. This is why we inherit from BN
instead of composing it as a .bn member. instead of composing it as a .bn member.
""" """
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, def __init__(
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
super(BatchNormAct2d, self).__init__( super(BatchNormAct2d, self).__init__(
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
if isinstance(act_layer, str): self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act: if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {} act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args) self.act = act_layer(**act_args)
@ -29,8 +32,8 @@ class BatchNormAct2d(nn.BatchNorm2d):
def _forward_jit(self, x): def _forward_jit(self, x):
""" A cut & paste of the contents of the PyTorch BatchNorm2d forward function """ A cut & paste of the contents of the PyTorch BatchNorm2d forward function
""" """
# exponential_average_factor is self.momentum set to # exponential_average_factor is set to self.momentum
# (when it is available) only so that if gets updated # (when it is available) only so that it gets updated
# in ONNX graph when this node is exported to ONNX. # in ONNX graph when this node is exported to ONNX.
if self.momentum is None: if self.momentum is None:
exponential_average_factor = 0.0 exponential_average_factor = 0.0
@ -39,18 +42,38 @@ class BatchNormAct2d(nn.BatchNorm2d):
if self.training and self.track_running_stats: if self.training and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None # TODO: if statement only here to tell the jit to skip emitting this when it is None
if self.num_batches_tracked is not None: if self.num_batches_tracked is not None: # type: ignore[has-type]
self.num_batches_tracked += 1 self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type]
if self.momentum is None: # use cumulative moving average if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked) exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average else: # use exponential moving average
exponential_average_factor = self.momentum exponential_average_factor = self.momentum
x = F.batch_norm( r"""
x, self.running_mean, self.running_var, self.weight, self.bias, Decide whether the mini-batch stats should be used for normalization rather than the buffers.
self.training or not self.track_running_stats, Mini-batch stats are used in training mode, and in eval mode when buffers are None.
exponential_average_factor, self.eps) """
return x if self.training:
bn_training = True
else:
bn_training = (self.running_mean is None) and (self.running_var is None)
r"""
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
used for normalization (i.e. in eval mode when buffers are not None).
"""
return F.batch_norm(
x,
# If buffers are not to be tracked, ensure that they won't be updated
self.running_mean if not self.training or self.track_running_stats else None,
self.running_var if not self.training or self.track_running_stats else None,
self.weight,
self.bias,
bn_training,
exponential_average_factor,
self.eps,
)
@torch.jit.ignore @torch.jit.ignore
def _forward_python(self, x): def _forward_python(self, x):
@ -62,6 +85,7 @@ class BatchNormAct2d(nn.BatchNorm2d):
x = self._forward_jit(x) x = self._forward_jit(x)
else: else:
x = self._forward_python(x) x = self._forward_python(x)
x = self.drop(x)
x = self.act(x) x = self.act(x)
return x return x
@ -91,13 +115,22 @@ def group_norm_tpu(x, w, b, groups: int = 32, eps: float = 1e-5, diff_sqm: bool
return x return x
def _num_groups(num_channels, num_groups, group_size):
if group_size:
assert num_channels % group_size == 0
return num_channels // group_size
return num_groups
class GroupNormAct(nn.GroupNorm): class GroupNormAct(nn.GroupNorm):
# NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args # NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True, def __init__(
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None): self, num_channels, num_groups=32, eps=1e-5, affine=True, group_size=None,
super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine) apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
if isinstance(act_layer, str): super(GroupNormAct, self).__init__(
act_layer = get_act_layer(act_layer) _num_groups(num_channels, num_groups, group_size), num_channels, eps=eps, affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act: if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {} act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args) self.act = act_layer(**act_args)
@ -109,5 +142,47 @@ class GroupNormAct(nn.GroupNorm):
x = group_norm_tpu(x, self.weight, self.bias, self.num_groups, self.eps) x = group_norm_tpu(x, self.weight, self.bias, self.num_groups, self.eps)
else: else:
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps) x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
x = self.drop(x)
x = self.act(x)
return x
class LayerNormAct(nn.LayerNorm):
def __init__(
self, normalization_shape: Union[int, List[int], torch.Size], eps=1e-5, affine=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
def forward(self, x):
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = self.drop(x)
x = self.act(x)
return x
class LayerNormAct2d(nn.LayerNorm):
def __init__(
self, num_channels, eps=1e-5, affine=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
def forward(self, x):
x = F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
x = self.drop(x)
x = self.act(x) x = self.act(x)
return x return x

@ -0,0 +1,143 @@
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from .helpers import to_2tuple
from .weight_init import trunc_normal_
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
""" Compute relative logits along one dimension
As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
Args:
q: (batch, heads, height, width, dim)
rel_k: (2 * width - 1, dim)
permute_mask: permute output dim according to this
"""
B, H, W, dim = q.shape
x = (q @ rel_k.transpose(-1, -2))
x = x.reshape(-1, W, 2 * W -1)
# pad to shift from relative to absolute indexing
x_pad = F.pad(x, [0, 1]).flatten(1)
x_pad = F.pad(x_pad, [0, W - 1])
# reshape and slice out the padded elements
x_pad = x_pad.reshape(-1, W + 1, 2 * W - 1)
x = x_pad[:, :W, W - 1:]
# reshape and tile
x = x.reshape(B, H, 1, W, W).expand(-1, -1, H, -1, -1)
return x.permute(permute_mask)
class PosEmbedRel(nn.Module):
""" Relative Position Embedding
As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
"""
def __init__(self, feat_size, dim_head, scale):
super().__init__()
self.height, self.width = to_2tuple(feat_size)
self.dim_head = dim_head
self.scale = scale
self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * self.scale)
self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * self.scale)
def forward(self, q):
B, num_heads, HW, _ = q.shape
# relative logits in width dimension.
q = q.reshape(B * num_heads, self.height, self.width, -1)
rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4))
# relative logits in height dimension.
q = q.transpose(1, 2)
rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2))
rel_logits = rel_logits_h + rel_logits_w
rel_logits = rel_logits.reshape(B, num_heads, HW, HW)
return rel_logits
class BottleneckAttn(nn.Module):
""" Bottleneck Attention
Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605
"""
def __init__(self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, qkv_bias=False):
super().__init__()
assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required'
dim_out = dim_out or dim
assert dim_out % num_heads == 0
self.num_heads = num_heads
self.dim_out = dim_out
self.dim_head = dim_out // num_heads
self.scale = self.dim_head ** -0.5
self.qkv = nn.Conv2d(dim, self.dim_out * 3, 1, bias=qkv_bias)
# NOTE I'm only supporting relative pos embedding for now
self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head, scale=self.scale)
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
self.reset_parameters()
def reset_parameters(self):
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
trunc_normal_(self.pos_embed.height_rel, std=self.scale)
trunc_normal_(self.pos_embed.width_rel, std=self.scale)
def forward(self, x):
B, C, H, W = x.shape
assert H == self.pos_embed.height
assert W == self.pos_embed.width
x = self.qkv(x) # B, 3 * num_heads * dim_head, H, W
x = x.reshape(B, -1, self.dim_head, H * W).transpose(-1, -2)
q, k, v = torch.split(x, self.num_heads, dim=1)
attn_logits = (q @ k.transpose(-1, -2)) * self.scale
attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W
attn_out = attn_logits.softmax(dim=-1)
attn_out = (attn_out @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W
attn_out = self.pool(attn_out)
return attn_out
class PoolingAttention(nn.Module):
def __init__(self, in_features: int, attention_features: int, segments: int, max_pool_kernel: int):
super(PoolingAttention, self).__init__()
self.attn = nn.Linear(in_features, attention_features * 5)
self.segments = segments
self.max_pool_kernel = max_pool_kernel
def forward(self, inp: torch.Tensor): # Shape: [Batch, Sequence, Features]
batch, sequence, features = inp.size()
assert sequence % self.segments == 0
qry, key, val, seg, loc = self.attn(inp).chunk(5, 2) # 5x Shape: [Batch, Sequence, AttentionFeatures]
aggregated = qry.mean(1, keepdim=True) # Shape: [Batch, AttentionFeatures]
aggregated = torch.einsum("ba,bsa->bs", aggregated, key) # Shape: [Batch, Sequence]
aggregated = F.softmax(aggregated, 1)
aggregated = torch.einsum("bs,bsa,bza->bza", aggregated, val,
qry) # Shape: [Batch, Sequence, AttentionFeatures]
pooled_sequence = sequence // self.segments
segment_max_pooled = seg.view(batch, pooled_sequence, self.segments, -1)
segment_max_pooled = segment_max_pooled.max(2, keepdim=True) # Shape: [Batch, PooledSequence, 1, AttentionFeatures]
segment_max_pooled = segment_max_pooled * qry.view(batch, pooled_sequence, self.segments, -1) # Shape: [Batch, PooledSequence, PoolSize, AttentionFeatures]
segment_max_pooled = segment_max_pooled.view(batch, sequence, -1) # Shape: [Batch, Sequence, AttentionFeatures]
loc = loc.transpose(1, 2) # Shape: [Batch, AttentionFeatures, Sequence]
local_max_pooled = F.max_pool1d(loc, self.max_pool_kernel, 1, self.max_pool_kernel // 2)
local_max_pooled = local_max_pooled.transpose(1, 2) # Shape: [Batch, Sequence, AttentionFeatures]
return aggregated + segment_max_pooled + local_max_pooled

@ -7,7 +7,7 @@ Hacked together by / Copyright 2020 Ross Wightman
import torch import torch
from torch import nn as nn from torch import nn as nn
from .conv_bn_act import ConvBnAct from .conv_bn_act import ConvNormActAa
from .helpers import make_divisible from .helpers import make_divisible
from .trace_utils import _assert from .trace_utils import _assert
@ -20,8 +20,7 @@ def _kernel_valid(k):
class SelectiveKernelAttn(nn.Module): class SelectiveKernelAttn(nn.Module):
def __init__(self, channels, num_paths=2, attn_channels=32, def __init__(self, channels, num_paths=2, attn_channels=32, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
""" Selective Kernel Attention Module """ Selective Kernel Attention Module
Selective Kernel attention mechanism factored out into its own module. Selective Kernel attention mechanism factored out into its own module.
@ -51,7 +50,7 @@ class SelectiveKernel(nn.Module):
def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1, def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1,
rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True, rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True,
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None): act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_layer=None):
""" Selective Kernel Convolution Module """ Selective Kernel Convolution Module
As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications. As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications.
@ -72,9 +71,10 @@ class SelectiveKernel(nn.Module):
keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations
split_input (bool): split input channels evenly across each convolution branch, keeps param count lower, split_input (bool): split input channels evenly across each convolution branch, keeps param count lower,
can be viewed as grouping by path, output expands to module out_channels count can be viewed as grouping by path, output expands to module out_channels count
drop_block (nn.Module): drop block module
act_layer (nn.Module): activation layer to use act_layer (nn.Module): activation layer to use
norm_layer (nn.Module): batchnorm/norm layer to use norm_layer (nn.Module): batchnorm/norm layer to use
aa_layer (nn.Module): anti-aliasing module
drop_layer (nn.Module): spatial drop module in convs (drop block, etc)
""" """
super(SelectiveKernel, self).__init__() super(SelectiveKernel, self).__init__()
out_channels = out_channels or in_channels out_channels = out_channels or in_channels
@ -97,15 +97,14 @@ class SelectiveKernel(nn.Module):
groups = min(out_channels, groups) groups = min(out_channels, groups)
conv_kwargs = dict( conv_kwargs = dict(
stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, stride=stride, groups=groups, act_layer=act_layer, norm_layer=norm_layer,
aa_layer=aa_layer) aa_layer=aa_layer, drop_layer=drop_layer)
self.paths = nn.ModuleList([ self.paths = nn.ModuleList([
ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs) ConvNormActAa(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
for k, d in zip(kernel_size, dilation)]) for k, d in zip(kernel_size, dilation)])
attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor) attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor)
self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels) self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
self.drop_block = drop_block
def forward(self, x): def forward(self, x):
if self.split_input: if self.split_input:

@ -8,16 +8,16 @@ Hacked together by / Copyright 2020 Ross Wightman
from torch import nn as nn from torch import nn as nn
from .create_conv2d import create_conv2d from .create_conv2d import create_conv2d
from .create_norm_act import convert_norm_act from .create_norm_act import get_norm_act_layer
class SeparableConvBnAct(nn.Module): class SeparableConvNormAct(nn.Module):
""" Separable Conv w/ trailing Norm and Activation """ Separable Conv w/ trailing Norm and Activation
""" """
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False, def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU,
apply_act=True, drop_block=None): apply_act=True, drop_layer=None):
super(SeparableConvBnAct, self).__init__() super(SeparableConvNormAct, self).__init__()
self.conv_dw = create_conv2d( self.conv_dw = create_conv2d(
in_channels, int(in_channels * channel_multiplier), kernel_size, in_channels, int(in_channels * channel_multiplier), kernel_size,
@ -26,8 +26,9 @@ class SeparableConvBnAct(nn.Module):
self.conv_pw = create_conv2d( self.conv_pw = create_conv2d(
int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias) int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
norm_act_layer = convert_norm_act(norm_layer, act_layer) norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block) norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
@property @property
def in_channels(self): def in_channels(self):
@ -40,11 +41,13 @@ class SeparableConvBnAct(nn.Module):
def forward(self, x): def forward(self, x):
x = self.conv_dw(x) x = self.conv_dw(x)
x = self.conv_pw(x) x = self.conv_pw(x)
if self.bn is not None: x = self.bn(x)
x = self.bn(x)
return x return x
SeparableConvBnAct = SeparableConvNormAct
class SeparableConv2d(nn.Module): class SeparableConv2d(nn.Module):
""" Separable Conv """ Separable Conv
""" """

@ -35,11 +35,10 @@ class SplitAttn(nn.Module):
""" """
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None, def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None,
dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8, dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs): act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs):
super(SplitAttn, self).__init__() super(SplitAttn, self).__init__()
out_channels = out_channels or in_channels out_channels = out_channels or in_channels
self.radix = radix self.radix = radix
self.drop_block = drop_block
mid_chs = out_channels * radix mid_chs = out_channels * radix
if rd_channels is None: if rd_channels is None:
attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor) attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor)
@ -51,6 +50,7 @@ class SplitAttn(nn.Module):
in_channels, mid_chs, kernel_size, stride, padding, dilation, in_channels, mid_chs, kernel_size, stride, padding, dilation,
groups=groups * radix, bias=bias, **kwargs) groups=groups * radix, bias=bias, **kwargs)
self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity() self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity()
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
self.act0 = act_layer(inplace=True) self.act0 = act_layer(inplace=True)
self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups) self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)
self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity() self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity()
@ -61,8 +61,7 @@ class SplitAttn(nn.Module):
def forward(self, x): def forward(self, x):
x = self.conv(x) x = self.conv(x)
x = self.bn0(x) x = self.bn0(x)
if self.drop_block is not None: x = self.drop(x)
x = self.drop_block(x)
x = self.act0(x) x = self.act0(x)
B, RC, H, W = x.shape B, RC, H, W = x.shape

@ -20,7 +20,7 @@ from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficien
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .features import FeatureInfo, FeatureHooks from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg, default_cfg_for_features from .helpers import build_model_with_cfg, default_cfg_for_features
from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, get_norm_act_layer
from .registry import register_model from .registry import register_model
__all__ = ['MobileNetV3', 'MobileNetV3Features'] __all__ = ['MobileNetV3', 'MobileNetV3Features']
@ -95,6 +95,7 @@ class MobileNetV3(nn.Module):
super(MobileNetV3, self).__init__() super(MobileNetV3, self).__init__()
act_layer = act_layer or nn.ReLU act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d norm_layer = norm_layer or nn.BatchNorm2d
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
se_layer = se_layer or SqueezeExcite se_layer = se_layer or SqueezeExcite
self.num_classes = num_classes self.num_classes = num_classes
self.num_features = num_features self.num_features = num_features
@ -103,8 +104,7 @@ class MobileNetV3(nn.Module):
# Stem # Stem
stem_size = round_chs_fn(stem_size) stem_size = round_chs_fn(stem_size)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size) self.bn1 = norm_act_layer(stem_size, inplace=True)
self.act1 = act_layer(inplace=True)
# Middle stages (IR/ER/DS Blocks) # Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder( builder = EfficientNetBuilder(
@ -125,7 +125,7 @@ class MobileNetV3(nn.Module):
efficientnet_init_weights(self) efficientnet_init_weights(self)
def as_sequential(self): def as_sequential(self):
layers = [self.conv_stem, self.bn1, self.act1] layers = [self.conv_stem, self.bn1]
layers.extend(self.blocks) layers.extend(self.blocks)
layers.extend([self.global_pool, self.conv_head, self.act2]) layers.extend([self.global_pool, self.conv_head, self.act2])
layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
@ -144,7 +144,6 @@ class MobileNetV3(nn.Module):
def forward_features(self, x): def forward_features(self, x):
x = self.conv_stem(x) x = self.conv_stem(x)
x = self.bn1(x) x = self.bn1(x)
x = self.act1(x)
x = self.blocks(x) x = self.blocks(x)
x = self.global_pool(x) x = self.global_pool(x)
x = self.conv_head(x) x = self.conv_head(x)

@ -9,7 +9,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import ConvBnAct, create_conv2d, create_pool2d, create_classifier from .layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier
from .registry import register_model from .registry import register_model
__all__ = ['NASNetALarge'] __all__ = ['NASNetALarge']
@ -420,7 +420,7 @@ class NASNetALarge(nn.Module):
channels = self.num_features // 24 channels = self.num_features // 24
# 24 is default value for the architecture # 24 is default value for the architecture
self.conv0 = ConvBnAct( self.conv0 = ConvNormAct(
in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2, in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False) norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False)

@ -13,7 +13,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import ConvBnAct, create_conv2d, create_pool2d, create_classifier from .layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier
from .registry import register_model from .registry import register_model
__all__ = ['PNASNet5Large'] __all__ = ['PNASNet5Large']
@ -243,7 +243,7 @@ class PNASNet5Large(nn.Module):
self.num_features = 4320 self.num_features = 4320
assert output_stride == 32 assert output_stride == 32
self.conv_0 = ConvBnAct( self.conv_0 = ConvNormAct(
in_chans, 96, kernel_size=3, stride=2, padding=0, in_chans, 96, kernel_size=3, stride=2, padding=0,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False) norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False)

@ -15,45 +15,76 @@ Hacked together by / Copyright 2020 Ross Wightman
""" """
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
from dataclasses import dataclass
from functools import partial
from typing import Optional, Union, Callable
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg, named_apply
from .layers import ClassifierHead, AvgPool2dSame, ConvBnAct, SEModule, DropPath from .layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, get_act_layer, GroupNormAct
from .registry import register_model from .registry import register_model
def _mcfg(**kwargs): @dataclass
cfg = dict(se_ratio=0., bottle_ratio=1., stem_width=32) class RegNetCfg:
cfg.update(**kwargs) depth: int = 21
return cfg w0: int = 80
wa: float = 42.63
wm: float = 2.66
group_size: int = 24
bottle_ratio: float = 1.
se_ratio: float = 0.
stem_width: int = 32
downsample: Optional[str] = 'conv1x1'
linear_out: bool = False
act_layer: Union[str, Callable] = 'relu'
norm_layer: Union[str, Callable] = 'batchnorm'
# Model FLOPS = three trailing digits * 10^8 # Model FLOPS = three trailing digits * 10^8
model_cfgs = dict( model_cfgs = dict(
regnetx_002=_mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13), # RegNet-X
regnetx_004=_mcfg(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22), regnetx_002=RegNetCfg(w0=24, wa=36.44, wm=2.49, group_size=8, depth=13),
regnetx_006=_mcfg(w0=48, wa=36.97, wm=2.24, group_w=24, depth=16), regnetx_004=RegNetCfg(w0=24, wa=24.48, wm=2.54, group_size=16, depth=22),
regnetx_008=_mcfg(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16), regnetx_006=RegNetCfg(w0=48, wa=36.97, wm=2.24, group_size=24, depth=16),
regnetx_016=_mcfg(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18), regnetx_008=RegNetCfg(w0=56, wa=35.73, wm=2.28, group_size=16, depth=16),
regnetx_032=_mcfg(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25), regnetx_016=RegNetCfg(w0=80, wa=34.01, wm=2.25, group_size=24, depth=18),
regnetx_040=_mcfg(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23), regnetx_032=RegNetCfg(w0=88, wa=26.31, wm=2.25, group_size=48, depth=25),
regnetx_064=_mcfg(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17), regnetx_040=RegNetCfg(w0=96, wa=38.65, wm=2.43, group_size=40, depth=23),
regnetx_080=_mcfg(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23), regnetx_064=RegNetCfg(w0=184, wa=60.83, wm=2.07, group_size=56, depth=17),
regnetx_120=_mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19), regnetx_080=RegNetCfg(w0=80, wa=49.56, wm=2.88, group_size=120, depth=23),
regnetx_160=_mcfg(w0=216, wa=55.59, wm=2.1, group_w=128, depth=22), regnetx_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19),
regnetx_320=_mcfg(w0=320, wa=69.86, wm=2.0, group_w=168, depth=23), regnetx_160=RegNetCfg(w0=216, wa=55.59, wm=2.1, group_size=128, depth=22),
regnety_002=_mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13, se_ratio=0.25), regnetx_320=RegNetCfg(w0=320, wa=69.86, wm=2.0, group_size=168, depth=23),
regnety_004=_mcfg(w0=48, wa=27.89, wm=2.09, group_w=8, depth=16, se_ratio=0.25),
regnety_006=_mcfg(w0=48, wa=32.54, wm=2.32, group_w=16, depth=15, se_ratio=0.25), # RegNet-Y
regnety_008=_mcfg(w0=56, wa=38.84, wm=2.4, group_w=16, depth=14, se_ratio=0.25), regnety_002=RegNetCfg(w0=24, wa=36.44, wm=2.49, group_size=8, depth=13, se_ratio=0.25),
regnety_016=_mcfg(w0=48, wa=20.71, wm=2.65, group_w=24, depth=27, se_ratio=0.25), regnety_004=RegNetCfg(w0=48, wa=27.89, wm=2.09, group_size=8, depth=16, se_ratio=0.25),
regnety_032=_mcfg(w0=80, wa=42.63, wm=2.66, group_w=24, depth=21, se_ratio=0.25), regnety_006=RegNetCfg(w0=48, wa=32.54, wm=2.32, group_size=16, depth=15, se_ratio=0.25),
regnety_040=_mcfg(w0=96, wa=31.41, wm=2.24, group_w=64, depth=22, se_ratio=0.25), regnety_008=RegNetCfg(w0=56, wa=38.84, wm=2.4, group_size=16, depth=14, se_ratio=0.25),
regnety_064=_mcfg(w0=112, wa=33.22, wm=2.27, group_w=72, depth=25, se_ratio=0.25), regnety_016=RegNetCfg(w0=48, wa=20.71, wm=2.65, group_size=24, depth=27, se_ratio=0.25),
regnety_080=_mcfg(w0=192, wa=76.82, wm=2.19, group_w=56, depth=17, se_ratio=0.25), regnety_032=RegNetCfg(w0=80, wa=42.63, wm=2.66, group_size=24, depth=21, se_ratio=0.25),
regnety_120=_mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, se_ratio=0.25), regnety_040=RegNetCfg(w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25),
regnety_160=_mcfg(w0=200, wa=106.23, wm=2.48, group_w=112, depth=18, se_ratio=0.25), regnety_064=RegNetCfg(w0=112, wa=33.22, wm=2.27, group_size=72, depth=25, se_ratio=0.25),
regnety_320=_mcfg(w0=232, wa=115.89, wm=2.53, group_w=232, depth=20, se_ratio=0.25), regnety_080=RegNetCfg(w0=192, wa=76.82, wm=2.19, group_size=56, depth=17, se_ratio=0.25),
regnety_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19, se_ratio=0.25),
regnety_160=RegNetCfg(w0=200, wa=106.23, wm=2.48, group_size=112, depth=18, se_ratio=0.25),
regnety_320=RegNetCfg(w0=232, wa=115.89, wm=2.53, group_size=232, depth=20, se_ratio=0.25),
# Experimental
regnety_040s_gn=RegNetCfg(
w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25,
act_layer='silu', norm_layer=partial(GroupNormAct, group_size=16)),
# RegNet-Z (unverified)
regnetz_005=RegNetCfg(
depth=21, w0=16, wa=10.7, wm=2.51, group_size=4, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, linear_out=True, act_layer='silu',
),
regnetz_040=RegNetCfg(
depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, linear_out=True, act_layer='silu',
),
) )
@ -80,6 +111,7 @@ default_cfgs = dict(
regnetx_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth'), regnetx_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth'),
regnetx_160=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth'), regnetx_160=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth'),
regnetx_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth'), regnetx_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth'),
regnety_002=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth'), regnety_002=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth'),
regnety_004=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth'), regnety_004=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth'),
regnety_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth'), regnety_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth'),
@ -96,6 +128,11 @@ default_cfgs = dict(
url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth', # from Facebook DeiT GitHub repository url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth', # from Facebook DeiT GitHub repository
crop_pct=1.0, test_input_size=(3, 288, 288)), crop_pct=1.0, test_input_size=(3, 288, 288)),
regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'), regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'),
regnety_040s_gn=_cfg(url=''),
regnetz_005=_cfg(url=''),
regnetz_040=_cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
) )
@ -125,6 +162,40 @@ def generate_regnet(width_slope, width_initial, width_mult, depth, q=8):
return widths, num_stages, max_stage, widths_cont return widths, num_stages, max_stage, widths_cont
def downsample_conv(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None):
norm_layer = norm_layer or nn.BatchNorm2d
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
dilation = dilation if kernel_size > 1 else 1
return ConvNormAct(
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, apply_act=False)
def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None):
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
norm_layer = norm_layer or nn.BatchNorm2d
avg_stride = stride if dilation == 1 else 1
pool = nn.Identity()
if stride > 1 or dilation > 1:
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
return nn.Sequential(*[
pool, ConvNormAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, apply_act=False)])
def create_shortcut(downsample_type, in_chs, out_chs, kernel_size, stride, dilation=(1, 1), norm_layer=None):
assert downsample_type in ('avg', 'conv1x1', '', None)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
if not downsample_type:
return None # no shortcut, no downsample
elif downsample_type == 'avg':
return downsample_avg(in_chs, out_chs, stride=stride, dilation=dilation[0], norm_layer=norm_layer)
else:
return downsample_conv(
in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation[0], norm_layer=norm_layer)
else:
return nn.Identity() # identity shortcut (no downsample)
class Bottleneck(nn.Module): class Bottleneck(nn.Module):
""" RegNet Bottleneck """ RegNet Bottleneck
@ -132,97 +203,70 @@ class Bottleneck(nn.Module):
after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels. after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels.
""" """
def __init__(self, in_chs, out_chs, stride=1, dilation=1, bottleneck_ratio=1, group_width=1, se_ratio=0.25, def __init__(self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25,
downsample=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
drop_block=None, drop_path=None): drop_block=None, drop_path_rate=0.):
super(Bottleneck, self).__init__() super(Bottleneck, self).__init__()
bottleneck_chs = int(round(out_chs * bottleneck_ratio)) act_layer = get_act_layer(act_layer)
groups = bottleneck_chs // group_width bottleneck_chs = int(round(out_chs * bottle_ratio))
groups = bottleneck_chs // group_size
cargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block)
self.conv1 = ConvBnAct(in_chs, bottleneck_chs, kernel_size=1, **cargs) cargs = dict(act_layer=act_layer, norm_layer=norm_layer)
self.conv2 = ConvBnAct( self.conv1 = ConvNormAct(in_chs, bottleneck_chs, kernel_size=1, **cargs)
bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=dilation, self.conv2 = ConvNormAct(
groups=groups, **cargs) bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=dilation[0],
groups=groups, drop_layer=drop_block, **cargs)
if se_ratio: if se_ratio:
se_channels = int(round(in_chs * se_ratio)) se_channels = int(round(in_chs * se_ratio))
self.se = SEModule(bottleneck_chs, rd_channels=se_channels) self.se = SEModule(bottleneck_chs, rd_channels=se_channels, act_layer=act_layer)
else: else:
self.se = None self.se = nn.Identity()
cargs['act_layer'] = None self.conv3 = ConvNormAct(bottleneck_chs, out_chs, kernel_size=1, apply_act=False, **cargs)
self.conv3 = ConvBnAct(bottleneck_chs, out_chs, kernel_size=1, **cargs) self.act3 = nn.Identity() if linear_out else act_layer()
self.act3 = act_layer(inplace=True) self.downsample = create_shortcut(downsample, in_chs, out_chs, 1, stride, dilation, norm_layer=norm_layer)
self.downsample = downsample self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.drop_path = drop_path
def zero_init_last(self):
def zero_init_last_bn(self):
nn.init.zeros_(self.conv3.bn.weight) nn.init.zeros_(self.conv3.bn.weight)
def forward(self, x): def forward(self, x):
shortcut = x shortcut = x
x = self.conv1(x) x = self.conv1(x)
x = self.conv2(x) x = self.conv2(x)
if self.se is not None: x = self.se(x)
x = self.se(x)
x = self.conv3(x) x = self.conv3(x)
if self.drop_path is not None:
x = self.drop_path(x)
if self.downsample is not None: if self.downsample is not None:
shortcut = self.downsample(shortcut) # NOTE stuck with downsample as the attr name due to weight compatibility
x += shortcut # now represents the shortcut, no shortcut if None, and non-downsample shortcut == nn.Identity()
x = x + self.drop_path(self.downsample(shortcut))
x = self.act3(x) x = self.act3(x)
return x return x
def downsample_conv(
in_chs, out_chs, kernel_size, stride=1, dilation=1, norm_layer=None):
norm_layer = norm_layer or nn.BatchNorm2d
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
dilation = dilation if kernel_size > 1 else 1
return ConvBnAct(
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, act_layer=None)
def downsample_avg(
in_chs, out_chs, kernel_size, stride=1, dilation=1, norm_layer=None):
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
norm_layer = norm_layer or nn.BatchNorm2d
avg_stride = stride if dilation == 1 else 1
pool = nn.Identity()
if stride > 1 or dilation > 1:
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
return nn.Sequential(*[
pool, ConvBnAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, act_layer=None)])
class RegStage(nn.Module): class RegStage(nn.Module):
"""Stage (sequence of blocks w/ the same output shape).""" """Stage (sequence of blocks w/ the same output shape)."""
def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio, group_width, def __init__(
block_fn=Bottleneck, se_ratio=0., drop_path_rates=None, drop_block=None): self, depth, in_chs, out_chs, stride, dilation, bottle_ratio=1.0, group_size=8, block_fn=Bottleneck,
se_ratio=0., downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
drop_path_rates=None, drop_block=None):
super(RegStage, self).__init__() super(RegStage, self).__init__()
block_kwargs = {} # FIXME setup to pass various aa, norm, act layer common args block_kwargs = dict(
bottle_ratio=bottle_ratio, group_size=group_size, se_ratio=se_ratio, downsample=downsample,
linear_out=linear_out, act_layer=act_layer, norm_layer=norm_layer, drop_block=drop_block)
first_dilation = 1 if dilation in (1, 2) else 2 first_dilation = 1 if dilation in (1, 2) else 2
for i in range(depth): for i in range(depth):
block_stride = stride if i == 0 else 1 block_stride = stride if i == 0 else 1
block_in_chs = in_chs if i == 0 else out_chs block_in_chs = in_chs if i == 0 else out_chs
block_dilation = first_dilation if i == 0 else dilation block_dilation = (first_dilation, dilation)
if drop_path_rates is not None and drop_path_rates[i] > 0.: dpr = drop_path_rates[i] if drop_path_rates is not None else 0.
drop_path = DropPath(drop_path_rates[i])
else:
drop_path = None
if (block_in_chs != out_chs) or (block_stride != 1):
proj_block = downsample_conv(block_in_chs, out_chs, 1, block_stride, block_dilation)
else:
proj_block = None
name = "b{}".format(i + 1) name = "b{}".format(i + 1)
self.add_module( self.add_module(
name, block_fn( name, block_fn(
block_in_chs, out_chs, block_stride, block_dilation, bottle_ratio, group_width, se_ratio, block_in_chs, out_chs, stride=block_stride, dilation=block_dilation,
downsample=proj_block, drop_block=drop_block, drop_path=drop_path, **block_kwargs) drop_path_rate=dpr, **block_kwargs)
) )
first_dilation = dilation
def forward(self, x): def forward(self, x):
for block in self.children(): for block in self.children():
@ -231,33 +275,34 @@ class RegStage(nn.Module):
class RegNet(nn.Module): class RegNet(nn.Module):
"""RegNet model. """RegNet-X, Y, and Z Models
Paper: https://arxiv.org/abs/2003.13678 Paper: https://arxiv.org/abs/2003.13678
Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py
""" """
def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0., def __init__(
drop_path_rate=0., zero_init_last_bn=True): self, cfg: RegNetCfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg',
drop_rate=0., drop_path_rate=0., zero_init_last=True):
super().__init__() super().__init__()
# TODO add drop block, drop path, anti-aliasing, custom bn/act args
self.num_classes = num_classes self.num_classes = num_classes
self.drop_rate = drop_rate self.drop_rate = drop_rate
assert output_stride in (8, 16, 32) assert output_stride in (8, 16, 32)
# Construct the stem # Construct the stem
stem_width = cfg['stem_width'] stem_width = cfg.stem_width
self.stem = ConvBnAct(in_chans, stem_width, 3, stride=2) self.stem = ConvNormAct(in_chans, stem_width, 3, stride=2, act_layer=cfg.act_layer, norm_layer=cfg.norm_layer)
self.feature_info = [dict(num_chs=stem_width, reduction=2, module='stem')] self.feature_info = [dict(num_chs=stem_width, reduction=2, module='stem')]
# Construct the stages # Construct the stages
prev_width = stem_width prev_width = stem_width
curr_stride = 2 curr_stride = 2
stage_params = self._get_stage_params(cfg, output_stride=output_stride, drop_path_rate=drop_path_rate) stage_params = self._get_stage_params(cfg, output_stride=output_stride, drop_path_rate=drop_path_rate)
se_ratio = cfg['se_ratio']
for i, stage_args in enumerate(stage_params): for i, stage_args in enumerate(stage_params):
stage_name = "s{}".format(i + 1) stage_name = "s{}".format(i + 1)
self.add_module(stage_name, RegStage(prev_width, **stage_args, se_ratio=se_ratio)) self.add_module(stage_name, RegStage(
in_chs=prev_width, se_ratio=cfg.se_ratio, downsample=cfg.downsample, linear_out=cfg.linear_out,
act_layer=cfg.act_layer, norm_layer=cfg.norm_layer, **stage_args))
prev_width = stage_args['out_chs'] prev_width = stage_args['out_chs']
curr_stride *= stage_args['stride'] curr_stride *= stage_args['stride']
self.feature_info += [dict(num_chs=prev_width, reduction=curr_stride, module=stage_name)] self.feature_info += [dict(num_chs=prev_width, reduction=curr_stride, module=stage_name)]
@ -267,31 +312,18 @@ class RegNet(nn.Module):
self.head = ClassifierHead( self.head = ClassifierHead(
in_chs=prev_width, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) in_chs=prev_width, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
for m in self.modules(): named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') def _get_stage_params(self, cfg: RegNetCfg, default_stride=2, output_stride=32, drop_path_rate=0.):
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0.0, std=0.01)
nn.init.zeros_(m.bias)
if zero_init_last_bn:
for m in self.modules():
if hasattr(m, 'zero_init_last_bn'):
m.zero_init_last_bn()
def _get_stage_params(self, cfg, default_stride=2, output_stride=32, drop_path_rate=0.):
# Generate RegNet ws per block # Generate RegNet ws per block
w_a, w_0, w_m, d = cfg['wa'], cfg['w0'], cfg['wm'], cfg['depth'] widths, num_stages, _, _ = generate_regnet(cfg.wa, cfg.w0, cfg.wm, cfg.depth)
widths, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d)
# Convert to per stage format # Convert to per stage format
stage_widths, stage_depths = np.unique(widths, return_counts=True) stage_widths, stage_depths = np.unique(widths, return_counts=True)
# Use the same group width, bottleneck mult and stride for each stage # Use the same group width, bottleneck mult and stride for each stage
stage_groups = [cfg['group_w'] for _ in range(num_stages)] stage_groups = [cfg.group_size for _ in range(num_stages)]
stage_bottle_ratios = [cfg['bottle_ratio'] for _ in range(num_stages)] stage_bottle_ratios = [cfg.bottle_ratio for _ in range(num_stages)]
stage_strides = [] stage_strides = []
stage_dilations = [] stage_dilations = []
net_stride = 2 net_stride = 2
@ -305,11 +337,11 @@ class RegNet(nn.Module):
net_stride *= stride net_stride *= stride
stage_strides.append(stride) stage_strides.append(stride)
stage_dilations.append(dilation) stage_dilations.append(dilation)
stage_dpr = np.split(np.linspace(0, drop_path_rate, d), np.cumsum(stage_depths[:-1])) stage_dpr = np.split(np.linspace(0, drop_path_rate, cfg.depth), np.cumsum(stage_depths[:-1]))
# Adjust the compatibility of ws and gws # Adjust the compatibility of ws and gws
stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups) stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups)
param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width', 'drop_path_rates'] param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_size', 'drop_path_rates']
stage_params = [ stage_params = [
dict(zip(param_names, params)) for params in dict(zip(param_names, params)) for params in
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups, zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups,
@ -333,6 +365,19 @@ class RegNet(nn.Module):
return x return x
def _init_weights(module, name='', zero_init_last=False):
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(module, nn.BatchNorm2d):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.01)
nn.init.zeros_(module.bias)
elif hasattr(module, 'zero_init_last'):
module.zero_init_last()
def _filter_fn(state_dict): def _filter_fn(state_dict):
""" convert patch embedding weight from manual patchify + linear proj to conv""" """ convert patch embedding weight from manual patchify + linear proj to conv"""
if 'model' in state_dict: if 'model' in state_dict:
@ -492,3 +537,27 @@ def regnety_160(pretrained=False, **kwargs):
def regnety_320(pretrained=False, **kwargs): def regnety_320(pretrained=False, **kwargs):
"""RegNetY-32GF""" """RegNetY-32GF"""
return _create_regnet('regnety_320', pretrained, **kwargs) return _create_regnet('regnety_320', pretrained, **kwargs)
@register_model
def regnety_040s_gn(pretrained=False, **kwargs):
"""RegNetY-4.0GF w/ GroupNorm """
return _create_regnet('regnety_040s_gn', pretrained, **kwargs)
@register_model
def regnetz_005(pretrained=False, **kwargs):
"""RegNetZ-500MF
NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py
but it's not clear it is equivalent to paper model as not detailed in the paper.
"""
return _create_regnet('regnetz_005', pretrained, **kwargs)
@register_model
def regnetz_040(pretrained=False, **kwargs):
"""RegNetZ-4.0GF
NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py
but it's not clear it is equivalent to paper model as not detailed in the paper.
"""
return _create_regnet('regnetz_040', pretrained, **kwargs)

@ -75,7 +75,6 @@ class ResNestBottleneck(nn.Module):
else: else:
avd_stride = 0 avd_stride = 0
self.radix = radix self.radix = radix
self.drop_block = drop_block
self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False) self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
self.bn1 = norm_layer(group_width) self.bn1 = norm_layer(group_width)
@ -85,14 +84,16 @@ class ResNestBottleneck(nn.Module):
if self.radix >= 1: if self.radix >= 1:
self.conv2 = SplitAttn( self.conv2 = SplitAttn(
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block) dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_layer=drop_block)
self.bn2 = nn.Identity() self.bn2 = nn.Identity()
self.drop_block = nn.Identity()
self.act2 = nn.Identity() self.act2 = nn.Identity()
else: else:
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation, group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
dilation=first_dilation, groups=cardinality, bias=False) dilation=first_dilation, groups=cardinality, bias=False)
self.bn2 = norm_layer(group_width) self.bn2 = norm_layer(group_width)
self.drop_block = drop_block() if drop_block is not None else nn.Identity()
self.act2 = act_layer(inplace=True) self.act2 = act_layer(inplace=True)
self.avd_last = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and not avd_first else None self.avd_last = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and not avd_first else None
@ -109,8 +110,6 @@ class ResNestBottleneck(nn.Module):
out = self.conv1(x) out = self.conv1(x)
out = self.bn1(out) out = self.bn1(out)
if self.drop_block is not None:
out = self.drop_block(out)
out = self.act1(out) out = self.act1(out)
if self.avd_first is not None: if self.avd_first is not None:
@ -118,8 +117,7 @@ class ResNestBottleneck(nn.Module):
out = self.conv2(out) out = self.conv2(out)
out = self.bn2(out) out = self.bn2(out)
if self.drop_block is not None: out = self.drop_block(out)
out = self.drop_block(out)
out = self.act2(out) out = self.act2(out)
if self.avd_last is not None: if self.avd_last is not None:
@ -127,8 +125,6 @@ class ResNestBottleneck(nn.Module):
out = self.conv3(out) out = self.conv3(out)
out = self.bn3(out) out = self.bn3(out)
if self.drop_block is not None:
out = self.drop_block(out)
if self.downsample is not None: if self.downsample is not None:
shortcut = self.downsample(x) shortcut = self.downsample(x)

@ -307,8 +307,9 @@ class BasicBlock(nn.Module):
inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation, inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation,
dilation=first_dilation, bias=False) dilation=first_dilation, bias=False)
self.bn1 = norm_layer(first_planes) self.bn1 = norm_layer(first_planes)
self.drop_block = drop_block() if drop_block is not None else nn.Identity()
self.act1 = act_layer(inplace=True) self.act1 = act_layer(inplace=True)
self.aa = aa_layer(channels=first_planes, stride=stride) if use_aa else None self.aa = aa_layer(channels=first_planes, stride=stride) if use_aa else nn.Identity()
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False) first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
@ -320,7 +321,6 @@ class BasicBlock(nn.Module):
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
self.dilation = dilation self.dilation = dilation
self.drop_block = drop_block
self.drop_path = drop_path self.drop_path = drop_path
def zero_init_last_bn(self): def zero_init_last_bn(self):
@ -331,16 +331,12 @@ class BasicBlock(nn.Module):
x = self.conv1(x) x = self.conv1(x)
x = self.bn1(x) x = self.bn1(x)
if self.drop_block is not None: x = self.drop_block(x)
x = self.drop_block(x)
x = self.act1(x) x = self.act1(x)
if self.aa is not None: x = self.aa(x)
x = self.aa(x)
x = self.conv2(x) x = self.conv2(x)
x = self.bn2(x) x = self.bn2(x)
if self.drop_block is not None:
x = self.drop_block(x)
if self.se is not None: if self.se is not None:
x = self.se(x) x = self.se(x)
@ -378,8 +374,9 @@ class Bottleneck(nn.Module):
first_planes, width, kernel_size=3, stride=1 if use_aa else stride, first_planes, width, kernel_size=3, stride=1 if use_aa else stride,
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False) padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
self.bn2 = norm_layer(width) self.bn2 = norm_layer(width)
self.drop_block = drop_block() if drop_block is not None else nn.Identity()
self.act2 = act_layer(inplace=True) self.act2 = act_layer(inplace=True)
self.aa = aa_layer(channels=width, stride=stride) if use_aa else None self.aa = aa_layer(channels=width, stride=stride) if use_aa else nn.Identity()
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False) self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes) self.bn3 = norm_layer(outplanes)
@ -390,7 +387,6 @@ class Bottleneck(nn.Module):
self.downsample = downsample self.downsample = downsample
self.stride = stride self.stride = stride
self.dilation = dilation self.dilation = dilation
self.drop_block = drop_block
self.drop_path = drop_path self.drop_path = drop_path
def zero_init_last_bn(self): def zero_init_last_bn(self):
@ -401,22 +397,16 @@ class Bottleneck(nn.Module):
x = self.conv1(x) x = self.conv1(x)
x = self.bn1(x) x = self.bn1(x)
if self.drop_block is not None:
x = self.drop_block(x)
x = self.act1(x) x = self.act1(x)
x = self.conv2(x) x = self.conv2(x)
x = self.bn2(x) x = self.bn2(x)
if self.drop_block is not None: x = self.drop_block(x)
x = self.drop_block(x)
x = self.act2(x) x = self.act2(x)
if self.aa is not None: x = self.aa(x)
x = self.aa(x)
x = self.conv3(x) x = self.conv3(x)
x = self.bn3(x) x = self.bn3(x)
if self.drop_block is not None:
x = self.drop_block(x)
if self.se is not None: if self.se is not None:
x = self.se(x) x = self.se(x)
@ -463,11 +453,11 @@ def downsample_avg(
]) ])
def drop_blocks(drop_block_rate=0.): def drop_blocks(drop_prob=0.):
return [ return [
None, None, None, None,
DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None, partial(DropBlock2d, drop_prob=drop_prob, block_size=5, gamma_scale=0.25) if drop_prob else None,
DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None] partial(DropBlock2d, drop_prob=drop_prob, block_size=3, gamma_scale=1.00) if drop_prob else None]
def make_blocks( def make_blocks(

@ -17,7 +17,7 @@ from math import ceil
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath, make_divisible, SEModule from .layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule
from .registry import register_model from .registry import register_model
from .efficientnet_builder import efficientnet_init_weights from .efficientnet_builder import efficientnet_init_weights
@ -63,19 +63,19 @@ class LinearBottleneck(nn.Module):
if exp_ratio != 1.: if exp_ratio != 1.:
dw_chs = make_divisible(round(in_chs * exp_ratio), divisor=ch_div) dw_chs = make_divisible(round(in_chs * exp_ratio), divisor=ch_div)
self.conv_exp = ConvBnAct(in_chs, dw_chs, act_layer=act_layer) self.conv_exp = ConvNormAct(in_chs, dw_chs, act_layer=act_layer)
else: else:
dw_chs = in_chs dw_chs = in_chs
self.conv_exp = None self.conv_exp = None
self.conv_dw = ConvBnAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False) self.conv_dw = ConvNormAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False)
if se_ratio > 0: if se_ratio > 0:
self.se = SEWithNorm(dw_chs, rd_channels=make_divisible(int(dw_chs * se_ratio), ch_div)) self.se = SEWithNorm(dw_chs, rd_channels=make_divisible(int(dw_chs * se_ratio), ch_div))
else: else:
self.se = None self.se = None
self.act_dw = create_act_layer(dw_act_layer) self.act_dw = create_act_layer(dw_act_layer)
self.conv_pwl = ConvBnAct(dw_chs, out_chs, 1, apply_act=False) self.conv_pwl = ConvNormAct(dw_chs, out_chs, 1, apply_act=False)
self.drop_path = drop_path self.drop_path = drop_path
def feat_channels(self, exp=False): def feat_channels(self, exp=False):
@ -138,7 +138,7 @@ def _build_blocks(
feat_chs += [features[-1].feat_channels()] feat_chs += [features[-1].feat_channels()]
pen_chs = make_divisible(1280 * width_mult, divisor=ch_div) pen_chs = make_divisible(1280 * width_mult, divisor=ch_div)
feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=f'features.{len(features) - 1}')] feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=f'features.{len(features) - 1}')]
features.append(ConvBnAct(prev_chs, pen_chs, act_layer=act_layer)) features.append(ConvNormAct(prev_chs, pen_chs, act_layer=act_layer))
return features, feature_info return features, feature_info
@ -153,7 +153,7 @@ class ReXNetV1(nn.Module):
assert output_stride == 32 # FIXME support dilation assert output_stride == 32 # FIXME support dilation
stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32 stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32
stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div) stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div)
self.stem = ConvBnAct(in_chans, stem_chs, 3, stride=2, act_layer=act_layer) self.stem = ConvNormAct(in_chans, stem_chs, 3, stride=2, act_layer=act_layer)
block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, se_ratio, ch_div) block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, se_ratio, ch_div)
features, self.feature_info = _build_blocks( features, self.feature_info = _build_blocks(

@ -14,7 +14,7 @@ from torch import nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import SelectiveKernel, ConvBnAct, create_attn from .layers import SelectiveKernel, ConvNormAct, ConvNormActAa, create_attn
from .registry import register_model from .registry import register_model
from .resnet import ResNet from .resnet import ResNet
@ -52,7 +52,7 @@ class SelectiveKernelBasic(nn.Module):
super(SelectiveKernelBasic, self).__init__() super(SelectiveKernelBasic, self).__init__()
sk_kwargs = sk_kwargs or {} sk_kwargs = sk_kwargs or {}
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer) conv_kwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
assert cardinality == 1, 'BasicBlock only supports cardinality of 1' assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
assert base_width == 64, 'BasicBlock doest not support changing base width' assert base_width == 64, 'BasicBlock doest not support changing base width'
first_planes = planes // reduce_first first_planes = planes // reduce_first
@ -60,16 +60,13 @@ class SelectiveKernelBasic(nn.Module):
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
self.conv1 = SelectiveKernel( self.conv1 = SelectiveKernel(
inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs) inplanes, first_planes, stride=stride, dilation=first_dilation,
conv_kwargs['act_layer'] = None aa_layer=aa_layer, drop_layer=drop_block, **conv_kwargs, **sk_kwargs)
self.conv2 = ConvBnAct( self.conv2 = ConvNormAct(
first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs) first_planes, outplanes, kernel_size=3, dilation=dilation, apply_act=False, **conv_kwargs)
self.se = create_attn(attn_layer, outplanes) self.se = create_attn(attn_layer, outplanes)
self.act = act_layer(inplace=True) self.act = act_layer(inplace=True)
self.downsample = downsample self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.drop_block = drop_block
self.drop_path = drop_path self.drop_path = drop_path
def zero_init_last_bn(self): def zero_init_last_bn(self):
@ -100,24 +97,20 @@ class SelectiveKernelBottleneck(nn.Module):
super(SelectiveKernelBottleneck, self).__init__() super(SelectiveKernelBottleneck, self).__init__()
sk_kwargs = sk_kwargs or {} sk_kwargs = sk_kwargs or {}
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer) conv_kwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
width = int(math.floor(planes * (base_width / 64)) * cardinality) width = int(math.floor(planes * (base_width / 64)) * cardinality)
first_planes = width // reduce_first first_planes = width // reduce_first
outplanes = planes * self.expansion outplanes = planes * self.expansion
first_dilation = first_dilation or dilation first_dilation = first_dilation or dilation
self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs) self.conv1 = ConvNormAct(inplanes, first_planes, kernel_size=1, **conv_kwargs)
self.conv2 = SelectiveKernel( self.conv2 = SelectiveKernel(
first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality, first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality,
**conv_kwargs, **sk_kwargs) aa_layer=aa_layer, drop_layer=drop_block, **conv_kwargs, **sk_kwargs)
conv_kwargs['act_layer'] = None self.conv3 = ConvNormAct(width, outplanes, kernel_size=1, apply_act=False, **conv_kwargs)
self.conv3 = ConvBnAct(width, outplanes, kernel_size=1, **conv_kwargs)
self.se = create_attn(attn_layer, outplanes) self.se = create_attn(attn_layer, outplanes)
self.act = act_layer(inplace=True) self.act = act_layer(inplace=True)
self.downsample = downsample self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.drop_block = drop_block
self.drop_path = drop_path self.drop_path = drop_path
def zero_init_last_bn(self): def zero_init_last_bn(self):

@ -20,8 +20,8 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .registry import register_model from .registry import register_model
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, ClassifierHead, DropPath,\ from .layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath,\
create_attn, create_norm_act, get_norm_act_layer create_attn, create_norm_act_layer, get_norm_act_layer
# model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 & # model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 &
@ -189,23 +189,23 @@ class OsaBlock(nn.Module):
next_in_chs = in_chs next_in_chs = in_chs
if self.depthwise and next_in_chs != mid_chs: if self.depthwise and next_in_chs != mid_chs:
assert not residual assert not residual
self.conv_reduction = ConvBnAct(next_in_chs, mid_chs, 1, **conv_kwargs) self.conv_reduction = ConvNormAct(next_in_chs, mid_chs, 1, **conv_kwargs)
else: else:
self.conv_reduction = None self.conv_reduction = None
mid_convs = [] mid_convs = []
for i in range(layer_per_block): for i in range(layer_per_block):
if self.depthwise: if self.depthwise:
conv = SeparableConvBnAct(mid_chs, mid_chs, **conv_kwargs) conv = SeparableConvNormAct(mid_chs, mid_chs, **conv_kwargs)
else: else:
conv = ConvBnAct(next_in_chs, mid_chs, 3, **conv_kwargs) conv = ConvNormAct(next_in_chs, mid_chs, 3, **conv_kwargs)
next_in_chs = mid_chs next_in_chs = mid_chs
mid_convs.append(conv) mid_convs.append(conv)
self.conv_mid = SequentialAppendList(*mid_convs) self.conv_mid = SequentialAppendList(*mid_convs)
# feature aggregation # feature aggregation
next_in_chs = in_chs + layer_per_block * mid_chs next_in_chs = in_chs + layer_per_block * mid_chs
self.conv_concat = ConvBnAct(next_in_chs, out_chs, **conv_kwargs) self.conv_concat = ConvNormAct(next_in_chs, out_chs, **conv_kwargs)
if attn: if attn:
self.attn = create_attn(attn, out_chs) self.attn = create_attn(attn, out_chs)
@ -283,9 +283,9 @@ class VovNet(nn.Module):
# Stem module # Stem module
last_stem_stride = stem_stride // 2 last_stem_stride = stem_stride // 2
conv_type = SeparableConvBnAct if cfg["depthwise"] else ConvBnAct conv_type = SeparableConvNormAct if cfg["depthwise"] else ConvNormAct
self.stem = nn.Sequential(*[ self.stem = nn.Sequential(*[
ConvBnAct(in_chans, stem_chs[0], 3, stride=2, **conv_kwargs), ConvNormAct(in_chans, stem_chs[0], 3, stride=2, **conv_kwargs),
conv_type(stem_chs[0], stem_chs[1], 3, stride=1, **conv_kwargs), conv_type(stem_chs[0], stem_chs[1], 3, stride=1, **conv_kwargs),
conv_type(stem_chs[1], stem_chs[2], 3, stride=last_stem_stride, **conv_kwargs), conv_type(stem_chs[1], stem_chs[2], 3, stride=last_stem_stride, **conv_kwargs),
]) ])
@ -395,12 +395,12 @@ def eca_vovnet39b(pretrained=False, **kwargs):
@register_model @register_model
def ese_vovnet39b_evos(pretrained=False, **kwargs): def ese_vovnet39b_evos(pretrained=False, **kwargs):
def norm_act_fn(num_features, **nkwargs): def norm_act_fn(num_features, **nkwargs):
return create_norm_act('evonorms0', num_features, jit=False, **nkwargs) return create_norm_act_layer('evonorms0', num_features, jit=False, **nkwargs)
return _create_vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs) return _create_vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs)
@register_model @register_model
def ese_vovnet99b_iabn(pretrained=False, **kwargs): def ese_vovnet99b_iabn(pretrained=False, **kwargs):
norm_layer = get_norm_act_layer('iabn') norm_layer = get_norm_act_layer('iabn', act_layer='leaky_relu')
return _create_vovnet( return _create_vovnet(
'ese_vovnet99b_iabn', pretrained=pretrained, norm_layer=norm_layer, act_layer=nn.LeakyReLU, **kwargs) 'ese_vovnet99b_iabn', pretrained=pretrained, norm_layer=norm_layer, act_layer=nn.LeakyReLU, **kwargs)

@ -12,7 +12,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg from .helpers import build_model_with_cfg
from .layers import ClassifierHead, ConvBnAct, create_conv2d from .layers import ClassifierHead, ConvNormAct, create_conv2d, get_norm_act_layer
from .layers.helpers import to_3tuple from .layers.helpers import to_3tuple
from .registry import register_model from .registry import register_model
@ -37,12 +37,14 @@ default_cfgs = dict(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_65-c9ae96e8.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_65-c9ae96e8.pth'),
xception71=_cfg( xception71=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'), url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'),
xception41p=_cfg(url=''),
) )
class SeparableConv2d(nn.Module): class SeparableConv2d(nn.Module):
def __init__( def __init__(
self, inplanes, planes, kernel_size=3, stride=1, dilation=1, padding='', self, in_chs, out_chs, kernel_size=3, stride=1, dilation=1, padding='',
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d): act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(SeparableConv2d, self).__init__() super(SeparableConv2d, self).__init__()
self.kernel_size = kernel_size self.kernel_size = kernel_size
@ -50,31 +52,48 @@ class SeparableConv2d(nn.Module):
# depthwise convolution # depthwise convolution
self.conv_dw = create_conv2d( self.conv_dw = create_conv2d(
inplanes, inplanes, kernel_size, stride=stride, in_chs, in_chs, kernel_size, stride=stride,
padding=padding, dilation=dilation, depthwise=True) padding=padding, dilation=dilation, depthwise=True)
self.bn_dw = norm_layer(inplanes) self.bn_dw = norm_layer(in_chs)
if act_layer is not None: self.act_dw = act_layer(inplace=True) if act_layer is not None else nn.Identity()
self.act_dw = act_layer(inplace=True)
else:
self.act_dw = None
# pointwise convolution # pointwise convolution
self.conv_pw = create_conv2d(inplanes, planes, kernel_size=1) self.conv_pw = create_conv2d(in_chs, out_chs, kernel_size=1)
self.bn_pw = norm_layer(planes) self.bn_pw = norm_layer(out_chs)
if act_layer is not None: self.act_pw = act_layer(inplace=True) if act_layer is not None else nn.Identity()
self.act_pw = act_layer(inplace=True)
else:
self.act_pw = None
def forward(self, x): def forward(self, x):
x = self.conv_dw(x) x = self.conv_dw(x)
x = self.bn_dw(x) x = self.bn_dw(x)
if self.act_dw is not None: x = self.act_dw(x)
x = self.act_dw(x)
x = self.conv_pw(x) x = self.conv_pw(x)
x = self.bn_pw(x) x = self.bn_pw(x)
if self.act_pw is not None: x = self.act_pw(x)
x = self.act_pw(x) return x
class PreSeparableConv2d(nn.Module):
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=1, padding='',
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, first_act=True):
super(PreSeparableConv2d, self).__init__()
norm_act_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
self.kernel_size = kernel_size
self.dilation = dilation
self.norm = norm_act_layer(in_chs, inplace=True) if first_act else nn.Identity()
# depthwise convolution
self.conv_dw = create_conv2d(
in_chs, in_chs, kernel_size, stride=stride,
padding=padding, dilation=dilation, depthwise=True)
# pointwise convolution
self.conv_pw = create_conv2d(in_chs, out_chs, kernel_size=1)
def forward(self, x):
x = self.norm(x)
x = self.conv_dw(x)
x = self.conv_pw(x)
return x return x
@ -88,8 +107,8 @@ class XceptionModule(nn.Module):
self.out_channels = out_chs[-1] self.out_channels = out_chs[-1]
self.no_skip = no_skip self.no_skip = no_skip
if not no_skip and (self.out_channels != self.in_channels or stride != 1): if not no_skip and (self.out_channels != self.in_channels or stride != 1):
self.shortcut = ConvBnAct( self.shortcut = ConvNormAct(
in_chs, self.out_channels, 1, stride=stride, norm_layer=norm_layer, act_layer=None) in_chs, self.out_channels, 1, stride=stride, norm_layer=norm_layer, apply_act=False)
else: else:
self.shortcut = None self.shortcut = None
@ -97,7 +116,7 @@ class XceptionModule(nn.Module):
self.stack = nn.Sequential() self.stack = nn.Sequential()
for i in range(3): for i in range(3):
if start_with_relu: if start_with_relu:
self.stack.add_module(f'act{i + 1}', nn.ReLU(inplace=i > 0)) self.stack.add_module(f'act{i + 1}', act_layer(inplace=i > 0))
self.stack.add_module(f'conv{i + 1}', SeparableConv2d( self.stack.add_module(f'conv{i + 1}', SeparableConv2d(
in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type, in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type,
act_layer=separable_act_layer, norm_layer=norm_layer)) act_layer=separable_act_layer, norm_layer=norm_layer))
@ -113,11 +132,42 @@ class XceptionModule(nn.Module):
return x return x
class PreXceptionModule(nn.Module):
def __init__(
self, in_chs, out_chs, stride=1, dilation=1, pad_type='',
no_skip=False, act_layer=nn.ReLU, norm_layer=None):
super(PreXceptionModule, self).__init__()
out_chs = to_3tuple(out_chs)
self.in_channels = in_chs
self.out_channels = out_chs[-1]
self.no_skip = no_skip
if not no_skip and (self.out_channels != self.in_channels or stride != 1):
self.shortcut = create_conv2d(in_chs, self.out_channels, 1, stride=stride)
else:
self.shortcut = nn.Identity()
self.norm = get_norm_act_layer(norm_layer, act_layer=act_layer)(in_chs, inplace=True)
self.stack = nn.Sequential()
for i in range(3):
self.stack.add_module(f'conv{i + 1}', PreSeparableConv2d(
in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type,
act_layer=act_layer, norm_layer=norm_layer, first_act=i > 0))
in_chs = out_chs[i]
def forward(self, x):
x = self.norm(x)
skip = x
x = self.stack(x)
if not self.no_skip:
x = x + self.shortcut(skip)
return x
class XceptionAligned(nn.Module): class XceptionAligned(nn.Module):
"""Modified Aligned Xception """Modified Aligned Xception
""" """
def __init__(self, block_cfg, num_classes=1000, in_chans=3, output_stride=32, def __init__(self, block_cfg, num_classes=1000, in_chans=3, output_stride=32, preact=False,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0., global_pool='avg'): act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0., global_pool='avg'):
super(XceptionAligned, self).__init__() super(XceptionAligned, self).__init__()
self.num_classes = num_classes self.num_classes = num_classes
@ -126,31 +176,33 @@ class XceptionAligned(nn.Module):
layer_args = dict(act_layer=act_layer, norm_layer=norm_layer) layer_args = dict(act_layer=act_layer, norm_layer=norm_layer)
self.stem = nn.Sequential(*[ self.stem = nn.Sequential(*[
ConvBnAct(in_chans, 32, kernel_size=3, stride=2, **layer_args), ConvNormAct(in_chans, 32, kernel_size=3, stride=2, **layer_args),
ConvBnAct(32, 64, kernel_size=3, stride=1, **layer_args) create_conv2d(32, 64, kernel_size=3, stride=1) if preact else
ConvNormAct(32, 64, kernel_size=3, stride=1, **layer_args)
]) ])
curr_dilation = 1 curr_dilation = 1
curr_stride = 2 curr_stride = 2
self.feature_info = [] self.feature_info = []
self.blocks = nn.Sequential() self.blocks = nn.Sequential()
module_fn = PreXceptionModule if preact else XceptionModule
for i, b in enumerate(block_cfg): for i, b in enumerate(block_cfg):
b['dilation'] = curr_dilation b['dilation'] = curr_dilation
if b['stride'] > 1: if b['stride'] > 1:
self.feature_info += [dict( name = f'blocks.{i}.stack.conv2' if preact else f'blocks.{i}.stack.act3'
num_chs=to_3tuple(b['out_chs'])[-2], reduction=curr_stride, module=f'blocks.{i}.stack.act3')] self.feature_info += [dict(num_chs=to_3tuple(b['out_chs'])[-2], reduction=curr_stride, module=name)]
next_stride = curr_stride * b['stride'] next_stride = curr_stride * b['stride']
if next_stride > output_stride: if next_stride > output_stride:
curr_dilation *= b['stride'] curr_dilation *= b['stride']
b['stride'] = 1 b['stride'] = 1
else: else:
curr_stride = next_stride curr_stride = next_stride
self.blocks.add_module(str(i), XceptionModule(**b, **layer_args)) self.blocks.add_module(str(i), module_fn(**b, **layer_args))
self.num_features = self.blocks[-1].out_channels self.num_features = self.blocks[-1].out_channels
self.feature_info += [dict( self.feature_info += [dict(
num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))] num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))]
self.act = act_layer(inplace=True) if preact else nn.Identity()
self.head = ClassifierHead( self.head = ClassifierHead(
in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
@ -163,6 +215,7 @@ class XceptionAligned(nn.Module):
def forward_features(self, x): def forward_features(self, x):
x = self.stem(x) x = self.stem(x)
x = self.blocks(x) x = self.blocks(x)
x = self.act(x)
return x return x
def forward(self, x): def forward(self, x):
@ -236,3 +289,22 @@ def xception71(pretrained=False, **kwargs):
] ]
model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs) model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs)
return _xception('xception71', pretrained=pretrained, **model_args) return _xception('xception71', pretrained=pretrained, **model_args)
@register_model
def xception41p(pretrained=False, **kwargs):
""" Modified Aligned Xception-41 w/ Pre-Act
"""
block_cfg = [
# entry flow
dict(in_chs=64, out_chs=128, stride=2),
dict(in_chs=128, out_chs=256, stride=2),
dict(in_chs=256, out_chs=728, stride=2),
# middle flow
*([dict(in_chs=728, out_chs=728, stride=1)] * 8),
# exit flow
dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
dict(in_chs=1024, out_chs=(1536, 1536, 2048), no_skip=True, stride=1),
]
model_args = dict(block_cfg=block_cfg, preact=True, norm_layer=nn.BatchNorm2d, **kwargs)
return _xception('xception41p', pretrained=pretrained, **model_args)

Loading…
Cancel
Save