Significant norm update

* ConvBnAct layer renamed -> ConvNormAct and ConvNormActAa for anti-aliased
* Significant update to EfficientNet and MobileNetV3 arch to support NormAct layers and grouped conv (as alternative to depthwise)
* Update RegNet to add Z variant
* Add Pre variant of XceptionAligned that works with NormAct layers
* EvoNorm matches bits_and_tpu branch for merge
pull/1014/head
Ross Wightman 3 years ago
parent d04f2f1377
commit ab49d275de

@ -34,8 +34,8 @@ import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, named_apply
from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple, EvoNorm2dS0, EvoNorm2dS0a,\
from .layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0, EvoNorm2dS0a,\
EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a, FilterResponseNormAct2d, FilterResponseNormTlu2d
from .registry import register_model
@ -921,7 +921,7 @@ def num_groups(group_size, channels):
@dataclass
class LayerFn:
conv_norm_act: Callable = ConvBnAct
conv_norm_act: Callable = ConvNormAct
norm_act: Callable = BatchNormAct2d
act: Callable = nn.ReLU
attn: Optional[Callable] = None
@ -978,7 +978,7 @@ class BasicBlock(nn.Module):
self.conv1_kxk = layers.conv_norm_act(in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0])
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
self.conv2_kxk = layers.conv_norm_act(
mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block, apply_act=False)
mid_chs, out_chs, kernel_size, dilation=dilation[1], groups=groups, drop_layer=drop_block, apply_act=False)
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
@ -1019,11 +1019,9 @@ class BottleneckBlock(nn.Module):
self.conv1_1x1 = layers.conv_norm_act(in_chs, mid_chs, 1)
self.conv2_kxk = layers.conv_norm_act(
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block)
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block)
if extra_conv:
self.conv2b_kxk = layers.conv_norm_act(
mid_chs, mid_chs, kernel_size, dilation=dilation[1], groups=groups, drop_block=drop_block)
self.conv2b_kxk = layers.conv_norm_act(mid_chs, mid_chs, kernel_size, dilation=dilation[1], groups=groups)
else:
self.conv2b_kxk = nn.Identity()
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
@ -1080,7 +1078,7 @@ class DarkBlock(nn.Module):
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
self.conv2_kxk = layers.conv_norm_act(
mid_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block, apply_act=False)
groups=groups, drop_layer=drop_block, apply_act=False)
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
self.act = nn.Identity() if linear_out else layers.act(inplace=True)
@ -1127,8 +1125,7 @@ class EdgeBlock(nn.Module):
apply_act=False, layers=layers)
self.conv1_kxk = layers.conv_norm_act(
in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block)
in_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0], groups=groups, drop_layer=drop_block)
self.attn = nn.Identity() if attn_last or layers.attn is None else layers.attn(mid_chs)
self.conv2_1x1 = layers.conv_norm_act(mid_chs, out_chs, 1, apply_act=False)
self.attn_last = nn.Identity() if not attn_last or layers.attn is None else layers.attn(out_chs)
@ -1172,7 +1169,7 @@ class RepVggBlock(nn.Module):
self.identity = layers.norm_act(out_chs, apply_act=False) if use_ident else None
self.conv_kxk = layers.conv_norm_act(
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block, apply_act=False)
groups=groups, drop_layer=drop_block, apply_act=False)
self.conv_1x1 = layers.conv_norm_act(in_chs, out_chs, 1, stride=stride, groups=groups, apply_act=False)
self.attn = nn.Identity() if layers.attn is None else layers.attn(out_chs)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. and use_ident else nn.Identity()
@ -1219,7 +1216,7 @@ class SelfAttnBlock(nn.Module):
if extra_conv:
self.conv2_kxk = layers.conv_norm_act(
mid_chs, mid_chs, kernel_size, stride=stride, dilation=dilation[0],
groups=groups, drop_block=drop_block)
groups=groups, drop_layer=drop_block)
stride = 1 # striding done via conv if enabled
else:
self.conv2_kxk = nn.Identity()
@ -1466,8 +1463,8 @@ def create_byob_stages(
def get_layer_fns(cfg: ByoModelCfg):
act = get_act_layer(cfg.act_layer)
norm_act = convert_norm_act(norm_layer=cfg.norm_layer, act_layer=act)
conv_norm_act = partial(ConvBnAct, norm_layer=cfg.norm_layer, act_layer=act)
norm_act = get_norm_act_layer(norm_layer=cfg.norm_layer, act_layer=act)
conv_norm_act = partial(ConvNormAct, norm_layer=cfg.norm_layer, act_layer=act)
attn = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
self_attn = partial(get_attn(cfg.self_attn_layer), **cfg.self_attn_kwargs) if cfg.self_attn_layer else None
layer_fn = LayerFn(conv_norm_act=conv_norm_act, norm_act=norm_act, act=act, attn=attn, self_attn=self_attn)

@ -14,11 +14,10 @@ Hacked together by / Copyright 2020 Ross Wightman
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import ClassifierHead, ConvBnAct, DropPath, create_attn, get_norm_act_layer
from .layers import ClassifierHead, ConvNormAct, ConvNormActAa, DropPath, create_attn, get_norm_act_layer
from .registry import register_model
@ -130,7 +129,7 @@ model_cfgs = dict(
def create_stem(
in_chans=3, out_chs=32, kernel_size=3, stride=2, pool='',
act_layer=None, norm_layer=None, aa_layer=None):
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None):
stem = nn.Sequential()
if not isinstance(out_chs, (tuple, list)):
out_chs = [out_chs]
@ -138,7 +137,7 @@ def create_stem(
in_c = in_chans
for i, out_c in enumerate(out_chs):
conv_name = f'conv{i + 1}'
stem.add_module(conv_name, ConvBnAct(
stem.add_module(conv_name, ConvNormAct(
in_c, out_c, kernel_size, stride=stride if i == 0 else 1,
act_layer=act_layer, norm_layer=norm_layer))
in_c = out_c
@ -161,12 +160,14 @@ class ResBottleneck(nn.Module):
attn_layer=None, aa_layer=None, drop_block=None, drop_path=None):
super(ResBottleneck, self).__init__()
mid_chs = int(round(out_chs * bottle_ratio))
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block)
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
self.conv1 = ConvBnAct(in_chs, mid_chs, kernel_size=1, **ckwargs)
self.conv2 = ConvBnAct(mid_chs, mid_chs, kernel_size=3, dilation=dilation, groups=groups, **ckwargs)
self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs)
self.conv2 = ConvNormActAa(
mid_chs, mid_chs, kernel_size=3, dilation=dilation, groups=groups,
aa_layer=aa_layer, drop_layer=drop_block, **ckwargs)
self.attn2 = create_attn(attn_layer, channels=mid_chs) if not attn_last else None
self.conv3 = ConvBnAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs)
self.conv3 = ConvNormAct(mid_chs, out_chs, kernel_size=1, apply_act=False, **ckwargs)
self.attn3 = create_attn(attn_layer, channels=out_chs) if attn_last else None
self.drop_path = drop_path
self.act3 = act_layer(inplace=True)
@ -201,9 +202,11 @@ class DarkBlock(nn.Module):
drop_block=None, drop_path=None):
super(DarkBlock, self).__init__()
mid_chs = int(round(out_chs * bottle_ratio))
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block)
self.conv1 = ConvBnAct(in_chs, mid_chs, kernel_size=1, **ckwargs)
self.conv2 = ConvBnAct(mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups, **ckwargs)
ckwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
self.conv1 = ConvNormAct(in_chs, mid_chs, kernel_size=1, **ckwargs)
self.conv2 = ConvNormActAa(
mid_chs, out_chs, kernel_size=3, dilation=dilation, groups=groups,
aa_layer=aa_layer, drop_layer=drop_block, **ckwargs)
self.attn = create_attn(attn_layer, channels=out_chs)
self.drop_path = drop_path
@ -235,7 +238,7 @@ class CrossStage(nn.Module):
conv_kwargs = dict(act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'))
if stride != 1 or first_dilation != dilation:
self.conv_down = ConvBnAct(
self.conv_down = ConvNormActAa(
in_chs, down_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
aa_layer=block_kwargs.get('aa_layer', None), **conv_kwargs)
prev_chs = down_chs
@ -246,7 +249,7 @@ class CrossStage(nn.Module):
# FIXME this 1x1 expansion is pushed down into the cross and block paths in the darknet cfgs. Also,
# there is also special case for the first stage for some of the model that results in uneven split
# across the two paths. I did it this way for simplicity for now.
self.conv_exp = ConvBnAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs)
self.conv_exp = ConvNormAct(prev_chs, exp_chs, kernel_size=1, apply_act=not cross_linear, **conv_kwargs)
prev_chs = exp_chs // 2 # output of conv_exp is always split in two
self.blocks = nn.Sequential()
@ -257,8 +260,8 @@ class CrossStage(nn.Module):
prev_chs = block_out_chs
# transition convs
self.conv_transition_b = ConvBnAct(prev_chs, exp_chs // 2, kernel_size=1, **conv_kwargs)
self.conv_transition = ConvBnAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs)
self.conv_transition_b = ConvNormAct(prev_chs, exp_chs // 2, kernel_size=1, **conv_kwargs)
self.conv_transition = ConvNormAct(exp_chs, out_chs, kernel_size=1, **conv_kwargs)
def forward(self, x):
if self.conv_down is not None:
@ -280,7 +283,7 @@ class DarkStage(nn.Module):
super(DarkStage, self).__init__()
first_dilation = first_dilation or dilation
self.conv_down = ConvBnAct(
self.conv_down = ConvNormActAa(
in_chs, out_chs, kernel_size=3, stride=stride, dilation=first_dilation, groups=groups,
act_layer=block_kwargs.get('act_layer'), norm_layer=block_kwargs.get('norm_layer'),
aa_layer=block_kwargs.get('aa_layer', None))
@ -437,7 +440,7 @@ def cspresnext50(pretrained=False, **kwargs):
@register_model
def cspresnext50_iabn(pretrained=False, **kwargs):
norm_layer = get_norm_act_layer('iabn')
norm_layer = get_norm_act_layer('iabn', act_layer='leaky_relu')
return _create_cspnet('cspresnext50_iabn', pretrained=pretrained, norm_layer=norm_layer, **kwargs)
@ -448,7 +451,7 @@ def cspdarknet53(pretrained=False, **kwargs):
@register_model
def cspdarknet53_iabn(pretrained=False, **kwargs):
norm_layer = get_norm_act_layer('iabn')
norm_layer = get_norm_act_layer('iabn', act_layer='leaky_relu')
return _create_cspnet('cspdarknet53_iabn', pretrained=pretrained, block_fn=DarkBlock, norm_layer=norm_layer, **kwargs)

@ -14,7 +14,7 @@ from torch.jit.annotations import List
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import BatchNormAct2d, create_norm_act, BlurPool2d, create_classifier
from .layers import BatchNormAct2d, create_norm_act_layer, BlurPool2d, create_classifier
from .registry import register_model
__all__ = ['DenseNet']
@ -370,7 +370,7 @@ def densenet264d_iabn(pretrained=False, **kwargs):
r"""Densenet-264 model with deep stem and Inplace-ABN
"""
def norm_act_fn(num_features, **kwargs):
return create_norm_act('iabn', num_features, **kwargs)
return create_norm_act_layer('iabn', num_features, act_layer='leaky_relu', **kwargs)
model = _create_densenet(
'densenet264d_iabn', growth_rate=48, block_config=(6, 12, 64, 48), stem_type='deep',
norm_layer=norm_act_fn, pretrained=pretrained, **kwargs)

@ -16,7 +16,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import BatchNormAct2d, ConvBnAct, create_conv2d, create_classifier
from .layers import BatchNormAct2d, ConvNormAct, create_conv2d, create_classifier
from .registry import register_model
__all__ = ['DPN']
@ -180,7 +180,7 @@ class DPN(nn.Module):
blocks = OrderedDict()
# conv1
blocks['conv1_1'] = ConvBnAct(
blocks['conv1_1'] = ConvNormAct(
in_chans, num_init_features, kernel_size=3 if small else 7, stride=2, norm_layer=norm_layer)
blocks['conv1_pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.feature_info = [dict(num_chs=num_init_features, reduction=2, module='features.conv1_1')]

@ -45,7 +45,7 @@ from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficien
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg, default_cfg_for_features
from .layers import create_conv2d, create_classifier
from .layers import create_conv2d, create_classifier, get_norm_act_layer, EvoNorm2dS0, GroupNormAct
from .registry import register_model
__all__ = ['EfficientNet', 'EfficientNetFeatures']
@ -117,6 +117,20 @@ default_cfgs = {
'efficientnet_l2': _cfg(
url='', input_size=(3, 800, 800), pool_size=(25, 25), crop_pct=0.961),
# FIXME experimental
'efficientnet_b0_gn': _cfg(
url=''),
'efficientnet_b0_g8': _cfg(
url=''),
'efficientnet_b0_g16_evos': _cfg(
url=''),
'efficientnet_b3_gn': _cfg(
url='',
input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0),
'efficientnet_b3_g8_gn': _cfg(
url='',
input_size=(3, 288, 288), pool_size=(9, 9), test_input_size=(3, 320, 320), crop_pct=1.0),
'efficientnet_es': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth'),
'efficientnet_em': _cfg(
@ -431,6 +445,7 @@ class EfficientNet(nn.Module):
super(EfficientNet, self).__init__()
act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
se_layer = se_layer or SqueezeExcite
self.num_classes = num_classes
self.num_features = num_features
@ -440,8 +455,7 @@ class EfficientNet(nn.Module):
if not fix_stem:
stem_size = round_chs_fn(stem_size)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size)
self.act1 = act_layer(inplace=True)
self.bn1 = norm_act_layer(stem_size, inplace=True)
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
@ -453,17 +467,16 @@ class EfficientNet(nn.Module):
# Head + Pooling
self.conv_head = create_conv2d(head_chs, self.num_features, 1, padding=pad_type)
self.bn2 = norm_layer(self.num_features)
self.act2 = act_layer(inplace=True)
self.bn2 = norm_act_layer(self.num_features, inplace=True)
self.global_pool, self.classifier = create_classifier(
self.num_features, self.num_classes, pool_type=global_pool)
efficientnet_init_weights(self)
def as_sequential(self):
layers = [self.conv_stem, self.bn1, self.act1]
layers = [self.conv_stem, self.bn1]
layers.extend(self.blocks)
layers.extend([self.conv_head, self.bn2, self.act2, self.global_pool])
layers.extend([self.conv_head, self.bn2, self.global_pool])
layers.extend([nn.Dropout(self.drop_rate), self.classifier])
return nn.Sequential(*layers)
@ -478,11 +491,9 @@ class EfficientNet(nn.Module):
def forward_features(self, x):
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
x = self.blocks(x)
x = self.conv_head(x)
x = self.bn2(x)
x = self.act2(x)
return x
def forward(self, x):
@ -506,6 +517,7 @@ class EfficientNetFeatures(nn.Module):
super(EfficientNetFeatures, self).__init__()
act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
se_layer = se_layer or SqueezeExcite
self.drop_rate = drop_rate
@ -513,8 +525,7 @@ class EfficientNetFeatures(nn.Module):
if not fix_stem:
stem_size = round_chs_fn(stem_size)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size)
self.act1 = act_layer(inplace=True)
self.bn1 = norm_act_layer(stem_size, inplace=True)
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
@ -536,7 +547,6 @@ class EfficientNetFeatures(nn.Module):
def forward(self, x) -> List[torch.Tensor]:
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
if self.feature_hooks is None:
features = []
if 0 in self._stage_out_idx:
@ -767,7 +777,9 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
return model
def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
def _gen_efficientnet(
variant, channel_multiplier=1.0, depth_multiplier=1.0, channel_divisor=8,
group_size=None, pretrained=False, **kwargs):
"""Creates an EfficientNet model.
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
@ -800,9 +812,9 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
['ir_r4_k5_s2_e6_c192_se0.25'],
['ir_r1_k3_s1_e6_c320_se0.25'],
]
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
round_chs_fn = partial(round_channels, multiplier=channel_multiplier, divisor=channel_divisor)
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
num_features=round_chs_fn(1280),
stem_size=32,
round_chs_fn=round_chs_fn,
@ -814,7 +826,8 @@ def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pre
return model
def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
def _gen_efficientnet_edge(
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs):
""" Creates an EfficientNet-EdgeTPU model
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu
@ -832,7 +845,7 @@ def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0
]
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
num_features=round_chs_fn(1280),
stem_size=32,
round_chs_fn=round_chs_fn,
@ -946,7 +959,7 @@ def _gen_efficientnetv2_base(
def _gen_efficientnetv2_s(
variant, channel_multiplier=1.0, depth_multiplier=1.0, rw=False, pretrained=False, **kwargs):
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, rw=False, pretrained=False, **kwargs):
""" Creates an EfficientNet-V2 Small model
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@ -972,7 +985,7 @@ def _gen_efficientnetv2_s(
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
model_kwargs = dict(
block_args=decode_arch_def(arch_def, depth_multiplier),
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
num_features=round_chs_fn(num_features),
stem_size=24,
round_chs_fn=round_chs_fn,
@ -1366,6 +1379,52 @@ def efficientnet_l2(pretrained=False, **kwargs):
return model
# FIXME experimental group cong / GroupNorm / EvoNorm experiments
@register_model
def efficientnet_b0_gn(pretrained=False, **kwargs):
""" EfficientNet-B0 + GroupNorm"""
model = _gen_efficientnet(
'efficientnet_b0_gn', norm_layer=partial(GroupNormAct, group_size=8), pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_b0_g8(pretrained=False, **kwargs):
""" EfficientNet-B0 w/ group conv + BN"""
model = _gen_efficientnet(
'efficientnet_b0_g8', group_size=8, pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_b0_g16_evos(pretrained=False, **kwargs):
""" EfficientNet-B0 w/ group 16 conv + EvoNorm"""
model = _gen_efficientnet(
'efficientnet_b0_g16_evos', group_size=16, channel_divisor=16,
norm_layer=partial(EvoNorm2dS0, group_size=16), pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_b3_gn(pretrained=False, **kwargs):
""" EfficientNet-B3 w/ GroupNorm """
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
model = _gen_efficientnet(
'efficientnet_b3_gn', channel_multiplier=1.2, depth_multiplier=1.4, channel_divisor=16,
norm_layer=partial(GroupNormAct, group_size=16), pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_b3_g8_gn(pretrained=False, **kwargs):
""" EfficientNet-B3 w/ grouped conv + BN"""
# NOTE for train, drop_rate should be 0.3, drop_path_rate should be 0.2
model = _gen_efficientnet(
'efficientnet_b3_g8_gn', channel_multiplier=1.2, depth_multiplier=1.4, group_size=8, channel_divisor=16,
norm_layer=partial(GroupNormAct, group_size=16), pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_es(pretrained=False, **kwargs):
""" EfficientNet-Edge Small. """
@ -1373,6 +1432,7 @@ def efficientnet_es(pretrained=False, **kwargs):
'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs)
return model
@register_model
def efficientnet_es_pruned(pretrained=False, **kwargs):
""" EfficientNet-Edge Small Pruned. For more info: https://github.com/DeGirum/pruned-models/releases/tag/efficientnet_v1.0"""

@ -2,18 +2,31 @@
Hacked together by / Copyright 2020 Ross Wightman
"""
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from .layers import create_conv2d, drop_path, make_divisible, create_act_layer
from .layers.activations import sigmoid
from .layers import create_conv2d, DropPath, make_divisible, create_act_layer, get_norm_act_layer
__all__ = [
'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', 'InvertedResidual', 'CondConvResidual', 'EdgeResidual']
def num_groups(group_size, channels):
if not group_size: # 0 or None
return 1 # normal conv with 1 group
else:
# NOTE group_size == 1 -> depthwise conv
#assert channels % group_size == 0
if channels % group_size != 0:
num_groups = math.floor(channels / group_size)
print(channels, group_size, num_groups)
return int(num_groups)
return channels // group_size
class SqueezeExcite(nn.Module):
""" Squeeze-and-Excitation w/ specific features for EfficientNet/MobileNet family
@ -51,31 +64,30 @@ class ConvBnAct(nn.Module):
""" Conv + Norm Layer + Activation w/ optional skip connection
"""
def __init__(
self, in_chs, out_chs, kernel_size, stride=1, dilation=1, pad_type='',
self, in_chs, out_chs, kernel_size, stride=1, dilation=1, group_size=0, pad_type='',
skip=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_path_rate=0.):
super(ConvBnAct, self).__init__()
self.has_residual = skip and stride == 1 and in_chs == out_chs
self.drop_path_rate = drop_path_rate
self.conv = create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, padding=pad_type)
self.bn1 = norm_layer(out_chs)
self.act1 = act_layer(inplace=True)
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
groups = num_groups(group_size, in_chs)
self.has_skip = skip and stride == 1 and in_chs == out_chs
self.conv = create_conv2d(
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, groups=groups, padding=pad_type)
self.bn1 = norm_act_layer(out_chs, inplace=True)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
def feature_info(self, location):
if location == 'expansion': # output of conv after act, same as block coutput
info = dict(module='act1', hook_type='forward', num_chs=self.conv.out_channels)
return dict(module='bn1', hook_type='forward', num_chs=self.conv.out_channels)
else: # location == 'bottleneck', block output
info = dict(module='', hook_type='', num_chs=self.conv.out_channels)
return info
return dict(module='', hook_type='', num_chs=self.conv.out_channels)
def forward(self, x):
shortcut = x
x = self.conv(x)
x = self.bn1(x)
x = self.act1(x)
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += shortcut
if self.has_skip:
x = x + self.drop_path(shortcut)
return x
@ -85,50 +97,41 @@ class DepthwiseSeparableConv(nn.Module):
(factor of 1.0). This is an alternative to having a IR with an optional first pw conv.
"""
def __init__(
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='',
noskip=False, pw_kernel_size=1, pw_act=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
se_layer=None, drop_path_rate=0.):
super(DepthwiseSeparableConv, self).__init__()
self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
groups = num_groups(group_size, in_chs)
self.has_skip = (stride == 1 and in_chs == out_chs) and not noskip
self.has_pw_act = pw_act # activation after point-wise conv
self.drop_path_rate = drop_path_rate
self.conv_dw = create_conv2d(
in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, depthwise=True)
self.bn1 = norm_layer(in_chs)
self.act1 = act_layer(inplace=True)
in_chs, in_chs, dw_kernel_size, stride=stride, dilation=dilation, padding=pad_type, groups=groups)
self.bn1 = norm_act_layer(in_chs, inplace=True)
# Squeeze-and-excitation
self.se = se_layer(in_chs, act_layer=act_layer) if se_layer else nn.Identity()
self.conv_pw = create_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type)
self.bn2 = norm_layer(out_chs)
self.act2 = act_layer(inplace=True) if self.has_pw_act else nn.Identity()
self.bn2 = norm_act_layer(out_chs, inplace=True, apply_act=self.has_pw_act)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
def feature_info(self, location):
if location == 'expansion': # after SE, input to PW
info = dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
return dict(module='conv_pw', hook_type='forward_pre', num_chs=self.conv_pw.in_channels)
else: # location == 'bottleneck', block output
info = dict(module='', hook_type='', num_chs=self.conv_pw.out_channels)
return info
return dict(module='', hook_type='', num_chs=self.conv_pw.out_channels)
def forward(self, x):
shortcut = x
x = self.conv_dw(x)
x = self.bn1(x)
x = self.act1(x)
x = self.se(x)
x = self.conv_pw(x)
x = self.bn2(x)
x = self.act2(x)
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += shortcut
if self.has_skip:
x = x + self.drop_path(shortcut)
return x
@ -143,66 +146,51 @@ class InvertedResidual(nn.Module):
"""
def __init__(
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='',
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, se_layer=None, conv_kwargs=None, drop_path_rate=0.):
super(InvertedResidual, self).__init__()
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
conv_kwargs = conv_kwargs or {}
mid_chs = make_divisible(in_chs * exp_ratio)
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_path_rate = drop_path_rate
groups = num_groups(group_size, mid_chs)
self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
# Point-wise expansion
self.conv_pw = create_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs)
self.bn1 = norm_layer(mid_chs)
self.act1 = act_layer(inplace=True)
self.bn1 = norm_act_layer(mid_chs, inplace=True)
# Depth-wise convolution
self.conv_dw = create_conv2d(
mid_chs, mid_chs, dw_kernel_size, stride=stride, dilation=dilation,
padding=pad_type, depthwise=True, **conv_kwargs)
self.bn2 = norm_layer(mid_chs)
self.act2 = act_layer(inplace=True)
groups=groups, padding=pad_type, **conv_kwargs)
self.bn2 = norm_act_layer(mid_chs, inplace=True)
# Squeeze-and-excitation
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
# Point-wise linear projection
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs)
self.bn3 = norm_layer(out_chs)
self.bn3 = norm_act_layer(out_chs, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
def feature_info(self, location):
if location == 'expansion': # after SE, input to PWL
info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
else: # location == 'bottleneck', block output
info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
return info
return dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
def forward(self, x):
shortcut = x
# Point-wise expansion
x = self.conv_pw(x)
x = self.bn1(x)
x = self.act1(x)
# Depth-wise convolution
x = self.conv_dw(x)
x = self.bn2(x)
x = self.act2(x)
# Squeeze-and-excitation
x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x)
x = self.bn3(x)
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += shortcut
if self.has_skip:
x = x + self.drop_path(shortcut)
return x
@ -210,7 +198,7 @@ class CondConvResidual(InvertedResidual):
""" Inverted residual block w/ CondConv routing"""
def __init__(
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, pad_type='',
self, in_chs, out_chs, dw_kernel_size=3, stride=1, dilation=1, group_size=1, pad_type='',
noskip=False, exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, se_layer=None, num_experts=0, drop_path_rate=0.):
@ -218,8 +206,8 @@ class CondConvResidual(InvertedResidual):
conv_kwargs = dict(num_experts=self.num_experts)
super(CondConvResidual, self).__init__(
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, pad_type=pad_type,
act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, dilation=dilation, group_size=group_size,
pad_type=pad_type, act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size, se_layer=se_layer, norm_layer=norm_layer, conv_kwargs=conv_kwargs,
drop_path_rate=drop_path_rate)
@ -227,32 +215,17 @@ class CondConvResidual(InvertedResidual):
def forward(self, x):
shortcut = x
# CondConv routing
pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1)
pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) # CondConv routing
routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs))
# Point-wise expansion
x = self.conv_pw(x, routing_weights)
x = self.bn1(x)
x = self.act1(x)
# Depth-wise convolution
x = self.conv_dw(x, routing_weights)
x = self.bn2(x)
x = self.act2(x)
# Squeeze-and-excitation
x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x, routing_weights)
x = self.bn3(x)
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += shortcut
if self.has_skip:
x = x + self.drop_path(shortcut)
return x
@ -269,55 +242,44 @@ class EdgeResidual(nn.Module):
"""
def __init__(
self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, pad_type='',
self, in_chs, out_chs, exp_kernel_size=3, stride=1, dilation=1, group_size=0, pad_type='',
force_in_chs=0, noskip=False, exp_ratio=1.0, pw_kernel_size=1, act_layer=nn.ReLU,
norm_layer=nn.BatchNorm2d, se_layer=None, drop_path_rate=0.):
super(EdgeResidual, self).__init__()
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
if force_in_chs > 0:
mid_chs = make_divisible(force_in_chs * exp_ratio)
else:
mid_chs = make_divisible(in_chs * exp_ratio)
self.has_residual = (in_chs == out_chs and stride == 1) and not noskip
self.drop_path_rate = drop_path_rate
groups = num_groups(group_size, in_chs)
self.has_skip = (in_chs == out_chs and stride == 1) and not noskip
# Expansion convolution
self.conv_exp = create_conv2d(
in_chs, mid_chs, exp_kernel_size, stride=stride, dilation=dilation, padding=pad_type)
self.bn1 = norm_layer(mid_chs)
self.act1 = act_layer(inplace=True)
in_chs, mid_chs, exp_kernel_size, stride=stride, dilation=dilation, groups=groups, padding=pad_type)
self.bn1 = norm_act_layer(mid_chs, inplace=True)
# Squeeze-and-excitation
self.se = se_layer(mid_chs, act_layer=act_layer) if se_layer else nn.Identity()
# Point-wise linear projection
self.conv_pwl = create_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type)
self.bn2 = norm_layer(out_chs)
self.bn2 = norm_act_layer(out_chs, apply_act=False)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate else nn.Identity()
def feature_info(self, location):
if location == 'expansion': # after SE, before PWL
info = dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
return dict(module='conv_pwl', hook_type='forward_pre', num_chs=self.conv_pwl.in_channels)
else: # location == 'bottleneck', block output
info = dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
return info
return dict(module='', hook_type='', num_chs=self.conv_pwl.out_channels)
def forward(self, x):
shortcut = x
# Expansion convolution
x = self.conv_exp(x)
x = self.bn1(x)
x = self.act1(x)
# Squeeze-and-excitation
x = self.se(x)
# Point-wise linear projection
x = self.conv_pwl(x)
x = self.bn2(x)
if self.has_residual:
if self.drop_path_rate > 0.:
x = drop_path(x, self.drop_path_rate, self.training)
x += shortcut
if self.has_skip:
x = x + self.drop_path(shortcut)
return x

@ -139,60 +139,52 @@ def _decode_block_str(block_str):
exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1
pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1
force_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
num_repeat = int(options['r'])
# each type of block has different valid arguments, fill accordingly
block_args = dict(
block_type=block_type,
out_chs=int(options['c']),
stride=int(options['s']),
act_layer=act_layer,
)
if block_type == 'ir':
block_args = dict(
block_type=block_type,
block_args.update(dict(
dw_kernel_size=_parse_ksize(options['k']),
exp_kernel_size=exp_kernel_size,
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
se_ratio=float(options['se']) if 'se' in options else 0.,
stride=int(options['s']),
act_layer=act_layer,
noskip=skip is False,
)
))
if 'cc' in options:
block_args['num_experts'] = int(options['cc'])
elif block_type == 'ds' or block_type == 'dsa':
block_args = dict(
block_type=block_type,
block_args.update(dict(
dw_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
se_ratio=float(options['se']) if 'se' in options else 0.,
stride=int(options['s']),
act_layer=act_layer,
pw_act=block_type == 'dsa',
noskip=block_type == 'dsa' or skip is False,
)
))
elif block_type == 'er':
block_args = dict(
block_type=block_type,
block_args.update(dict(
exp_kernel_size=_parse_ksize(options['k']),
pw_kernel_size=pw_kernel_size,
out_chs=int(options['c']),
exp_ratio=float(options['e']),
force_in_chs=force_in_chs,
se_ratio=float(options['se']) if 'se' in options else 0.,
stride=int(options['s']),
act_layer=act_layer,
noskip=skip is False,
)
))
elif block_type == 'cn':
block_args = dict(
block_type=block_type,
block_args.update(dict(
kernel_size=int(options['k']),
out_chs=int(options['c']),
stride=int(options['s']),
act_layer=act_layer,
skip=skip is True,
)
))
else:
assert False, 'Unknown block type (%s)' % block_type
if 'gs' in options:
block_args['group_size'] = options['gs']
return block_args, num_repeat
@ -235,7 +227,27 @@ def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='c
return sa_scaled
def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False):
def decode_arch_def(
arch_def,
depth_multiplier=1.0,
depth_trunc='ceil',
experts_multiplier=1,
fix_first_last=False,
group_size=None,
):
""" Decode block architecture definition strings -> block kwargs
Args:
arch_def: architecture definition strings, list of list of strings
depth_multiplier: network depth multiplier
depth_trunc: networ depth truncation mode when applying multiplier
experts_multiplier: CondConv experts multiplier
fix_first_last: fix first and last block depths when multiplier is applied
group_size: group size override for all blocks that weren't explicitly set in arch string
Returns:
list of list of block kwargs
"""
arch_args = []
if isinstance(depth_multiplier, tuple):
assert len(depth_multiplier) == len(arch_def)
@ -250,6 +262,8 @@ def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_
ba, rep = _decode_block_str(block_str)
if ba.get('num_experts', 0) > 0 and experts_multiplier > 1:
ba['num_experts'] *= experts_multiplier
if group_size is not None:
ba.setdefault('group_size', group_size)
stack_args.append(ba)
repeats.append(rep)
if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1):

@ -7,11 +7,11 @@ from .cond_conv2d import CondConv2d, get_condconv_initializer
from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set_scriptable, set_no_jit,\
set_layer_config
from .conv2d_same import Conv2dSame, conv2d_same
from .conv_bn_act import ConvBnAct
from .conv_bn_act import ConvNormAct, ConvNormActAa, ConvBnAct
from .create_act import create_act_layer, get_act_layer, get_act_fn
from .create_attn import get_attn, create_attn
from .create_conv2d import create_conv2d
from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act
from .create_norm_act import get_norm_act_layer, create_norm_act_layer, get_norm_act_layer
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn
from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\
@ -32,7 +32,7 @@ from .patch_embed import PatchEmbed
from .pool2d_same import AvgPool2dSame, create_pool2d
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
from .selective_kernel import SelectiveKernel
from .separable_conv import SeparableConv2d, SeparableConvBnAct
from .separable_conv import SeparableConv2d, SeparableConvNormAct
from .space_to_depth import SpaceToDepthModule
from .split_attn import SplitAttn
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model

@ -11,7 +11,7 @@ import torch
from torch import nn as nn
import torch.nn.functional as F
from .conv_bn_act import ConvBnAct
from .conv_bn_act import ConvNormAct
from .create_act import create_act_layer, get_act_layer
from .helpers import make_divisible
@ -56,7 +56,7 @@ class SpatialAttn(nn.Module):
"""
def __init__(self, kernel_size=7, gate_layer='sigmoid'):
super(SpatialAttn, self).__init__()
self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None)
self.conv = ConvNormAct(2, 1, kernel_size, apply_act=False)
self.gate = create_act_layer(gate_layer)
def forward(self, x):
@ -70,7 +70,7 @@ class LightSpatialAttn(nn.Module):
"""
def __init__(self, kernel_size=7, gate_layer='sigmoid'):
super(LightSpatialAttn, self).__init__()
self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None)
self.conv = ConvNormAct(1, 1, kernel_size, apply_act=False)
self.gate = create_act_layer(gate_layer)
def forward(self, x):

@ -5,14 +5,46 @@ Hacked together by / Copyright 2020 Ross Wightman
from torch import nn as nn
from .create_conv2d import create_conv2d
from .create_norm_act import convert_norm_act
from .create_norm_act import get_norm_act_layer
class ConvBnAct(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None,
drop_block=None):
super(ConvBnAct, self).__init__()
class ConvNormAct(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, drop_layer=None):
super(ConvNormAct, self).__init__()
self.conv = create_conv2d(
in_channels, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
@property
def in_channels(self):
return self.conv.in_channels
@property
def out_channels(self):
return self.conv.out_channels
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
ConvBnAct = ConvNormAct
class ConvNormActAa(nn.Module):
def __init__(
self, in_channels, out_channels, kernel_size=1, stride=1, padding='', dilation=1, groups=1,
bias=False, apply_act=True, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU, aa_layer=None, drop_layer=None):
super(ConvNormActAa, self).__init__()
use_aa = aa_layer is not None
self.conv = create_conv2d(
@ -20,9 +52,11 @@ class ConvBnAct(nn.Module):
padding=padding, dilation=dilation, groups=groups, bias=bias)
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
norm_act_layer = convert_norm_act(norm_layer, act_layer)
self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block)
self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else None
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
self.aa = aa_layer(channels=out_channels) if stride == 2 and use_aa else nn.Identity()
@property
def in_channels(self):
@ -35,6 +69,5 @@ class ConvBnAct(nn.Module):
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
if self.aa is not None:
x = self.aa(x)
x = self.aa(x)
return x

@ -16,7 +16,12 @@ def create_conv2d(in_channels, out_channels, kernel_size, **kwargs):
"""
if isinstance(kernel_size, list):
assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently
assert 'groups' not in kwargs # MixedConv groups are defined by kernel list
if 'groups' in kwargs:
groups = kwargs.pop('groups')
if groups == in_channels:
kwargs['depthwise'] = True
else:
assert groups == 1
# We're going to use only lists for defining the MixedConv2d kernel groups,
# ints, tuples, other iterables will continue to pass to normal conv and specify h, w.
m = MixedConv2d(in_channels, out_channels, kernel_size, **kwargs)

