Merge branch 'norm_norm_norm' into bits_and_tpu

pull/1239/head
Ross Wightman 2 years ago
commit cbc4f33220

@ -120,7 +120,7 @@ default_cfgs = {
# FIXME experimental
'efficientnet_b0_gn': _cfg(
url=''),
'efficientnet_b0_g8': _cfg(
'efficientnet_b0_g8_gn': _cfg(
url=''),
'efficientnet_b0_g16_evos': _cfg(
url=''),
@ -1389,10 +1389,11 @@ def efficientnet_b0_gn(pretrained=False, **kwargs):
@register_model
def efficientnet_b0_g8(pretrained=False, **kwargs):
""" EfficientNet-B0 w/ group conv + BN"""
def efficientnet_b0_g8_gn(pretrained=False, **kwargs):
""" EfficientNet-B0 w/ group conv + GroupNorm"""
model = _gen_efficientnet(
'efficientnet_b0_g8', group_size=8, pretrained=pretrained, **kwargs)
'efficientnet_b0_g8_gn', group_size=8, norm_layer=partial(GroupNormAct, group_size=8),
pretrained=pretrained, **kwargs)
return model

@ -19,11 +19,7 @@ def num_groups(group_size, channels):
return 1 # normal conv with 1 group
else:
# NOTE group_size == 1 -> depthwise conv
#assert channels % group_size == 0
if channels % group_size != 0:
num_groups = math.floor(channels / group_size)
print(channels, group_size, num_groups)
return int(num_groups)
assert channels % group_size == 0
return channels // group_size
@ -87,7 +83,7 @@ class ConvBnAct(nn.Module):
x = self.conv(x)
x = self.bn1(x)
if self.has_skip:
x = x + self.drop_path(shortcut)
x = self.drop_path(x) + shortcut
return x
@ -131,7 +127,7 @@ class DepthwiseSeparableConv(nn.Module):
x = self.conv_pw(x)
x = self.bn2(x)
if self.has_skip:
x = x + self.drop_path(shortcut)
x = self.drop_path(x) + shortcut
return x
@ -190,7 +186,7 @@ class InvertedResidual(nn.Module):
x = self.conv_pwl(x)
x = self.bn3(x)
if self.has_skip:
x = x + self.drop_path(shortcut)
x = self.drop_path(x) + shortcut
return x
@ -225,7 +221,7 @@ class CondConvResidual(InvertedResidual):
x = self.conv_pwl(x, routing_weights)
x = self.bn3(x)
if self.has_skip:
x = x + self.drop_path(shortcut)
x = self.drop_path(x) + shortcut
return x
@ -281,5 +277,5 @@ class EdgeResidual(nn.Module):
x = self.conv_pwl(x)
x = self.bn2(x)
if self.has_skip:
x = x + self.drop_path(shortcut)
x = self.drop_path(x) + shortcut
return x

@ -40,7 +40,7 @@ def get_bn_args_tf():
def resolve_bn_args(kwargs):
bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {}
bn_args = {}
bn_momentum = kwargs.pop('bn_momentum', None)
if bn_momentum is not None:
bn_args['momentum'] = bn_momentum

@ -47,13 +47,6 @@ def create_model(
"""
source_name, model_name = split_model_name(model_name)
# Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args
is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3'])
if not is_efficientnet:
kwargs.pop('bn_tf', None)
kwargs.pop('bn_momentum', None)
kwargs.pop('bn_eps', None)
# handle backwards compat with drop_connect -> drop_path change
drop_connect_rate = kwargs.pop('drop_connect_rate', None)
if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None:

@ -13,15 +13,18 @@ Weights from original impl have been modified
Hacked together by / Copyright 2020 Ross Wightman
"""
import numpy as np
import torch.nn as nn
import math
from dataclasses import dataclass
from functools import partial
from typing import Optional, Union, Callable
import numpy as np
import torch.nn as nn
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg, named_apply
from .layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, get_act_layer, GroupNormAct
from .layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct
from .layers import get_act_layer, get_norm_act_layer, create_conv2d
from .registry import register_model
@ -37,6 +40,8 @@ class RegNetCfg:
stem_width: int = 32
downsample: Optional[str] = 'conv1x1'
linear_out: bool = False
preact: bool = False
num_features: int = 0
act_layer: Union[str, Callable] = 'relu'
norm_layer: Union[str, Callable] = 'batchnorm'
@ -75,15 +80,23 @@ model_cfgs = dict(
regnety_040s_gn=RegNetCfg(
w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25,
act_layer='silu', norm_layer=partial(GroupNormAct, group_size=16)),
# regnetv = 'preact regnet y'
regnetv_040=RegNetCfg(
depth=22, w0=96, wa=31.41, wm=2.24, group_size=64, se_ratio=0.25, preact=True, act_layer='silu'),
# regnetw = 'preact regnet z'
regnetw_040=RegNetCfg(
depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, preact=True, num_features=1536, act_layer='silu',
),
# RegNet-Z (unverified)
regnetz_005=RegNetCfg(
depth=21, w0=16, wa=10.7, wm=2.51, group_size=4, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, linear_out=True, act_layer='silu',
downsample=None, linear_out=True, num_features=1024, act_layer='silu',
),
regnetz_040=RegNetCfg(
depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25,
downsample=None, linear_out=True, act_layer='silu',
downsample=None, linear_out=True, num_features=1536, act_layer='silu',
),
)
@ -130,6 +143,8 @@ default_cfgs = dict(
regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'),
regnety_040s_gn=_cfg(url=''),
regnetv_040=_cfg(url=''),
regnetw_040=_cfg(url=''),
regnetz_005=_cfg(url=''),
regnetz_040=_cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)),
@ -162,15 +177,18 @@ def generate_regnet(width_slope, width_initial, width_mult, depth, q=8):
return widths, num_stages, max_stage, widths_cont
def downsample_conv(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None):
def downsample_conv(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None, preact=False):
norm_layer = norm_layer or nn.BatchNorm2d
kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
dilation = dilation if kernel_size > 1 else 1
return ConvNormAct(
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, apply_act=False)
if preact:
return create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation)
else:
return ConvNormAct(
in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, apply_act=False)
def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None):
def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None, preact=False):
""" AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment."""
norm_layer = norm_layer or nn.BatchNorm2d
avg_stride = stride if dilation == 1 else 1
@ -178,20 +196,24 @@ def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_la
if stride > 1 or dilation > 1:
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
return nn.Sequential(*[
pool, ConvNormAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, apply_act=False)])
if preact:
conv = create_conv2d(in_chs, out_chs, 1, stride=1)
else:
conv = ConvNormAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, apply_act=False)
return nn.Sequential(*[pool, conv])
def create_shortcut(downsample_type, in_chs, out_chs, kernel_size, stride, dilation=(1, 1), norm_layer=None):
def create_shortcut(
downsample_type, in_chs, out_chs, kernel_size, stride, dilation=(1, 1), norm_layer=None, preact=False):
assert downsample_type in ('avg', 'conv1x1', '', None)
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
dargs = dict(stride=stride, dilation=dilation[0], norm_layer=norm_layer, preact=preact)
if not downsample_type:
return None # no shortcut, no downsample
elif downsample_type == 'avg':
return downsample_avg(in_chs, out_chs, stride=stride, dilation=dilation[0], norm_layer=norm_layer)
return downsample_avg(in_chs, out_chs, **dargs)
else:
return downsample_conv(
in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation[0], norm_layer=norm_layer)
return downsample_conv(in_chs, out_chs, kernel_size=kernel_size, **dargs)
else:
return nn.Identity() # identity shortcut (no downsample)
@ -203,9 +225,10 @@ class Bottleneck(nn.Module):
after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels.
"""
def __init__(self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25,
downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
drop_block=None, drop_path_rate=0.):
def __init__(
self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25,
downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
drop_block=None, drop_path_rate=0.):
super(Bottleneck, self).__init__()
act_layer = get_act_layer(act_layer)
bottleneck_chs = int(round(out_chs * bottle_ratio))
@ -238,22 +261,68 @@ class Bottleneck(nn.Module):
if self.downsample is not None:
# NOTE stuck with downsample as the attr name due to weight compatibility
# now represents the shortcut, no shortcut if None, and non-downsample shortcut == nn.Identity()
x = x + self.drop_path(self.downsample(shortcut))
x = self.drop_path(x) + self.downsample(shortcut)
x = self.act3(x)
return x
class PreBottleneck(nn.Module):
""" RegNet Bottleneck
This is almost exactly the same as a ResNet Bottlneck. The main difference is the SE block is moved from
after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels.
"""
def __init__(
self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25,
downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
drop_block=None, drop_path_rate=0.):
super(PreBottleneck, self).__init__()
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
bottleneck_chs = int(round(out_chs * bottle_ratio))
groups = bottleneck_chs // group_size
self.norm1 = norm_act_layer(in_chs)
self.conv1 = create_conv2d(in_chs, bottleneck_chs, kernel_size=1)
self.norm2 = norm_act_layer(bottleneck_chs)
self.conv2 = create_conv2d(
bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=dilation[0], groups=groups)
if se_ratio:
se_channels = int(round(in_chs * se_ratio))
self.se = SEModule(bottleneck_chs, rd_channels=se_channels, act_layer=act_layer)
else:
self.se = nn.Identity()
self.norm3 = norm_act_layer(bottleneck_chs)
self.conv3 = create_conv2d(bottleneck_chs, out_chs, kernel_size=1)
self.downsample = create_shortcut(downsample, in_chs, out_chs, 1, stride, dilation, preact=True)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
def zero_init_last(self):
pass
def forward(self, x):
x = self.norm1(x)
shortcut = x
x = self.conv1(x)
x = self.norm2(x)
x = self.conv2(x)
x = self.se(x)
x = self.norm3(x)
x = self.conv3(x)
if self.downsample is not None:
# NOTE stuck with downsample as the attr name due to weight compatibility
# now represents the shortcut, no shortcut if None, and non-downsample shortcut == nn.Identity()
x = self.drop_path(x) + self.downsample(shortcut)
return x
class RegStage(nn.Module):
"""Stage (sequence of blocks w/ the same output shape)."""
def __init__(
self, depth, in_chs, out_chs, stride, dilation, bottle_ratio=1.0, group_size=8, block_fn=Bottleneck,
se_ratio=0., downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d,
drop_path_rates=None, drop_block=None):
self, depth, in_chs, out_chs, stride, dilation,
drop_path_rates=None, block_fn=Bottleneck, **block_kwargs):
super(RegStage, self).__init__()
block_kwargs = dict(
bottle_ratio=bottle_ratio, group_size=group_size, se_ratio=se_ratio, downsample=downsample,
linear_out=linear_out, act_layer=act_layer, norm_layer=norm_layer, drop_block=drop_block)
first_dilation = 1 if dilation in (1, 2) else 2
for i in range(depth):
block_stride = stride if i == 0 else 1
@ -291,30 +360,40 @@ class RegNet(nn.Module):
# Construct the stem
stem_width = cfg.stem_width
self.stem = ConvNormAct(in_chans, stem_width, 3, stride=2, act_layer=cfg.act_layer, norm_layer=cfg.norm_layer)
na_args = dict(act_layer=cfg.act_layer, norm_layer=cfg.norm_layer)
if cfg.preact:
self.stem = create_conv2d(in_chans, stem_width, 3, stride=2)
else:
self.stem = ConvNormAct(in_chans, stem_width, 3, stride=2, **na_args)
self.feature_info = [dict(num_chs=stem_width, reduction=2, module='stem')]
# Construct the stages
prev_width = stem_width
curr_stride = 2
stage_params = self._get_stage_params(cfg, output_stride=output_stride, drop_path_rate=drop_path_rate)
for i, stage_args in enumerate(stage_params):
per_stage_args, common_args = self._get_stage_args(
cfg, output_stride=output_stride, drop_path_rate=drop_path_rate)
block_fn = PreBottleneck if cfg.preact else Bottleneck
for i, stage_args in enumerate(per_stage_args):
stage_name = "s{}".format(i + 1)
self.add_module(stage_name, RegStage(
in_chs=prev_width, se_ratio=cfg.se_ratio, downsample=cfg.downsample, linear_out=cfg.linear_out,
act_layer=cfg.act_layer, norm_layer=cfg.norm_layer, **stage_args))
self.add_module(stage_name, RegStage(in_chs=prev_width, block_fn=block_fn, **stage_args, **common_args))
prev_width = stage_args['out_chs']
curr_stride *= stage_args['stride']
self.feature_info += [dict(num_chs=prev_width, reduction=curr_stride, module=stage_name)]
# Construct the head
self.num_features = prev_width
if cfg.num_features:
self.final_conv = ConvNormAct(prev_width, cfg.num_features, kernel_size=1, **na_args)
self.num_features = cfg.num_features
else:
final_act = cfg.linear_out or cfg.preact
self.final_conv = get_act_layer(cfg.act_layer)() if final_act else nn.Identity()
self.num_features = prev_width
self.head = ClassifierHead(
in_chs=prev_width, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
def _get_stage_params(self, cfg: RegNetCfg, default_stride=2, output_stride=32, drop_path_rate=0.):
def _get_stage_args(self, cfg: RegNetCfg, default_stride=2, output_stride=32, drop_path_rate=0.):
# Generate RegNet ws per block
widths, num_stages, _, _ = generate_regnet(cfg.wa, cfg.w0, cfg.wm, cfg.depth)
@ -341,12 +420,15 @@ class RegNet(nn.Module):
# Adjust the compatibility of ws and gws
stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups)
param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_size', 'drop_path_rates']
stage_params = [
dict(zip(param_names, params)) for params in
arg_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_size', 'drop_path_rates']
per_stage_args = [
dict(zip(arg_names, params)) for params in
zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups,
stage_dpr)]
return stage_params
common_args = dict(
downsample=cfg.downsample, se_ratio=cfg.se_ratio, linear_out=cfg.linear_out,
act_layer=cfg.act_layer, norm_layer=cfg.norm_layer)
return per_stage_args, common_args
def get_classifier(self):
return self.head.fc
@ -367,14 +449,16 @@ class RegNet(nn.Module):
def _init_weights(module, name='', zero_init_last=False):
if isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(module, nn.BatchNorm2d):
nn.init.ones_(module.weight)
nn.init.zeros_(module.bias)
fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels
fan_out //= module.groups
module.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Linear):
nn.init.normal_(module.weight, mean=0.0, std=0.01)
nn.init.zeros_(module.bias)
elif hasattr(module, 'zero_init_last'):
if module.bias is not None:
nn.init.zeros_(module.bias)
elif zero_init_last and hasattr(module, 'zero_init_last'):
module.zero_init_last()
@ -545,13 +629,25 @@ def regnety_040s_gn(pretrained=False, **kwargs):
return _create_regnet('regnety_040s_gn', pretrained, **kwargs)
@register_model
def regnetv_040(pretrained=False, **kwargs):
""""""
return _create_regnet('regnetv_040', pretrained, **kwargs)
@register_model
def regnetw_040(pretrained=False, **kwargs):
""""""
return _create_regnet('regnetw_040', pretrained, **kwargs)
@register_model
def regnetz_005(pretrained=False, **kwargs):
"""RegNetZ-500MF
NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py
but it's not clear it is equivalent to paper model as not detailed in the paper.
"""
return _create_regnet('regnetz_005', pretrained, **kwargs)
return _create_regnet('regnetz_005', pretrained, zero_init_last=False, **kwargs)
@register_model
@ -560,4 +656,4 @@ def regnetz_040(pretrained=False, **kwargs):
NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py
but it's not clear it is equivalent to paper model as not detailed in the paper.
"""
return _create_regnet('regnetz_040', pretrained, **kwargs)
return _create_regnet('regnetz_040', pretrained, zero_init_last=False, **kwargs)

Loading…
Cancel
Save