@ -11,12 +11,15 @@ import functools
from .evo_norm import *
from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d
from .norm_act import BatchNormAct2d, GroupNormAct
from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d
from .inplace_abn import InplaceAbn
_NORM_ACT_MAP = dict(
batchnorm=BatchNormAct2d,
batchnorm2d=BatchNormAct2d,
groupnorm=GroupNormAct,
layernorm=LayerNormAct,
layernorm2d=LayerNormAct2d,
evonormb0=EvoNorm2dB0,
evonormb1=EvoNorm2dB1,
evonormb2=EvoNorm2dB2,
@ -33,28 +36,19 @@ _NORM_ACT_MAP = dict(
)
_NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()}
# has act_layer arg to define act type
_NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, FilterResponseNormAct2d, InplaceAbn}
_NORM_ACT_REQUIRES_ARG = {
BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn}
def get_norm_act_layer(layer_name):
layer_name = layer_name.replace('_', '').lower().split('-')[0]
layer = _NORM_ACT_MAP.get(layer_name, None)
assert layer is not None, "Invalid norm_act layer (%s)" % layer_name
return layer
def create_norm_act(layer_name, num_features, apply_act=True, jit=False, **kwargs):
layer_parts = layer_name.split('-') # e.g. batchnorm-leaky_relu
assert len(layer_parts) in (1, 2)
layer = get_norm_act_layer(layer_parts[0])
#activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection?
def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs):
layer = get_norm_act_layer(layer_name, act_layer=act_layer)
layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
if jit:
layer_instance = torch.jit.script(layer_instance)
return layer_instance
def convert_norm_act(norm_layer, act_layer):
def get_norm_act_layer(norm_layer, act_layer=None):
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial))
norm_act_kwargs = {}
@ -65,7 +59,8 @@ def convert_norm_act(norm_layer, act_layer):
norm_layer = norm_layer.func
if isinstance(norm_layer, str):
norm_act_layer = get_norm_act_layer(norm_layer)
layer_name = norm_layer.replace('_', '').lower().split('-')[0]
norm_act_layer = _NORM_ACT_MAP.get(layer_name, None)
elif norm_layer in _NORM_ACT_TYPES:
norm_act_layer = norm_layer
elif isinstance(norm_layer, types.FunctionType):
@ -77,6 +72,10 @@ def convert_norm_act(norm_layer, act_layer):
norm_act_layer = BatchNormAct2d
elif type_name.startswith('groupnorm'):
norm_act_layer = GroupNormAct
elif type_name.startswith('layernorm2d'):
norm_act_layer = LayerNormAct2d
elif type_name.startswith('layernorm'):
norm_act_layer = LayerNormAct
else:
assert False, f"No equivalent norm_act layer for {type_name}"

@ -20,7 +20,7 @@ import torch.nn.functional as F
def drop_block_2d(
x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
x, drop_prob: float = 0.1, block_size: int = 7, gamma_scale: float = 1.0,
with_noise: bool = False, inplace: bool = False, batchwise: bool = False):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
@ -32,7 +32,7 @@ def drop_block_2d(
clipped_block_size = min(block_size, min(W, H))
# seed_drop_rate, the gamma parameter
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
(W - block_size + 1) * (H - block_size + 1))
(W - block_size + 1) * (H - block_size + 1))
# Forces the block to be inside the feature map.
w_i, h_i = torch.meshgrid(torch.arange(W).to(x.device), torch.arange(H).to(x.device))
@ -104,14 +104,16 @@ def drop_block_fast_2d(
class DropBlock2d(nn.Module):
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
"""
def __init__(self,
drop_prob=0.1,
block_size=7,
gamma_scale=1.0,
with_noise=False,
inplace=False,
batchwise=False,
fast=True):
def __init__(
self,
drop_prob=0.1,
block_size=7,
gamma_scale=1.0,
with_noise=False,
inplace=False,
batchwise=False,
fast=True):
super(DropBlock2d, self).__init__()
self.drop_prob = drop_prob
self.gamma_scale = gamma_scale
@ -155,6 +157,7 @@ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: b
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=None, scale_by_keep=True):
super(DropPath, self).__init__()
self.drop_prob = drop_prob

@ -23,6 +23,7 @@ GPU, similar train speeds for EvoNormS variants and BatchNorm.
Hacked together by / Copyright 2020 Ross Wightman
"""
from typing import Sequence, Union
import torch
import torch.nn as nn
@ -33,41 +34,57 @@ from .trace_utils import _assert
def instance_std(x, eps: float = 1e-5):
rms = x.float().var(dim=(2, 3), unbiased=False, keepdim=True).add(eps).sqrt().to(x.dtype)
return rms.expand(x.shape)
std = x.float().var(dim=(2, 3), unbiased=False, keepdim=True).add(eps).sqrt().to(x.dtype)
return std.expand(x.shape)
def instance_std_tpu(x, eps: float = 1e-5):
std = manual_var(x, dim=(2, 3)).add(eps).sqrt()
return std.expand(x.shape)
# instance_std = instance_std_tpu
def instance_rms(x, eps: float = 1e-5):
rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).sqrt().to(dtype=x.dtype)
rms = x.float().square().mean(dim=(2, 3), keepdim=True).add(eps).sqrt().to(x.dtype)
return rms.expand(x.shape)
def manual_var(x, dim: Union[int, Sequence[int]], diff_sqm: bool = False):
xm = x.mean(dim=dim, keepdim=True)
if diff_sqm:
# difference of squared mean and mean squared, faster on TPU can be less stable
var = ((x * x).mean(dim=dim, keepdim=True) - (xm * xm)).clamp(0)
else:
var = ((x - xm) * (x - xm)).mean(dim=dim, keepdim=True)
return var
def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False):
B, C, H, W = x.shape
x_dtype = x.dtype
_assert(C % groups == 0, '')
# x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues
# std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt()
x = x.reshape(B, groups, C // groups, H, W)
std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt()
return std.expand(x.shape).reshape(B, C, H, W).to(x_dtype)
if flatten:
x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues
std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype)
else:
x = x.reshape(B, groups, C // groups, H, W)
std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt().to(x_dtype)
return std.expand(x.shape).reshape(B, C, H, W)
def group_std_tpu(x, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False):
def group_std_tpu(x, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False, flatten: bool = False):
# This is a workaround for some stability / odd behaviour of .var and .std
# running on PyTorch XLA w/ TPUs. These manual var impl are producing much better results
B, C, H, W = x.shape
_assert(C % groups == 0, '')
x_dtype = x.dtype
x = x.float().reshape(B, groups, C // groups, H, W)
xm = x.mean(dim=(2, 3, 4), keepdim=True)
if diff_sqm:
# difference of squared mean and mean squared, faster on TPU
var = (x.square().mean(dim=(2, 3, 4), keepdim=True) - xm.square()).clamp(0)
if flatten:
x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues
var = manual_var(x, dim=-1, diff_sqm=diff_sqm)
else:
var = (x - xm).square().mean(dim=(2, 3, 4), keepdim=True)
return var.add(eps).sqrt().expand(x.shape).reshape(B, C, H, W).to(x_dtype)
# group_std = group_std_tpu # temporary, for TPU / PT XLA
x = x.reshape(B, groups, C // groups, H, W)
var = manual_var(x, dim=(2, 3, 4), diff_sqm=diff_sqm)
return var.add(eps).sqrt().expand(x.shape).reshape(B, C, H, W)
#group_std = group_std_tpu # FIXME TPU temporary
def group_rms(x, groups: int = 32, eps: float = 1e-5):
@ -75,8 +92,8 @@ def group_rms(x, groups: int = 32, eps: float = 1e-5):
_assert(C % groups == 0, '')
x_dtype = x.dtype
x = x.reshape(B, groups, C // groups, H, W)
sqm = x.square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(dtype=x_dtype)
return sqm.expand(x.shape).reshape(B, C, H, W)
rms = x.float().square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(dtype=x_dtype)
return rms.expand(x.shape).reshape(B, C, H, W)
class EvoNorm2dB0(nn.Module):
@ -104,6 +121,7 @@ class EvoNorm2dB0(nn.Module):
if self.v is not None:
if self.training:
var = x.float().var(dim=(0, 2, 3), unbiased=False)
# var = manual_var(x, dim=(0, 2, 3)).squeeze()
n = x.numel() / x.shape[1]
self.running_var.copy_(
self.running_var * (1 - self.momentum) +
@ -230,7 +248,7 @@ class EvoNorm2dS0a(EvoNorm2dS0):
d = group_std(x, self.groups, self.eps)
if self.v is not None:
v = self.v.view(v_shape).to(dtype=x_dtype)
x = x * (x * v).sigmoid_()
x = x * (x * v).sigmoid()
x = x / d
return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype)

@ -38,7 +38,7 @@ class InplaceAbn(nn.Module):
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, apply_act=True,
act_layer="leaky_relu", act_param=0.01, drop_block=None):
act_layer="leaky_relu", act_param=0.01, drop_layer=None):
super(InplaceAbn, self).__init__()
self.num_features = num_features
self.affine = affine
@ -54,7 +54,7 @@ class InplaceAbn(nn.Module):
self.act_name = 'elu'
elif act_layer == nn.LeakyReLU:
self.act_name = 'leaky_relu'
elif act_layer == nn.Identity:
elif act_layer is None or act_layer == nn.Identity:
self.act_name = 'identity'
else:
assert False, f'Invalid act layer {act_layer.__name__} for IABN'

@ -8,7 +8,7 @@ import torch
from torch import nn
from torch.nn import functional as F
from .conv_bn_act import ConvBnAct
from .conv_bn_act import ConvNormAct
from .helpers import make_divisible
from .trace_utils import _assert
@ -74,10 +74,10 @@ class BilinearAttnTransform(nn.Module):
def __init__(self, in_channels, block_size, groups, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(BilinearAttnTransform, self).__init__()
self.conv1 = ConvBnAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer)
self.conv1 = ConvNormAct(in_channels, groups, 1, act_layer=act_layer, norm_layer=norm_layer)
self.conv_p = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(block_size, 1))
self.conv_q = nn.Conv2d(groups, block_size * block_size * groups, kernel_size=(1, block_size))
self.conv2 = ConvBnAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
self.conv2 = ConvNormAct(in_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
self.block_size = block_size
self.groups = groups
self.in_channels = in_channels
@ -132,9 +132,9 @@ class BatNonLocalAttn(nn.Module):
super().__init__()
if rd_channels is None:
rd_channels = make_divisible(in_channels * rd_ratio, divisor=rd_divisor)
self.conv1 = ConvBnAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
self.conv1 = ConvNormAct(in_channels, rd_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
self.ba = BilinearAttnTransform(rd_channels, block_size, groups, act_layer=act_layer, norm_layer=norm_layer)
self.conv2 = ConvBnAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
self.conv2 = ConvNormAct(rd_channels, in_channels, 1, act_layer=act_layer, norm_layer=norm_layer)
self.dropout = nn.Dropout2d(p=drop_rate)
def forward(self, x):

@ -1,5 +1,7 @@
""" Normalization + Activation Layers
"""
from typing import Union, List
import torch
from torch import nn as nn
from torch.nn import functional as F
@ -14,12 +16,13 @@ class BatchNormAct2d(nn.BatchNorm2d):
compatible with weights trained with separate bn, act. This is why we inherit from BN
instead of composing it as a .bn member.
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None):
def __init__(
self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
super(BatchNormAct2d, self).__init__(
num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
if isinstance(act_layer, str):
act_layer = get_act_layer(act_layer)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
@ -29,8 +32,8 @@ class BatchNormAct2d(nn.BatchNorm2d):
def _forward_jit(self, x):
""" A cut & paste of the contents of the PyTorch BatchNorm2d forward function
"""
# exponential_average_factor is self.momentum set to
# (when it is available) only so that if gets updated
# exponential_average_factor is set to self.momentum
# (when it is available) only so that it gets updated
# in ONNX graph when this node is exported to ONNX.
if self.momentum is None:
exponential_average_factor = 0.0
@ -39,18 +42,38 @@ class BatchNormAct2d(nn.BatchNorm2d):
if self.training and self.track_running_stats:
# TODO: if statement only here to tell the jit to skip emitting this when it is None
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.num_batches_tracked is not None: # type: ignore[has-type]
self.num_batches_tracked = self.num_batches_tracked + 1 # type: ignore[has-type]
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
x = F.batch_norm(
x, self.running_mean, self.running_var, self.weight, self.bias,
self.training or not self.track_running_stats,
exponential_average_factor, self.eps)
return x
r"""
Decide whether the mini-batch stats should be used for normalization rather than the buffers.
Mini-batch stats are used in training mode, and in eval mode when buffers are None.
"""
if self.training:
bn_training = True
else:
bn_training = (self.running_mean is None) and (self.running_var is None)
r"""
Buffers are only updated if they are to be tracked and we are in training mode. Thus they only need to be
passed when the update should occur (i.e. in training mode when they are tracked), or when buffer stats are
used for normalization (i.e. in eval mode when buffers are not None).
"""
return F.batch_norm(
x,
# If buffers are not to be tracked, ensure that they won't be updated
self.running_mean if not self.training or self.track_running_stats else None,
self.running_var if not self.training or self.track_running_stats else None,
self.weight,
self.bias,
bn_training,
exponential_average_factor,
self.eps,
)
@torch.jit.ignore
def _forward_python(self, x):
@ -62,17 +85,27 @@ class BatchNormAct2d(nn.BatchNorm2d):
x = self._forward_jit(x)
else:
x = self._forward_python(x)
x = self.drop(x)
x = self.act(x)
return x
def _num_groups(num_channels, num_groups, group_size):
if group_size:
assert num_channels % group_size == 0
return num_channels // group_size
return num_groups
class GroupNormAct(nn.GroupNorm):
# NOTE num_channel and num_groups order flipped for easier layer swaps / binding of fixed args
def __init__(self, num_channels, num_groups=32, eps=1e-5, affine=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None):
super(GroupNormAct, self).__init__(num_groups, num_channels, eps=eps, affine=affine)
if isinstance(act_layer, str):
act_layer = get_act_layer(act_layer)
def __init__(
self, num_channels, num_groups=32, eps=1e-5, affine=True, group_size=None,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
super(GroupNormAct, self).__init__(
_num_groups(num_channels, num_groups, group_size), num_channels, eps=eps, affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
@ -81,5 +114,47 @@ class GroupNormAct(nn.GroupNorm):
def forward(self, x):
x = F.group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
x = self.drop(x)
x = self.act(x)
return x
class LayerNormAct(nn.LayerNorm):
def __init__(
self, normalization_shape: Union[int, List[int], torch.Size], eps=1e-5, affine=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
def forward(self, x):
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = self.drop(x)
x = self.act(x)
return x
class LayerNormAct2d(nn.LayerNorm):
def __init__(
self, num_channels, eps=1e-5, affine=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_layer=None):
super(LayerNormAct2d, self).__init__(num_channels, eps=eps, elementwise_affine=affine)
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
act_layer = get_act_layer(act_layer) # string -> nn.Module
if act_layer is not None and apply_act:
act_args = dict(inplace=True) if inplace else {}
self.act = act_layer(**act_args)
else:
self.act = nn.Identity()
def forward(self, x):
x = F.layer_norm(
x.permute(0, 2, 3, 1), self.normalized_shape, self.weight, self.bias, self.eps).permute(0, 3, 1, 2)
x = self.drop(x)
x = self.act(x)
return x

@ -0,0 +1,143 @@
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from .helpers import to_2tuple
from .weight_init import trunc_normal_
def rel_logits_1d(q, rel_k, permute_mask: List[int]):
""" Compute relative logits along one dimension
As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
Args:
q: (batch, heads, height, width, dim)
rel_k: (2 * width - 1, dim)
permute_mask: permute output dim according to this
"""
B, H, W, dim = q.shape
x = (q @ rel_k.transpose(-1, -2))
x = x.reshape(-1, W, 2 * W -1)
# pad to shift from relative to absolute indexing
x_pad = F.pad(x, [0, 1]).flatten(1)
x_pad = F.pad(x_pad, [0, W - 1])
# reshape and slice out the padded elements
x_pad = x_pad.reshape(-1, W + 1, 2 * W - 1)
x = x_pad[:, :W, W - 1:]
# reshape and tile
x = x.reshape(B, H, 1, W, W).expand(-1, -1, H, -1, -1)
return x.permute(permute_mask)
class PosEmbedRel(nn.Module):
""" Relative Position Embedding
As per: https://gist.github.com/aravindsrinivas/56359b79f0ce4449bcb04ab4b56a57a2
Originally from: `Attention Augmented Convolutional Networks` - https://arxiv.org/abs/1904.09925
"""
def __init__(self, feat_size, dim_head, scale):
super().__init__()
self.height, self.width = to_2tuple(feat_size)
self.dim_head = dim_head
self.scale = scale
self.height_rel = nn.Parameter(torch.randn(self.height * 2 - 1, dim_head) * self.scale)
self.width_rel = nn.Parameter(torch.randn(self.width * 2 - 1, dim_head) * self.scale)
def forward(self, q):
B, num_heads, HW, _ = q.shape
# relative logits in width dimension.
q = q.reshape(B * num_heads, self.height, self.width, -1)
rel_logits_w = rel_logits_1d(q, self.width_rel, permute_mask=(0, 1, 3, 2, 4))
# relative logits in height dimension.
q = q.transpose(1, 2)
rel_logits_h = rel_logits_1d(q, self.height_rel, permute_mask=(0, 3, 1, 4, 2))
rel_logits = rel_logits_h + rel_logits_w
rel_logits = rel_logits.reshape(B, num_heads, HW, HW)
return rel_logits
class BottleneckAttn(nn.Module):
""" Bottleneck Attention
Paper: `Bottleneck Transformers for Visual Recognition` - https://arxiv.org/abs/2101.11605
"""
def __init__(self, dim, dim_out=None, feat_size=None, stride=1, num_heads=4, qkv_bias=False):
super().__init__()
assert feat_size is not None, 'A concrete feature size matching expected input (H, W) is required'
dim_out = dim_out or dim
assert dim_out % num_heads == 0
self.num_heads = num_heads
self.dim_out = dim_out
self.dim_head = dim_out // num_heads
self.scale = self.dim_head ** -0.5
self.qkv = nn.Conv2d(dim, self.dim_out * 3, 1, bias=qkv_bias)
# NOTE I'm only supporting relative pos embedding for now
self.pos_embed = PosEmbedRel(feat_size, dim_head=self.dim_head, scale=self.scale)
self.pool = nn.AvgPool2d(2, 2) if stride == 2 else nn.Identity()
self.reset_parameters()
def reset_parameters(self):
trunc_normal_(self.qkv.weight, std=self.qkv.weight.shape[1] ** -0.5)
trunc_normal_(self.pos_embed.height_rel, std=self.scale)
trunc_normal_(self.pos_embed.width_rel, std=self.scale)
def forward(self, x):
B, C, H, W = x.shape
assert H == self.pos_embed.height
assert W == self.pos_embed.width
x = self.qkv(x) # B, 3 * num_heads * dim_head, H, W
x = x.reshape(B, -1, self.dim_head, H * W).transpose(-1, -2)
q, k, v = torch.split(x, self.num_heads, dim=1)
attn_logits = (q @ k.transpose(-1, -2)) * self.scale
attn_logits = attn_logits + self.pos_embed(q) # B, num_heads, H * W, H * W
attn_out = attn_logits.softmax(dim=-1)
attn_out = (attn_out @ v).transpose(-1, -2).reshape(B, self.dim_out, H, W) # B, dim_out, H, W
attn_out = self.pool(attn_out)
return attn_out
class PoolingAttention(nn.Module):
def __init__(self, in_features: int, attention_features: int, segments: int, max_pool_kernel: int):
super(PoolingAttention, self).__init__()
self.attn = nn.Linear(in_features, attention_features * 5)
self.segments = segments
self.max_pool_kernel = max_pool_kernel
def forward(self, inp: torch.Tensor): # Shape: [Batch, Sequence, Features]
batch, sequence, features = inp.size()
assert sequence % self.segments == 0
qry, key, val, seg, loc = self.attn(inp).chunk(5, 2) # 5x Shape: [Batch, Sequence, AttentionFeatures]
aggregated = qry.mean(1, keepdim=True) # Shape: [Batch, AttentionFeatures]
aggregated = torch.einsum("ba,bsa->bs", aggregated, key) # Shape: [Batch, Sequence]
aggregated = F.softmax(aggregated, 1)
aggregated = torch.einsum("bs,bsa,bza->bza", aggregated, val,
qry) # Shape: [Batch, Sequence, AttentionFeatures]
pooled_sequence = sequence // self.segments
segment_max_pooled = seg.view(batch, pooled_sequence, self.segments, -1)
segment_max_pooled = segment_max_pooled.max(2, keepdim=True) # Shape: [Batch, PooledSequence, 1, AttentionFeatures]
segment_max_pooled = segment_max_pooled * qry.view(batch, pooled_sequence, self.segments, -1) # Shape: [Batch, PooledSequence, PoolSize, AttentionFeatures]
segment_max_pooled = segment_max_pooled.view(batch, sequence, -1) # Shape: [Batch, Sequence, AttentionFeatures]
loc = loc.transpose(1, 2) # Shape: [Batch, AttentionFeatures, Sequence]
local_max_pooled = F.max_pool1d(loc, self.max_pool_kernel, 1, self.max_pool_kernel // 2)
local_max_pooled = local_max_pooled.transpose(1, 2) # Shape: [Batch, Sequence, AttentionFeatures]
return aggregated + segment_max_pooled + local_max_pooled

@ -7,7 +7,7 @@ Hacked together by / Copyright 2020 Ross Wightman
import torch
from torch import nn as nn
from .conv_bn_act import ConvBnAct
from .conv_bn_act import ConvNormActAa
from .helpers import make_divisible
from .trace_utils import _assert
@ -20,8 +20,7 @@ def _kernel_valid(k):
class SelectiveKernelAttn(nn.Module):
def __init__(self, channels, num_paths=2, attn_channels=32,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
def __init__(self, channels, num_paths=2, attn_channels=32, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
""" Selective Kernel Attention Module
Selective Kernel attention mechanism factored out into its own module.
@ -51,7 +50,7 @@ class SelectiveKernel(nn.Module):
def __init__(self, in_channels, out_channels=None, kernel_size=None, stride=1, dilation=1, groups=1,
rd_ratio=1./16, rd_channels=None, rd_divisor=8, keep_3x3=True, split_input=True,
drop_block=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None):
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None, drop_layer=None):
""" Selective Kernel Convolution Module
As described in Selective Kernel Networks (https://arxiv.org/abs/1903.06586) with some modifications.
@ -72,9 +71,10 @@ class SelectiveKernel(nn.Module):
keep_3x3 (bool): keep all branch convolution kernels as 3x3, changing larger kernels for dilations
split_input (bool): split input channels evenly across each convolution branch, keeps param count lower,
can be viewed as grouping by path, output expands to module out_channels count
drop_block (nn.Module): drop block module
act_layer (nn.Module): activation layer to use
norm_layer (nn.Module): batchnorm/norm layer to use
aa_layer (nn.Module): anti-aliasing module
drop_layer (nn.Module): spatial drop module in convs (drop block, etc)
"""
super(SelectiveKernel, self).__init__()
out_channels = out_channels or in_channels
@ -97,15 +97,14 @@ class SelectiveKernel(nn.Module):
groups = min(out_channels, groups)
conv_kwargs = dict(
stride=stride, groups=groups, drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer,
aa_layer=aa_layer)
stride=stride, groups=groups, act_layer=act_layer, norm_layer=norm_layer,
aa_layer=aa_layer, drop_layer=drop_layer)
self.paths = nn.ModuleList([
ConvBnAct(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
ConvNormActAa(in_channels, out_channels, kernel_size=k, dilation=d, **conv_kwargs)
for k, d in zip(kernel_size, dilation)])
attn_channels = rd_channels or make_divisible(out_channels * rd_ratio, divisor=rd_divisor)
self.attn = SelectiveKernelAttn(out_channels, self.num_paths, attn_channels)
self.drop_block = drop_block
def forward(self, x):
if self.split_input:

@ -8,16 +8,16 @@ Hacked together by / Copyright 2020 Ross Wightman
from torch import nn as nn
from .create_conv2d import create_conv2d
from .create_norm_act import convert_norm_act
from .create_norm_act import get_norm_act_layer
class SeparableConvBnAct(nn.Module):
class SeparableConvNormAct(nn.Module):
""" Separable Conv w/ trailing Norm and Activation
"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, padding='', bias=False,
channel_multiplier=1.0, pw_kernel_size=1, norm_layer=nn.BatchNorm2d, act_layer=nn.ReLU,
apply_act=True, drop_block=None):
super(SeparableConvBnAct, self).__init__()
apply_act=True, drop_layer=None):
super(SeparableConvNormAct, self).__init__()
self.conv_dw = create_conv2d(
in_channels, int(in_channels * channel_multiplier), kernel_size,
@ -26,8 +26,9 @@ class SeparableConvBnAct(nn.Module):
self.conv_pw = create_conv2d(
int(in_channels * channel_multiplier), out_channels, pw_kernel_size, padding=padding, bias=bias)
norm_act_layer = convert_norm_act(norm_layer, act_layer)
self.bn = norm_act_layer(out_channels, apply_act=apply_act, drop_block=drop_block)
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
norm_kwargs = dict(drop_layer=drop_layer) if drop_layer is not None else {}
self.bn = norm_act_layer(out_channels, apply_act=apply_act, **norm_kwargs)
@property
def in_channels(self):
@ -40,11 +41,13 @@ class SeparableConvBnAct(nn.Module):
def forward(self, x):
x = self.conv_dw(x)
x = self.conv_pw(x)
if self.bn is not None:
x = self.bn(x)
x = self.bn(x)
return x
SeparableConvBnAct = SeparableConvNormAct
class SeparableConv2d(nn.Module):
""" Separable Conv
"""

@ -35,11 +35,10 @@ class SplitAttn(nn.Module):
"""
def __init__(self, in_channels, out_channels=None, kernel_size=3, stride=1, padding=None,
dilation=1, groups=1, bias=False, radix=2, rd_ratio=0.25, rd_channels=None, rd_divisor=8,
act_layer=nn.ReLU, norm_layer=None, drop_block=None, **kwargs):
act_layer=nn.ReLU, norm_layer=None, drop_layer=None, **kwargs):
super(SplitAttn, self).__init__()
out_channels = out_channels or in_channels
self.radix = radix
self.drop_block = drop_block
mid_chs = out_channels * radix
if rd_channels is None:
attn_chs = make_divisible(in_channels * radix * rd_ratio, min_value=32, divisor=rd_divisor)
@ -51,6 +50,7 @@ class SplitAttn(nn.Module):
in_channels, mid_chs, kernel_size, stride, padding, dilation,
groups=groups * radix, bias=bias, **kwargs)
self.bn0 = norm_layer(mid_chs) if norm_layer else nn.Identity()
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
self.act0 = act_layer(inplace=True)
self.fc1 = nn.Conv2d(out_channels, attn_chs, 1, groups=groups)
self.bn1 = norm_layer(attn_chs) if norm_layer else nn.Identity()
@ -61,8 +61,7 @@ class SplitAttn(nn.Module):
def forward(self, x):
x = self.conv(x)
x = self.bn0(x)
if self.drop_block is not None:
x = self.drop_block(x)
x = self.drop(x)
x = self.act0(x)
B, RC, H, W = x.shape

@ -20,7 +20,7 @@ from .efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficien
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
from .features import FeatureInfo, FeatureHooks
from .helpers import build_model_with_cfg, default_cfg_for_features
from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, hard_sigmoid
from .layers import SelectAdaptivePool2d, Linear, create_conv2d, get_act_fn, get_norm_act_layer
from .registry import register_model
__all__ = ['MobileNetV3', 'MobileNetV3Features']
@ -95,6 +95,7 @@ class MobileNetV3(nn.Module):
super(MobileNetV3, self).__init__()
act_layer = act_layer or nn.ReLU
norm_layer = norm_layer or nn.BatchNorm2d
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
se_layer = se_layer or SqueezeExcite
self.num_classes = num_classes
self.num_features = num_features
@ -103,8 +104,7 @@ class MobileNetV3(nn.Module):
# Stem
stem_size = round_chs_fn(stem_size)
self.conv_stem = create_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type)
self.bn1 = norm_layer(stem_size)
self.act1 = act_layer(inplace=True)
self.bn1 = norm_act_layer(stem_size, inplace=True)
# Middle stages (IR/ER/DS Blocks)
builder = EfficientNetBuilder(
@ -125,7 +125,7 @@ class MobileNetV3(nn.Module):
efficientnet_init_weights(self)
def as_sequential(self):
layers = [self.conv_stem, self.bn1, self.act1]
layers = [self.conv_stem, self.bn1]
layers.extend(self.blocks)
layers.extend([self.global_pool, self.conv_head, self.act2])
layers.extend([nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier])
@ -144,7 +144,6 @@ class MobileNetV3(nn.Module):
def forward_features(self, x):
x = self.conv_stem(x)
x = self.bn1(x)
x = self.act1(x)
x = self.blocks(x)
x = self.global_pool(x)
x = self.conv_head(x)

@ -9,7 +9,7 @@ import torch.nn as nn
import torch.nn.functional as F
from .helpers import build_model_with_cfg
from .layers import ConvBnAct, create_conv2d, create_pool2d, create_classifier
from .layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier
from .registry import register_model
__all__ = ['NASNetALarge']
@ -420,7 +420,7 @@ class NASNetALarge(nn.Module):
channels = self.num_features // 24
# 24 is default value for the architecture
self.conv0 = ConvBnAct(
self.conv0 = ConvNormAct(
in_channels=in_chans, out_channels=self.stem_size, kernel_size=3, padding=0, stride=2,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False)

@ -13,7 +13,7 @@ import torch.nn as nn
import torch.nn.functional as F
from .helpers import build_model_with_cfg
from .layers import ConvBnAct, create_conv2d, create_pool2d, create_classifier
from .layers import ConvNormAct, create_conv2d, create_pool2d, create_classifier
from .registry import register_model
__all__ = ['PNASNet5Large']
@ -243,7 +243,7 @@ class PNASNet5Large(nn.Module):
self.num_features = 4320
assert output_stride == 32
self.conv_0 = ConvBnAct(
self.conv_0 = ConvNormAct(
in_chans, 96, kernel_size=3, stride=2, padding=0,
norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.1), apply_act=False)

@ -15,45 +15,76 @@ Hacked together by / Copyright 2020 Ross Wightman
"""
import numpy as np
import torch.nn as nn
from dataclasses import dataclass
from functools import partial
from typing import Optional, Union, Callable
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import ClassifierHead, AvgPool2dSame, ConvBnAct, SEModule, DropPath
from .helpers import build_model_with_cfg, named_apply
from .layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, get_act_layer, GroupNormAct
from .registry import register_model
def _mcfg(**kwargs):
cfg = dict(se_ratio=0., bottle_ratio=1., stem_width=32)
cfg.update(**kwargs)
return cfg
@dataclass
class RegNetCfg:
depth: int = 21
w0: int = 80
wa: float = 42.63
wm: float = 2.66
group_size: int = 24
bottle_ratio: float = 1.
se_ratio: float = 0.
stem_width: int = 32
downsample: Optional[str] = 'conv1x1'
linear_out: bool = False
act_layer: Union[str, Callable] = 'relu'
norm_layer: Union[str, Callable] = 'batchnorm'
# Model FLOPS = three trailing digits * 10^8
model_cfgs = dict(
regnetx_002=_mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13),
regnetx_004=_mcfg(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22),
regnetx_006=_mcfg(w0=48, wa=36.97, wm=2.24, group_w=24, depth=16),
regnetx_008=_mcfg(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16),
regnetx_016=_mcfg(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18),
regnetx_032=_mcfg(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25),
regnetx_040=_mcfg(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23),
regnetx_064=_mcfg(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17),
regnetx_080=_mcfg(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23),
regnetx_120=_mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19),
regnetx_160=_mcfg(w0=216, wa=55.59, wm=2.1, group_w=128, depth=22),
regnetx_320=_mcfg(w0=320, wa=69.86, wm=2.0, group_w=168, depth=23),
regnety_002=_mcfg(w0=24, wa=36.44, wm=2.49, group_w=8, depth=13, se_ratio=0.25),
regnety_004=_mcfg(w0=48, wa=27.89, wm=2.09, group_w=8, depth=16, se_ratio=0.25),
regnety_006=_mcfg(w0=48, wa=32.54, wm=2.32, group_w=16, depth=15, se_ratio=0.25),
regnety_008=_mcfg(w0=56, wa=38.84, wm=2.4, group_w=16, depth=14, se_ratio=0.25),
regnety_016=_mcfg(w0=48, wa=20.71, wm=2.65, group_w=24, depth=27, se_ratio=0.25),
regnety_032=_mcfg(w0=80, wa=42.63, wm=2.66, group_w=24, depth=21, se_ratio=0.25),
regnety_040=_mcfg(w0=96, wa=31.41, wm=2.24, group_w=64, depth=22, se_ratio=0.25),
regnety_064=_mcfg(w0=112, wa=33.22, wm=2.27, group_w=72, depth=25, se_ratio=0.25),
regnety_080=_mcfg(w0=192, wa=76.82, wm=2.19, group_w=56, depth=17, se_ratio=0.25),
regnety_120=_mcfg(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, se_ratio=0.25),
regnety_160=_mcfg(w0=200, wa=106.23, wm=2.48, group_w=112, depth=18, se_ratio=0.25),
regnety_320=_mcfg(w0=232, wa=115.89, wm=2.53, group_w=232, depth=20, se_ratio=0.25),
# RegNet-X
regnetx_002=RegNetCfg(w0=24, wa=36.44, wm=2.49, group_size=8, depth=13),
regnetx_004=RegNetCfg(w0=24, wa=24.48, wm=2.54, group_size=16, depth=22),
regnetx_006=RegNetCfg(w0=48, wa=36.97, wm=2.24, group_size=24, depth=16),
regnetx_008=RegNetCfg(w0=56, wa=35.73, wm=2.28, group_size=16, depth=16),
regnetx_016=RegNetCfg(w0=80, wa=34.01, wm=2.25, group_size=24, depth=18),
regnetx_032=RegNetCfg(w0=88, wa=26.31, wm=2.25, group_size=48, depth=25),
regnetx_040=RegNetCfg(w0=96, wa=38.65, wm=2.43, group_size=40, depth=23),
regnetx_064=RegNetCfg(w0=184, wa=60.83, wm=2.07, group_size=56, depth=17),
regnetx_080=RegNetCfg(w0=80, wa=49.56, wm=2.88, group_size=120, depth=23),
regnetx_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19),
regnetx_160=RegNetCfg(w0=216, wa=55.59, wm=2.1, group_size=128, depth=22),
regnetx_320=RegNetCfg(w0=320, wa=69.86, wm=2.0, group_size=168, depth=23),
# RegNet-Y
regnety_002=RegNetCfg(w0=24, wa=36.44, wm=2.49, group_size=8, depth=13, se_ratio=0.25),
regnety_004=RegNetCfg(w0=48, wa=27.89, wm=2.09, group_size=8, depth=16, se_ratio=0.25),
regnety_006=RegNetCfg(w0=48, wa=32.54, wm=2.32, group_size=16, depth=15, se_ratio=0.25),
regnety_008=RegNetCfg(w0=56, wa=38.84, wm=2.4, group_size=16, depth=14, se_ratio=0.25),
regnety_016=RegNetCfg(w0=48, wa=20.71, wm=2.65, group_size=24, depth=27, se_ratio=0.25),
regnety_032=RegNetCfg(w0=80, wa=42.63, wm=2.66, group_size=24, depth=21, se_ratio=0.25),
regnety_040=RegNetCfg(w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25),
regnety_064=RegNetCfg(w0=112, wa=33.22, wm=2.27, group_size=72, depth=25, se_ratio=0.25),
regnety_080=RegNetCfg(w0=192, wa=76.82, wm=2.19, group_size=56, depth=17, se_ratio=0.25),
regnety_120=RegNetCfg(w0=168, wa=73.36, wm=2.37, group_size=112, depth=19, se_ratio=0.25),
regnety_160=RegNetCfg(w0=200, wa=106.23, wm=2.48, group_size=112, depth=18, se_ratio=0.25),
regnety_320=RegNetCfg(w0=232, wa=115.89, wm=2.53, group_size=232, depth=20, se_ratio=0.25),
# Experimental
regnety_040s_gn=RegNetCfg(
w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25,
act_layer='silu', norm_layer=partial(GroupNormAct, group_size=16)),
# RegNet-Z (unverified)
regnetz_005=RegNetCfg(
depth=21, w0=16, wa=10.7, wm=2.51, group_size=4, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, linear_out=True, act_layer='silu',
),
regnetz_040=RegNetCfg(
depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, linear_out=True, act_layer='silu',
),
)
@ -80,6 +111,7 @@ default_cfgs = dict(
regnetx_120=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_120-65d5521e.pth'),
regnetx_160=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_160-c98c4112.pth'),
regnetx_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnetx_320-8ea38b93.pth'),
regnety_002=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_002-e68ca334.pth'),
regnety_004=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_004-0db870e6.pth'),
regnety_006=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_006-c67e57ec.pth'),
@ -96,6 +128,11 @@ default_cfgs = dict(
url='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth', # from Facebook DeiT GitHub repository
crop_pct=1.0, test_input_size=(3, 288, 288)),
regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'),
regnety_040s_gn=_cfg(url=''),
regnetz_005=_cfg(url=''),
regnetz_040=_cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
)
@ -125,6 +162,40 @@ def generate_regnet(width_slope, width_initial, width_mult, depth, q=8):
return widths, num_stages, max_stage, widths_cont
def downsample_conv(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None):
norm_layer = norm_layer or nn.BatchNorm2d
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
dilation = dilation if kernel_size > 1 else 1
return ConvNormAct(
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, apply_act=False)
def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None):
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
norm_layer = norm_layer or nn.BatchNorm2d
avg_stride = stride if dilation == 1 else 1
pool = nn.Identity()
if stride > 1 or dilation > 1:
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
return nn.Sequential(*[
pool, ConvNormAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, apply_act=False)])
def create_shortcut(downsample_type, in_chs, out_chs, kernel_size, stride, dilation=(1, 1), norm_layer=None):
assert downsample_type in ('avg', 'conv1x1', '', None)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
if not downsample_type:
return None # no shortcut, no downsample
elif downsample_type == 'avg':
return downsample_avg(in_chs, out_chs, stride=stride, dilation=dilation[0], norm_layer=norm_layer)
else:
return downsample_conv(
in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation[0], norm_layer=norm_layer)
else:
return nn.Identity() # identity shortcut (no downsample)
class Bottleneck(nn.Module):
""" RegNet Bottleneck
@ -132,97 +203,70 @@ class Bottleneck(nn.Module):
after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels.
"""
def __init__(self, in_chs, out_chs, stride=1, dilation=1, bottleneck_ratio=1, group_width=1, se_ratio=0.25,
downsample=None, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, aa_layer=None,
drop_block=None, drop_path=None):
def __init__(self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25,
downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
drop_block=None, drop_path_rate=0.):
super(Bottleneck, self).__init__()
bottleneck_chs = int(round(out_chs * bottleneck_ratio))
groups = bottleneck_chs // group_width
cargs = dict(act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer, drop_block=drop_block)
self.conv1 = ConvBnAct(in_chs, bottleneck_chs, kernel_size=1, **cargs)
self.conv2 = ConvBnAct(
bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=dilation,
groups=groups, **cargs)
act_layer = get_act_layer(act_layer)
bottleneck_chs = int(round(out_chs * bottle_ratio))
groups = bottleneck_chs // group_size
cargs = dict(act_layer=act_layer, norm_layer=norm_layer)
self.conv1 = ConvNormAct(in_chs, bottleneck_chs, kernel_size=1, **cargs)
self.conv2 = ConvNormAct(
bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=dilation[0],
groups=groups, drop_layer=drop_block, **cargs)
if se_ratio:
se_channels = int(round(in_chs * se_ratio))
self.se = SEModule(bottleneck_chs, rd_channels=se_channels)
self.se = SEModule(bottleneck_chs, rd_channels=se_channels, act_layer=act_layer)
else:
self.se = None
cargs['act_layer'] = None
self.conv3 = ConvBnAct(bottleneck_chs, out_chs, kernel_size=1, **cargs)
self.act3 = act_layer(inplace=True)
self.downsample = downsample
self.drop_path = drop_path
def zero_init_last_bn(self):
self.se = nn.Identity()
self.conv3 = ConvNormAct(bottleneck_chs, out_chs, kernel_size=1, apply_act=False, **cargs)
self.act3 = nn.Identity() if linear_out else act_layer()
self.downsample = create_shortcut(downsample, in_chs, out_chs, 1, stride, dilation, norm_layer=norm_layer)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
def zero_init_last(self):
nn.init.zeros_(self.conv3.bn.weight)
def forward(self, x):
shortcut = x
x = self.conv1(x)
x = self.conv2(x)
if self.se is not None:
x = self.se(x)
x = self.se(x)
x = self.conv3(x)
if self.drop_path is not None:
x = self.drop_path(x)
if self.downsample is not None:
shortcut = self.downsample(shortcut)
x += shortcut
# NOTE stuck with downsample as the attr name due to weight compatibility
# now represents the shortcut, no shortcut if None, and non-downsample shortcut == nn.Identity()
x = x + self.drop_path(self.downsample(shortcut))
x = self.act3(x)
return x
def downsample_conv(
in_chs, out_chs, kernel_size, stride=1, dilation=1, norm_layer=None):
norm_layer = norm_layer or nn.BatchNorm2d
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
dilation = dilation if kernel_size > 1 else 1
return ConvBnAct(
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, act_layer=None)
def downsample_avg(
in_chs, out_chs, kernel_size, stride=1, dilation=1, norm_layer=None):
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
norm_layer = norm_layer or nn.BatchNorm2d
avg_stride = stride if dilation == 1 else 1
pool = nn.Identity()
if stride > 1 or dilation > 1:
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
return nn.Sequential(*[
pool, ConvBnAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, act_layer=None)])
class RegStage(nn.Module):
"""Stage (sequence of blocks w/ the same output shape)."""
def __init__(self, in_chs, out_chs, stride, dilation, depth, bottle_ratio, group_width,
block_fn=Bottleneck, se_ratio=0., drop_path_rates=None, drop_block=None):
def __init__(
self, depth, in_chs, out_chs, stride, dilation, bottle_ratio=1.0, group_size=8, block_fn=Bottleneck,
se_ratio=0., downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
drop_path_rates=None, drop_block=None):
super(RegStage, self).__init__()
block_kwargs = {} # FIXME setup to pass various aa, norm, act layer common args
block_kwargs = dict(
bottle_ratio=bottle_ratio, group_size=group_size, se_ratio=se_ratio, downsample=downsample,
linear_out=linear_out, act_layer=act_layer, norm_layer=norm_layer, drop_block=drop_block)
first_dilation = 1 if dilation in (1, 2) else 2
for i in range(depth):
block_stride = stride if i == 0 else 1
block_in_chs = in_chs if i == 0 else out_chs
block_dilation = first_dilation if i == 0 else dilation
if drop_path_rates is not None and drop_path_rates[i] > 0.:
drop_path = DropPath(drop_path_rates[i])
else:
drop_path = None
if (block_in_chs != out_chs) or (block_stride != 1):
proj_block = downsample_conv(block_in_chs, out_chs, 1, block_stride, block_dilation)
else:
proj_block = None
block_dilation = (first_dilation, dilation)
dpr = drop_path_rates[i] if drop_path_rates is not None else 0.
name = "b{}".format(i + 1)
self.add_module(
name, block_fn(
block_in_chs, out_chs, block_stride, block_dilation, bottle_ratio, group_width, se_ratio,
downsample=proj_block, drop_block=drop_block, drop_path=drop_path, **block_kwargs)
block_in_chs, out_chs, stride=block_stride, dilation=block_dilation,
drop_path_rate=dpr, **block_kwargs)
)
first_dilation = dilation
def forward(self, x):
for block in self.children():
@ -231,33 +275,34 @@ class RegStage(nn.Module):
class RegNet(nn.Module):
"""RegNet model.
"""RegNet-X, Y, and Z Models
Paper: https://arxiv.org/abs/2003.13678
Original Impl: https://github.com/facebookresearch/pycls/blob/master/pycls/models/regnet.py
"""
def __init__(self, cfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg', drop_rate=0.,
drop_path_rate=0., zero_init_last_bn=True):
def __init__(
self, cfg: RegNetCfg, in_chans=3, num_classes=1000, output_stride=32, global_pool='avg',
drop_rate=0., drop_path_rate=0., zero_init_last=True):
super().__init__()
# TODO add drop block, drop path, anti-aliasing, custom bn/act args
self.num_classes = num_classes
self.drop_rate = drop_rate
assert output_stride in (8, 16, 32)
# Construct the stem
stem_width = cfg['stem_width']
self.stem = ConvBnAct(in_chans, stem_width, 3, stride=2)
stem_width = cfg.stem_width
self.stem = ConvNormAct(in_chans, stem_width, 3, stride=2, act_layer=cfg.act_layer, norm_layer=cfg.norm_layer)
self.feature_info = [dict(num_chs=stem_width, reduction=2, module='stem')]
# Construct the stages
prev_width = stem_width
curr_stride = 2
stage_params = self._get_stage_params(cfg, output_stride=output_stride, drop_path_rate=drop_path_rate)
se_ratio = cfg['se_ratio']
for i, stage_args in enumerate(stage_params):
stage_name = "s{}".format(i + 1)
self.add_module(stage_name, RegStage(prev_width, **stage_args, se_ratio=se_ratio))
self.add_module(stage_name, RegStage(
in_chs=prev_width, se_ratio=cfg.se_ratio, downsample=cfg.downsample, linear_out=cfg.linear_out,
act_layer=cfg.act_layer, norm_layer=cfg.norm_layer, **stage_args))
prev_width = stage_args['out_chs']
curr_stride *= stage_args['stride']
self.feature_info += [dict(num_chs=prev_width, reduction=curr_stride, module=stage_name)]
@ -267,31 +312,18 @@ class RegNet(nn.Module):
self.head = ClassifierHead(
in_chs=prev_width, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0.0, std=0.01)
nn.init.zeros_(m.bias)
if zero_init_last_bn:
for m in self.modules():
if hasattr(m, 'zero_init_last_bn'):
m.zero_init_last_bn()
def _get_stage_params(self, cfg, default_stride=2, output_stride=32, drop_path_rate=0.):
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
def _get_stage_params(self, cfg: RegNetCfg, default_stride=2, output_stride=32, drop_path_rate=0.):
# Generate RegNet ws per block
w_a, w_0, w_m, d = cfg['wa'], cfg['w0'], cfg['wm'], cfg['depth']
widths, num_stages, _, _ = generate_regnet(w_a, w_0, w_m, d)
widths, num_stages, _, _ = generate_regnet(cfg.wa, cfg.w0, cfg.wm, cfg.depth)
# Convert to per stage format
stage_widths, stage_depths = np.unique(widths, return_counts=True)
# Use the same group width, bottleneck mult and stride for each stage
stage_groups = [cfg['group_w'] for _ in range(num_stages)]
stage_bottle_ratios = [cfg['bottle_ratio'] for _ in range(num_stages)]
stage_groups = [cfg.group_size for _ in range(num_stages)]
stage_bottle_ratios = [cfg.bottle_ratio for _ in range(num_stages)]
stage_strides = []
stage_dilations = []
net_stride = 2
@ -305,11 +337,11 @@ class RegNet(nn.Module):
net_stride *= stride
stage_strides.append(stride)
stage_dilations.append(dilation)
stage_dpr = np.split(np.linspace(0, drop_path_rate, d), np.cumsum(stage_depths[:-1]))
stage_dpr = np.split(np.linspace(0, drop_path_rate, cfg.depth), np.cumsum(stage_depths[:-1]))
# Adjust the compatibility of ws and gws
stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups)
param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_width', 'drop_path_rates']
param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_size', 'drop_path_rates']
stage_params = [
dict(zip(param_names, params)) for params in
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups,
@ -333,6 +365,19 @@ class RegNet(nn.Module):
return x
def _init_weights(module, name='', zero_init_last=False):
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(module, nn.BatchNorm2d):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.01)
nn.init.zeros_(module.bias)
elif hasattr(module, 'zero_init_last'):
module.zero_init_last()
def _filter_fn(state_dict):
""" convert patch embedding weight from manual patchify + linear proj to conv"""
if 'model' in state_dict:
@ -492,3 +537,27 @@ def regnety_160(pretrained=False, **kwargs):
def regnety_320(pretrained=False, **kwargs):
"""RegNetY-32GF"""
return _create_regnet('regnety_320', pretrained, **kwargs)
@register_model
def regnety_040s_gn(pretrained=False, **kwargs):
"""RegNetY-4.0GF w/ GroupNorm """
return _create_regnet('regnety_040s_gn', pretrained, **kwargs)
@register_model
def regnetz_005(pretrained=False, **kwargs):
"""RegNetZ-500MF
NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py
but it's not clear it is equivalent to paper model as not detailed in the paper.
"""
return _create_regnet('regnetz_005', pretrained, **kwargs)
@register_model
def regnetz_040(pretrained=False, **kwargs):
"""RegNetZ-4.0GF
NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py
but it's not clear it is equivalent to paper model as not detailed in the paper.
"""
return _create_regnet('regnetz_040', pretrained, **kwargs)

@ -75,7 +75,6 @@ class ResNestBottleneck(nn.Module):
else:
avd_stride = 0
self.radix = radix
self.drop_block = drop_block
self.conv1 = nn.Conv2d(inplanes, group_width, kernel_size=1, bias=False)
self.bn1 = norm_layer(group_width)
@ -85,14 +84,16 @@ class ResNestBottleneck(nn.Module):
if self.radix >= 1:
self.conv2 = SplitAttn(
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_block=drop_block)
dilation=first_dilation, groups=cardinality, radix=radix, norm_layer=norm_layer, drop_layer=drop_block)
self.bn2 = nn.Identity()
self.drop_block = nn.Identity()
self.act2 = nn.Identity()
else:
self.conv2 = nn.Conv2d(
group_width, group_width, kernel_size=3, stride=stride, padding=first_dilation,
dilation=first_dilation, groups=cardinality, bias=False)
self.bn2 = norm_layer(group_width)
self.drop_block = drop_block() if drop_block is not None else nn.Identity()
self.act2 = act_layer(inplace=True)
self.avd_last = nn.AvgPool2d(3, avd_stride, padding=1) if avd_stride > 0 and not avd_first else None
@ -109,8 +110,6 @@ class ResNestBottleneck(nn.Module):
out = self.conv1(x)
out = self.bn1(out)
if self.drop_block is not None:
out = self.drop_block(out)
out = self.act1(out)
if self.avd_first is not None:
@ -118,8 +117,7 @@ class ResNestBottleneck(nn.Module):
out = self.conv2(out)
out = self.bn2(out)
if self.drop_block is not None:
out = self.drop_block(out)
out = self.drop_block(out)
out = self.act2(out)
if self.avd_last is not None:
@ -127,8 +125,6 @@ class ResNestBottleneck(nn.Module):
out = self.conv3(out)
out = self.bn3(out)
if self.drop_block is not None:
out = self.drop_block(out)
if self.downsample is not None:
shortcut = self.downsample(x)

@ -307,8 +307,9 @@ class BasicBlock(nn.Module):
inplanes, first_planes, kernel_size=3, stride=1 if use_aa else stride, padding=first_dilation,
dilation=first_dilation, bias=False)
self.bn1 = norm_layer(first_planes)
self.drop_block = drop_block() if drop_block is not None else nn.Identity()
self.act1 = act_layer(inplace=True)
self.aa = aa_layer(channels=first_planes, stride=stride) if use_aa else None
self.aa = aa_layer(channels=first_planes, stride=stride) if use_aa else nn.Identity()
self.conv2 = nn.Conv2d(
first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
@ -320,7 +321,6 @@ class BasicBlock(nn.Module):
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.drop_block = drop_block
self.drop_path = drop_path
def zero_init_last_bn(self):
@ -331,16 +331,12 @@ class BasicBlock(nn.Module):
x = self.conv1(x)
x = self.bn1(x)
if self.drop_block is not None:
x = self.drop_block(x)
x = self.drop_block(x)
x = self.act1(x)
if self.aa is not None:
x = self.aa(x)
x = self.aa(x)
x = self.conv2(x)
x = self.bn2(x)
if self.drop_block is not None:
x = self.drop_block(x)
if self.se is not None:
x = self.se(x)
@ -378,8 +374,9 @@ class Bottleneck(nn.Module):
first_planes, width, kernel_size=3, stride=1 if use_aa else stride,
padding=first_dilation, dilation=first_dilation, groups=cardinality, bias=False)
self.bn2 = norm_layer(width)
self.drop_block = drop_block() if drop_block is not None else nn.Identity()
self.act2 = act_layer(inplace=True)
self.aa = aa_layer(channels=width, stride=stride) if use_aa else None
self.aa = aa_layer(channels=width, stride=stride) if use_aa else nn.Identity()
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
self.bn3 = norm_layer(outplanes)
@ -390,7 +387,6 @@ class Bottleneck(nn.Module):
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.drop_block = drop_block
self.drop_path = drop_path
def zero_init_last_bn(self):
@ -401,22 +397,16 @@ class Bottleneck(nn.Module):
x = self.conv1(x)
x = self.bn1(x)
if self.drop_block is not None:
x = self.drop_block(x)
x = self.act1(x)
x = self.conv2(x)
x = self.bn2(x)
if self.drop_block is not None:
x = self.drop_block(x)
x = self.drop_block(x)
x = self.act2(x)
if self.aa is not None:
x = self.aa(x)
x = self.aa(x)
x = self.conv3(x)
x = self.bn3(x)
if self.drop_block is not None:
x = self.drop_block(x)
if self.se is not None:
x = self.se(x)
@ -463,11 +453,11 @@ def downsample_avg(
])
def drop_blocks(drop_block_rate=0.):
def drop_blocks(drop_prob=0.):
return [
None, None,
DropBlock2d(drop_block_rate, 5, 0.25) if drop_block_rate else None,
DropBlock2d(drop_block_rate, 3, 1.00) if drop_block_rate else None]
partial(DropBlock2d, drop_prob=drop_prob, block_size=5, gamma_scale=0.25) if drop_prob else None,
partial(DropBlock2d, drop_prob=drop_prob, block_size=3, gamma_scale=1.00) if drop_prob else None]
def make_blocks(

@ -17,7 +17,7 @@ from math import ceil
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath, make_divisible, SEModule
from .layers import ClassifierHead, create_act_layer, ConvNormAct, DropPath, make_divisible, SEModule
from .registry import register_model
from .efficientnet_builder import efficientnet_init_weights
@ -63,19 +63,19 @@ class LinearBottleneck(nn.Module):
if exp_ratio != 1.:
dw_chs = make_divisible(round(in_chs * exp_ratio), divisor=ch_div)
self.conv_exp = ConvBnAct(in_chs, dw_chs, act_layer=act_layer)
self.conv_exp = ConvNormAct(in_chs, dw_chs, act_layer=act_layer)
else:
dw_chs = in_chs
self.conv_exp = None
self.conv_dw = ConvBnAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False)
self.conv_dw = ConvNormAct(dw_chs, dw_chs, 3, stride=stride, groups=dw_chs, apply_act=False)
if se_ratio > 0:
self.se = SEWithNorm(dw_chs, rd_channels=make_divisible(int(dw_chs * se_ratio), ch_div))
else:
self.se = None
self.act_dw = create_act_layer(dw_act_layer)
self.conv_pwl = ConvBnAct(dw_chs, out_chs, 1, apply_act=False)
self.conv_pwl = ConvNormAct(dw_chs, out_chs, 1, apply_act=False)
self.drop_path = drop_path
def feat_channels(self, exp=False):
@ -138,7 +138,7 @@ def _build_blocks(
feat_chs += [features[-1].feat_channels()]
pen_chs = make_divisible(1280 * width_mult, divisor=ch_div)
feature_info += [dict(num_chs=feat_chs[-1], reduction=curr_stride, module=f'features.{len(features) - 1}')]
features.append(ConvBnAct(prev_chs, pen_chs, act_layer=act_layer))
features.append(ConvNormAct(prev_chs, pen_chs, act_layer=act_layer))
return features, feature_info
@ -153,7 +153,7 @@ class ReXNetV1(nn.Module):
assert output_stride == 32 # FIXME support dilation
stem_base_chs = 32 / width_mult if width_mult < 1.0 else 32
stem_chs = make_divisible(round(stem_base_chs * width_mult), divisor=ch_div)
self.stem = ConvBnAct(in_chans, stem_chs, 3, stride=2, act_layer=act_layer)
self.stem = ConvNormAct(in_chans, stem_chs, 3, stride=2, act_layer=act_layer)
block_cfg = _block_cfg(width_mult, depth_mult, initial_chs, final_chs, se_ratio, ch_div)
features, self.feature_info = _build_blocks(

@ -14,7 +14,7 @@ from torch import nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import SelectiveKernel, ConvBnAct, create_attn
from .layers import SelectiveKernel, ConvNormAct, ConvNormActAa, create_attn
from .registry import register_model
from .resnet import ResNet
@ -52,7 +52,7 @@ class SelectiveKernelBasic(nn.Module):
super(SelectiveKernelBasic, self).__init__()
sk_kwargs = sk_kwargs or {}
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer)
conv_kwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
assert base_width == 64, 'BasicBlock doest not support changing base width'
first_planes = planes // reduce_first
@ -60,16 +60,13 @@ class SelectiveKernelBasic(nn.Module):
first_dilation = first_dilation or dilation
self.conv1 = SelectiveKernel(
inplanes, first_planes, stride=stride, dilation=first_dilation, **conv_kwargs, **sk_kwargs)
conv_kwargs['act_layer'] = None
self.conv2 = ConvBnAct(
first_planes, outplanes, kernel_size=3, dilation=dilation, **conv_kwargs)
inplanes, first_planes, stride=stride, dilation=first_dilation,
aa_layer=aa_layer, drop_layer=drop_block, **conv_kwargs, **sk_kwargs)
self.conv2 = ConvNormAct(
first_planes, outplanes, kernel_size=3, dilation=dilation, apply_act=False, **conv_kwargs)
self.se = create_attn(attn_layer, outplanes)
self.act = act_layer(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.drop_block = drop_block
self.drop_path = drop_path
def zero_init_last_bn(self):
@ -100,24 +97,20 @@ class SelectiveKernelBottleneck(nn.Module):
super(SelectiveKernelBottleneck, self).__init__()
sk_kwargs = sk_kwargs or {}
conv_kwargs = dict(drop_block=drop_block, act_layer=act_layer, norm_layer=norm_layer, aa_layer=aa_layer)
conv_kwargs = dict(act_layer=act_layer, norm_layer=norm_layer)
width = int(math.floor(planes * (base_width / 64)) * cardinality)
first_planes = width // reduce_first
outplanes = planes * self.expansion
first_dilation = first_dilation or dilation
self.conv1 = ConvBnAct(inplanes, first_planes, kernel_size=1, **conv_kwargs)
self.conv1 = ConvNormAct(inplanes, first_planes, kernel_size=1, **conv_kwargs)
self.conv2 = SelectiveKernel(
first_planes, width, stride=stride, dilation=first_dilation, groups=cardinality,
**conv_kwargs, **sk_kwargs)
conv_kwargs['act_layer'] = None
self.conv3 = ConvBnAct(width, outplanes, kernel_size=1, **conv_kwargs)
aa_layer=aa_layer, drop_layer=drop_block, **conv_kwargs, **sk_kwargs)
self.conv3 = ConvNormAct(width, outplanes, kernel_size=1, apply_act=False, **conv_kwargs)
self.se = create_attn(attn_layer, outplanes)
self.act = act_layer(inplace=True)
self.downsample = downsample
self.stride = stride
self.dilation = dilation
self.drop_block = drop_block
self.drop_path = drop_path
def zero_init_last_bn(self):

@ -20,8 +20,8 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .registry import register_model
from .helpers import build_model_with_cfg
from .layers import ConvBnAct, SeparableConvBnAct, BatchNormAct2d, ClassifierHead, DropPath,\
create_attn, create_norm_act, get_norm_act_layer
from .layers import ConvNormAct, SeparableConvNormAct, BatchNormAct2d, ClassifierHead, DropPath,\
create_attn, create_norm_act_layer, get_norm_act_layer
# model cfgs adapted from https://github.com/youngwanLEE/vovnet-detectron2 &
@ -189,23 +189,23 @@ class OsaBlock(nn.Module):
next_in_chs = in_chs
if self.depthwise and next_in_chs != mid_chs:
assert not residual
self.conv_reduction = ConvBnAct(next_in_chs, mid_chs, 1, **conv_kwargs)
self.conv_reduction = ConvNormAct(next_in_chs, mid_chs, 1, **conv_kwargs)
else:
self.conv_reduction = None
mid_convs = []
for i in range(layer_per_block):
if self.depthwise:
conv = SeparableConvBnAct(mid_chs, mid_chs, **conv_kwargs)
conv = SeparableConvNormAct(mid_chs, mid_chs, **conv_kwargs)
else:
conv = ConvBnAct(next_in_chs, mid_chs, 3, **conv_kwargs)
conv = ConvNormAct(next_in_chs, mid_chs, 3, **conv_kwargs)
next_in_chs = mid_chs
mid_convs.append(conv)
self.conv_mid = SequentialAppendList(*mid_convs)
# feature aggregation
next_in_chs = in_chs + layer_per_block * mid_chs
self.conv_concat = ConvBnAct(next_in_chs, out_chs, **conv_kwargs)
self.conv_concat = ConvNormAct(next_in_chs, out_chs, **conv_kwargs)
if attn:
self.attn = create_attn(attn, out_chs)
@ -283,9 +283,9 @@ class VovNet(nn.Module):
# Stem module
last_stem_stride = stem_stride // 2
conv_type = SeparableConvBnAct if cfg["depthwise"] else ConvBnAct
conv_type = SeparableConvNormAct if cfg["depthwise"] else ConvNormAct
self.stem = nn.Sequential(*[
ConvBnAct(in_chans, stem_chs[0], 3, stride=2, **conv_kwargs),
ConvNormAct(in_chans, stem_chs[0], 3, stride=2, **conv_kwargs),
conv_type(stem_chs[0], stem_chs[1], 3, stride=1, **conv_kwargs),
conv_type(stem_chs[1], stem_chs[2], 3, stride=last_stem_stride, **conv_kwargs),
])
@ -395,12 +395,12 @@ def eca_vovnet39b(pretrained=False, **kwargs):
@register_model
def ese_vovnet39b_evos(pretrained=False, **kwargs):
def norm_act_fn(num_features, **nkwargs):
return create_norm_act('evonorms0', num_features, jit=False, **nkwargs)
return create_norm_act_layer('evonorms0', num_features, jit=False, **nkwargs)
return _create_vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs)
@register_model
def ese_vovnet99b_iabn(pretrained=False, **kwargs):
norm_layer = get_norm_act_layer('iabn')
norm_layer = get_norm_act_layer('iabn', act_layer='leaky_relu')
return _create_vovnet(
'ese_vovnet99b_iabn', pretrained=pretrained, norm_layer=norm_layer, act_layer=nn.LeakyReLU, **kwargs)

@ -12,7 +12,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg
from .layers import ClassifierHead, ConvBnAct, create_conv2d
from .layers import ClassifierHead, ConvNormAct, create_conv2d, get_norm_act_layer
from .layers.helpers import to_3tuple
from .registry import register_model
@ -37,12 +37,14 @@ default_cfgs = dict(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_65-c9ae96e8.pth'),
xception71=_cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'),
xception41p=_cfg(url=''),
)
class SeparableConv2d(nn.Module):
def __init__(
self, inplanes, planes, kernel_size=3, stride=1, dilation=1, padding='',
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=1, padding='',
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
super(SeparableConv2d, self).__init__()
self.kernel_size = kernel_size
@ -50,31 +52,48 @@ class SeparableConv2d(nn.Module):
# depthwise convolution
self.conv_dw = create_conv2d(
inplanes, inplanes, kernel_size, stride=stride,
in_chs, in_chs, kernel_size, stride=stride,
padding=padding, dilation=dilation, depthwise=True)
self.bn_dw = norm_layer(inplanes)
if act_layer is not None:
self.act_dw = act_layer(inplace=True)
else:
self.act_dw = None
self.bn_dw = norm_layer(in_chs)
self.act_dw = act_layer(inplace=True) if act_layer is not None else nn.Identity()
# pointwise convolution
self.conv_pw = create_conv2d(inplanes, planes, kernel_size=1)
self.bn_pw = norm_layer(planes)
if act_layer is not None:
self.act_pw = act_layer(inplace=True)
else:
self.act_pw = None
self.conv_pw = create_conv2d(in_chs, out_chs, kernel_size=1)
self.bn_pw = norm_layer(out_chs)
self.act_pw = act_layer(inplace=True) if act_layer is not None else nn.Identity()
def forward(self, x):
x = self.conv_dw(x)
x = self.bn_dw(x)
if self.act_dw is not None:
x = self.act_dw(x)
x = self.act_dw(x)
x = self.conv_pw(x)
x = self.bn_pw(x)
if self.act_pw is not None:
x = self.act_pw(x)
x = self.act_pw(x)
return x
class PreSeparableConv2d(nn.Module):
def __init__(
self, in_chs, out_chs, kernel_size=3, stride=1, dilation=1, padding='',
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, first_act=True):
super(PreSeparableConv2d, self).__init__()
norm_act_layer = get_norm_act_layer(norm_layer, act_layer=act_layer)
self.kernel_size = kernel_size
self.dilation = dilation
self.norm = norm_act_layer(in_chs, inplace=True) if first_act else nn.Identity()
# depthwise convolution
self.conv_dw = create_conv2d(
in_chs, in_chs, kernel_size, stride=stride,
padding=padding, dilation=dilation, depthwise=True)
# pointwise convolution
self.conv_pw = create_conv2d(in_chs, out_chs, kernel_size=1)
def forward(self, x):
x = self.norm(x)
x = self.conv_dw(x)
x = self.conv_pw(x)
return x
@ -88,8 +107,8 @@ class XceptionModule(nn.Module):
self.out_channels = out_chs[-1]
self.no_skip = no_skip
if not no_skip and (self.out_channels != self.in_channels or stride != 1):
self.shortcut = ConvBnAct(
in_chs, self.out_channels, 1, stride=stride, norm_layer=norm_layer, act_layer=None)
self.shortcut = ConvNormAct(
in_chs, self.out_channels, 1, stride=stride, norm_layer=norm_layer, apply_act=False)
else:
self.shortcut = None
@ -97,7 +116,7 @@ class XceptionModule(nn.Module):
self.stack = nn.Sequential()
for i in range(3):
if start_with_relu:
self.stack.add_module(f'act{i + 1}', nn.ReLU(inplace=i > 0))
self.stack.add_module(f'act{i + 1}', act_layer(inplace=i > 0))
self.stack.add_module(f'conv{i + 1}', SeparableConv2d(
in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type,
act_layer=separable_act_layer, norm_layer=norm_layer))
@ -113,11 +132,42 @@ class XceptionModule(nn.Module):
return x
class PreXceptionModule(nn.Module):
def __init__(
self, in_chs, out_chs, stride=1, dilation=1, pad_type='',
no_skip=False, act_layer=nn.ReLU, norm_layer=None):
super(PreXceptionModule, self).__init__()
out_chs = to_3tuple(out_chs)
self.in_channels = in_chs
self.out_channels = out_chs[-1]
self.no_skip = no_skip
if not no_skip and (self.out_channels != self.in_channels or stride != 1):
self.shortcut = create_conv2d(in_chs, self.out_channels, 1, stride=stride)
else:
self.shortcut = nn.Identity()
self.norm = get_norm_act_layer(norm_layer, act_layer=act_layer)(in_chs, inplace=True)
self.stack = nn.Sequential()
for i in range(3):
self.stack.add_module(f'conv{i + 1}', PreSeparableConv2d(
in_chs, out_chs[i], 3, stride=stride if i == 2 else 1, dilation=dilation, padding=pad_type,
act_layer=act_layer, norm_layer=norm_layer, first_act=i > 0))
in_chs = out_chs[i]
def forward(self, x):
x = self.norm(x)
skip = x
x = self.stack(x)
if not self.no_skip:
x = x + self.shortcut(skip)
return x
class XceptionAligned(nn.Module):
"""Modified Aligned Xception
"""
def __init__(self, block_cfg, num_classes=1000, in_chans=3, output_stride=32,
def __init__(self, block_cfg, num_classes=1000, in_chans=3, output_stride=32, preact=False,
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0., global_pool='avg'):
super(XceptionAligned, self).__init__()
self.num_classes = num_classes
@ -126,31 +176,33 @@ class XceptionAligned(nn.Module):
layer_args = dict(act_layer=act_layer, norm_layer=norm_layer)
self.stem = nn.Sequential(*[
ConvBnAct(in_chans, 32, kernel_size=3, stride=2, **layer_args),
ConvBnAct(32, 64, kernel_size=3, stride=1, **layer_args)
ConvNormAct(in_chans, 32, kernel_size=3, stride=2, **layer_args),
create_conv2d(32, 64, kernel_size=3, stride=1) if preact else
ConvNormAct(32, 64, kernel_size=3, stride=1, **layer_args)
])
curr_dilation = 1
curr_stride = 2
self.feature_info = []
self.blocks = nn.Sequential()
module_fn = PreXceptionModule if preact else XceptionModule
for i, b in enumerate(block_cfg):
b['dilation'] = curr_dilation
if b['stride'] > 1:
self.feature_info += [dict(
num_chs=to_3tuple(b['out_chs'])[-2], reduction=curr_stride, module=f'blocks.{i}.stack.act3')]
name = f'blocks.{i}.stack.conv2' if preact else f'blocks.{i}.stack.act3'
self.feature_info += [dict(num_chs=to_3tuple(b['out_chs'])[-2], reduction=curr_stride, module=name)]
next_stride = curr_stride * b['stride']
if next_stride > output_stride:
curr_dilation *= b['stride']
b['stride'] = 1
else:
curr_stride = next_stride
self.blocks.add_module(str(i), XceptionModule(**b, **layer_args))
self.blocks.add_module(str(i), module_fn(**b, **layer_args))
self.num_features = self.blocks[-1].out_channels
self.feature_info += [dict(
num_chs=self.num_features, reduction=curr_stride, module='blocks.' + str(len(self.blocks) - 1))]
self.act = act_layer(inplace=True) if preact else nn.Identity()
self.head = ClassifierHead(
in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
@ -163,6 +215,7 @@ class XceptionAligned(nn.Module):
def forward_features(self, x):
x = self.stem(x)
x = self.blocks(x)
x = self.act(x)
return x
def forward(self, x):
@ -236,3 +289,22 @@ def xception71(pretrained=False, **kwargs):
]
model_args = dict(block_cfg=block_cfg, norm_layer=partial(nn.BatchNorm2d, eps=.001, momentum=.1), **kwargs)
return _xception('xception71', pretrained=pretrained, **model_args)
@register_model
def xception41p(pretrained=False, **kwargs):
""" Modified Aligned Xception-41 w/ Pre-Act
"""
block_cfg = [
# entry flow
dict(in_chs=64, out_chs=128, stride=2),
dict(in_chs=128, out_chs=256, stride=2),
dict(in_chs=256, out_chs=728, stride=2),
# middle flow
*([dict(in_chs=728, out_chs=728, stride=1)] * 8),
# exit flow
dict(in_chs=728, out_chs=(728, 1024, 1024), stride=2),
dict(in_chs=1024, out_chs=(1536, 1536, 2048), no_skip=True, stride=1),
]
model_args = dict(block_cfg=block_cfg, preact=True, norm_layer=nn.BatchNorm2d, **kwargs)
return _xception('xception41p', pretrained=pretrained, **model_args)

Loading…
Cancel
Save