diff --git a/timm/models/efficientnet.py b/timm/models/efficientnet.py index b38b3c0e..b6ba0209 100644 --- a/timm/models/efficientnet.py +++ b/timm/models/efficientnet.py @@ -120,7 +120,7 @@ default_cfgs = { # FIXME experimental 'efficientnet_b0_gn': _cfg( url=''), - 'efficientnet_b0_g8': _cfg( + 'efficientnet_b0_g8_gn': _cfg( url=''), 'efficientnet_b0_g16_evos': _cfg( url=''), @@ -1389,10 +1389,11 @@ def efficientnet_b0_gn(pretrained=False, **kwargs): @register_model -def efficientnet_b0_g8(pretrained=False, **kwargs): - """ EfficientNet-B0 w/ group conv + BN""" +def efficientnet_b0_g8_gn(pretrained=False, **kwargs): + """ EfficientNet-B0 w/ group conv + GroupNorm""" model = _gen_efficientnet( - 'efficientnet_b0_g8', group_size=8, pretrained=pretrained, **kwargs) + 'efficientnet_b0_g8_gn', group_size=8, norm_layer=partial(GroupNormAct, group_size=8), + pretrained=pretrained, **kwargs) return model diff --git a/timm/models/efficientnet_blocks.py b/timm/models/efficientnet_blocks.py index 0e91319b..b842e82c 100644 --- a/timm/models/efficientnet_blocks.py +++ b/timm/models/efficientnet_blocks.py @@ -19,11 +19,7 @@ def num_groups(group_size, channels): return 1 # normal conv with 1 group else: # NOTE group_size == 1 -> depthwise conv - #assert channels % group_size == 0 - if channels % group_size != 0: - num_groups = math.floor(channels / group_size) - print(channels, group_size, num_groups) - return int(num_groups) + assert channels % group_size == 0 return channels // group_size @@ -87,7 +83,7 @@ class ConvBnAct(nn.Module): x = self.conv(x) x = self.bn1(x) if self.has_skip: - x = x + self.drop_path(shortcut) + x = self.drop_path(x) + shortcut return x @@ -131,7 +127,7 @@ class DepthwiseSeparableConv(nn.Module): x = self.conv_pw(x) x = self.bn2(x) if self.has_skip: - x = x + self.drop_path(shortcut) + x = self.drop_path(x) + shortcut return x @@ -190,7 +186,7 @@ class InvertedResidual(nn.Module): x = self.conv_pwl(x) x = self.bn3(x) if self.has_skip: - x = x + self.drop_path(shortcut) + x = self.drop_path(x) + shortcut return x @@ -225,7 +221,7 @@ class CondConvResidual(InvertedResidual): x = self.conv_pwl(x, routing_weights) x = self.bn3(x) if self.has_skip: - x = x + self.drop_path(shortcut) + x = self.drop_path(x) + shortcut return x @@ -281,5 +277,5 @@ class EdgeResidual(nn.Module): x = self.conv_pwl(x) x = self.bn2(x) if self.has_skip: - x = x + self.drop_path(shortcut) + x = self.drop_path(x) + shortcut return x diff --git a/timm/models/efficientnet_builder.py b/timm/models/efficientnet_builder.py index a102a872..023f10a3 100644 --- a/timm/models/efficientnet_builder.py +++ b/timm/models/efficientnet_builder.py @@ -40,7 +40,7 @@ def get_bn_args_tf(): def resolve_bn_args(kwargs): - bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {} + bn_args = {} bn_momentum = kwargs.pop('bn_momentum', None) if bn_momentum is not None: bn_args['momentum'] = bn_momentum diff --git a/timm/models/factory.py b/timm/models/factory.py index d040a9ff..6d3fd982 100644 --- a/timm/models/factory.py +++ b/timm/models/factory.py @@ -47,13 +47,6 @@ def create_model( """ source_name, model_name = split_model_name(model_name) - # Only EfficientNet and MobileNetV3 models have support for batchnorm params or drop_connect_rate passed as args - is_efficientnet = is_model_in_modules(model_name, ['efficientnet', 'mobilenetv3']) - if not is_efficientnet: - kwargs.pop('bn_tf', None) - kwargs.pop('bn_momentum', None) - kwargs.pop('bn_eps', None) - # handle backwards compat with drop_connect -> drop_path change drop_connect_rate = kwargs.pop('drop_connect_rate', None) if drop_connect_rate is not None and kwargs.get('drop_path_rate', None) is None: diff --git a/timm/models/regnet.py b/timm/models/regnet.py index 8a0689f7..93e31bd8 100644 --- a/timm/models/regnet.py +++ b/timm/models/regnet.py @@ -13,15 +13,18 @@ Weights from original impl have been modified Hacked together by / Copyright 2020 Ross Wightman """ -import numpy as np -import torch.nn as nn +import math from dataclasses import dataclass from functools import partial from typing import Optional, Union, Callable +import numpy as np +import torch.nn as nn + from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, named_apply -from .layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, get_act_layer, GroupNormAct +from .layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct +from .layers import get_act_layer, get_norm_act_layer, create_conv2d from .registry import register_model @@ -37,6 +40,8 @@ class RegNetCfg: stem_width: int = 32 downsample: Optional[str] = 'conv1x1' linear_out: bool = False + preact: bool = False + num_features: int = 0 act_layer: Union[str, Callable] = 'relu' norm_layer: Union[str, Callable] = 'batchnorm' @@ -75,15 +80,23 @@ model_cfgs = dict( regnety_040s_gn=RegNetCfg( w0=96, wa=31.41, wm=2.24, group_size=64, depth=22, se_ratio=0.25, act_layer='silu', norm_layer=partial(GroupNormAct, group_size=16)), + # regnetv = 'preact regnet y' + regnetv_040=RegNetCfg( + depth=22, w0=96, wa=31.41, wm=2.24, group_size=64, se_ratio=0.25, preact=True, act_layer='silu'), + # regnetw = 'preact regnet z' + regnetw_040=RegNetCfg( + depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25, + downsample=None, preact=True, num_features=1536, act_layer='silu', + ), # RegNet-Z (unverified) regnetz_005=RegNetCfg( depth=21, w0=16, wa=10.7, wm=2.51, group_size=4, bottle_ratio=4.0, se_ratio=0.25, - downsample=None, linear_out=True, act_layer='silu', + downsample=None, linear_out=True, num_features=1024, act_layer='silu', ), regnetz_040=RegNetCfg( depth=28, w0=48, wa=14.5, wm=2.226, group_size=8, bottle_ratio=4.0, se_ratio=0.25, - downsample=None, linear_out=True, act_layer='silu', + downsample=None, linear_out=True, num_features=1536, act_layer='silu', ), ) @@ -130,6 +143,8 @@ default_cfgs = dict( regnety_320=_cfg(url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-regnet/regnety_320-ba464b29.pth'), regnety_040s_gn=_cfg(url=''), + regnetv_040=_cfg(url=''), + regnetw_040=_cfg(url=''), regnetz_005=_cfg(url=''), regnetz_040=_cfg(url='', input_size=(3, 256, 256), pool_size=(8, 8)), @@ -162,15 +177,18 @@ def generate_regnet(width_slope, width_initial, width_mult, depth, q=8): return widths, num_stages, max_stage, widths_cont -def downsample_conv(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None): +def downsample_conv(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None, preact=False): norm_layer = norm_layer or nn.BatchNorm2d kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size dilation = dilation if kernel_size > 1 else 1 - return ConvNormAct( - in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, apply_act=False) + if preact: + return create_conv2d(in_chs, out_chs, kernel_size, stride=stride, dilation=dilation) + else: + return ConvNormAct( + in_chs, out_chs, kernel_size, stride=stride, dilation=dilation, norm_layer=norm_layer, apply_act=False) -def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None): +def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_layer=None, preact=False): """ AvgPool Downsampling as in 'D' ResNet variants. This is not in RegNet space but I might experiment.""" norm_layer = norm_layer or nn.BatchNorm2d avg_stride = stride if dilation == 1 else 1 @@ -178,20 +196,24 @@ def downsample_avg(in_chs, out_chs, kernel_size=1, stride=1, dilation=1, norm_la if stride > 1 or dilation > 1: avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) - return nn.Sequential(*[ - pool, ConvNormAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, apply_act=False)]) + if preact: + conv = create_conv2d(in_chs, out_chs, 1, stride=1) + else: + conv = ConvNormAct(in_chs, out_chs, 1, stride=1, norm_layer=norm_layer, apply_act=False) + return nn.Sequential(*[pool, conv]) -def create_shortcut(downsample_type, in_chs, out_chs, kernel_size, stride, dilation=(1, 1), norm_layer=None): +def create_shortcut( + downsample_type, in_chs, out_chs, kernel_size, stride, dilation=(1, 1), norm_layer=None, preact=False): assert downsample_type in ('avg', 'conv1x1', '', None) if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: + dargs = dict(stride=stride, dilation=dilation[0], norm_layer=norm_layer, preact=preact) if not downsample_type: return None # no shortcut, no downsample elif downsample_type == 'avg': - return downsample_avg(in_chs, out_chs, stride=stride, dilation=dilation[0], norm_layer=norm_layer) + return downsample_avg(in_chs, out_chs, **dargs) else: - return downsample_conv( - in_chs, out_chs, kernel_size=kernel_size, stride=stride, dilation=dilation[0], norm_layer=norm_layer) + return downsample_conv(in_chs, out_chs, kernel_size=kernel_size, **dargs) else: return nn.Identity() # identity shortcut (no downsample) @@ -203,9 +225,10 @@ class Bottleneck(nn.Module): after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels. """ - def __init__(self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25, - downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - drop_block=None, drop_path_rate=0.): + def __init__( + self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25, + downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + drop_block=None, drop_path_rate=0.): super(Bottleneck, self).__init__() act_layer = get_act_layer(act_layer) bottleneck_chs = int(round(out_chs * bottle_ratio)) @@ -238,22 +261,68 @@ class Bottleneck(nn.Module): if self.downsample is not None: # NOTE stuck with downsample as the attr name due to weight compatibility # now represents the shortcut, no shortcut if None, and non-downsample shortcut == nn.Identity() - x = x + self.drop_path(self.downsample(shortcut)) + x = self.drop_path(x) + self.downsample(shortcut) x = self.act3(x) return x +class PreBottleneck(nn.Module): + """ RegNet Bottleneck + + This is almost exactly the same as a ResNet Bottlneck. The main difference is the SE block is moved from + after conv3 to after conv2. Otherwise, it's just redefining the arguments for groups/bottleneck channels. + """ + + def __init__( + self, in_chs, out_chs, stride=1, dilation=(1, 1), bottle_ratio=1, group_size=1, se_ratio=0.25, + downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, + drop_block=None, drop_path_rate=0.): + super(PreBottleneck, self).__init__() + norm_act_layer = get_norm_act_layer(norm_layer, act_layer) + bottleneck_chs = int(round(out_chs * bottle_ratio)) + groups = bottleneck_chs // group_size + + self.norm1 = norm_act_layer(in_chs) + self.conv1 = create_conv2d(in_chs, bottleneck_chs, kernel_size=1) + self.norm2 = norm_act_layer(bottleneck_chs) + self.conv2 = create_conv2d( + bottleneck_chs, bottleneck_chs, kernel_size=3, stride=stride, dilation=dilation[0], groups=groups) + if se_ratio: + se_channels = int(round(in_chs * se_ratio)) + self.se = SEModule(bottleneck_chs, rd_channels=se_channels, act_layer=act_layer) + else: + self.se = nn.Identity() + self.norm3 = norm_act_layer(bottleneck_chs) + self.conv3 = create_conv2d(bottleneck_chs, out_chs, kernel_size=1) + self.downsample = create_shortcut(downsample, in_chs, out_chs, 1, stride, dilation, preact=True) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + + def zero_init_last(self): + pass + + def forward(self, x): + x = self.norm1(x) + shortcut = x + x = self.conv1(x) + x = self.norm2(x) + x = self.conv2(x) + x = self.se(x) + x = self.norm3(x) + x = self.conv3(x) + if self.downsample is not None: + # NOTE stuck with downsample as the attr name due to weight compatibility + # now represents the shortcut, no shortcut if None, and non-downsample shortcut == nn.Identity() + x = self.drop_path(x) + self.downsample(shortcut) + return x + + class RegStage(nn.Module): """Stage (sequence of blocks w/ the same output shape).""" def __init__( - self, depth, in_chs, out_chs, stride, dilation, bottle_ratio=1.0, group_size=8, block_fn=Bottleneck, - se_ratio=0., downsample='conv1x1', linear_out=False, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, - drop_path_rates=None, drop_block=None): + self, depth, in_chs, out_chs, stride, dilation, + drop_path_rates=None, block_fn=Bottleneck, **block_kwargs): super(RegStage, self).__init__() - block_kwargs = dict( - bottle_ratio=bottle_ratio, group_size=group_size, se_ratio=se_ratio, downsample=downsample, - linear_out=linear_out, act_layer=act_layer, norm_layer=norm_layer, drop_block=drop_block) first_dilation = 1 if dilation in (1, 2) else 2 for i in range(depth): block_stride = stride if i == 0 else 1 @@ -291,30 +360,40 @@ class RegNet(nn.Module): # Construct the stem stem_width = cfg.stem_width - self.stem = ConvNormAct(in_chans, stem_width, 3, stride=2, act_layer=cfg.act_layer, norm_layer=cfg.norm_layer) + na_args = dict(act_layer=cfg.act_layer, norm_layer=cfg.norm_layer) + if cfg.preact: + self.stem = create_conv2d(in_chans, stem_width, 3, stride=2) + else: + self.stem = ConvNormAct(in_chans, stem_width, 3, stride=2, **na_args) self.feature_info = [dict(num_chs=stem_width, reduction=2, module='stem')] # Construct the stages prev_width = stem_width curr_stride = 2 - stage_params = self._get_stage_params(cfg, output_stride=output_stride, drop_path_rate=drop_path_rate) - for i, stage_args in enumerate(stage_params): + per_stage_args, common_args = self._get_stage_args( + cfg, output_stride=output_stride, drop_path_rate=drop_path_rate) + block_fn = PreBottleneck if cfg.preact else Bottleneck + for i, stage_args in enumerate(per_stage_args): stage_name = "s{}".format(i + 1) - self.add_module(stage_name, RegStage( - in_chs=prev_width, se_ratio=cfg.se_ratio, downsample=cfg.downsample, linear_out=cfg.linear_out, - act_layer=cfg.act_layer, norm_layer=cfg.norm_layer, **stage_args)) + self.add_module(stage_name, RegStage(in_chs=prev_width, block_fn=block_fn, **stage_args, **common_args)) prev_width = stage_args['out_chs'] curr_stride *= stage_args['stride'] self.feature_info += [dict(num_chs=prev_width, reduction=curr_stride, module=stage_name)] # Construct the head - self.num_features = prev_width + if cfg.num_features: + self.final_conv = ConvNormAct(prev_width, cfg.num_features, kernel_size=1, **na_args) + self.num_features = cfg.num_features + else: + final_act = cfg.linear_out or cfg.preact + self.final_conv = get_act_layer(cfg.act_layer)() if final_act else nn.Identity() + self.num_features = prev_width self.head = ClassifierHead( - in_chs=prev_width, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) + in_chs=self.num_features, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate) named_apply(partial(_init_weights, zero_init_last=zero_init_last), self) - def _get_stage_params(self, cfg: RegNetCfg, default_stride=2, output_stride=32, drop_path_rate=0.): + def _get_stage_args(self, cfg: RegNetCfg, default_stride=2, output_stride=32, drop_path_rate=0.): # Generate RegNet ws per block widths, num_stages, _, _ = generate_regnet(cfg.wa, cfg.w0, cfg.wm, cfg.depth) @@ -341,12 +420,15 @@ class RegNet(nn.Module): # Adjust the compatibility of ws and gws stage_widths, stage_groups = adjust_widths_groups_comp(stage_widths, stage_bottle_ratios, stage_groups) - param_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_size', 'drop_path_rates'] - stage_params = [ - dict(zip(param_names, params)) for params in + arg_names = ['out_chs', 'stride', 'dilation', 'depth', 'bottle_ratio', 'group_size', 'drop_path_rates'] + per_stage_args = [ + dict(zip(arg_names, params)) for params in zip(stage_widths, stage_strides, stage_dilations, stage_depths, stage_bottle_ratios, stage_groups, stage_dpr)] - return stage_params + common_args = dict( + downsample=cfg.downsample, se_ratio=cfg.se_ratio, linear_out=cfg.linear_out, + act_layer=cfg.act_layer, norm_layer=cfg.norm_layer) + return per_stage_args, common_args def get_classifier(self): return self.head.fc @@ -367,14 +449,16 @@ class RegNet(nn.Module): def _init_weights(module, name='', zero_init_last=False): if isinstance(module, nn.Conv2d): - nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') - elif isinstance(module, nn.BatchNorm2d): - nn.init.ones_(module.weight) - nn.init.zeros_(module.bias) + fan_out = module.kernel_size[0] * module.kernel_size[1] * module.out_channels + fan_out //= module.groups + module.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if module.bias is not None: + module.bias.data.zero_() elif isinstance(module, nn.Linear): nn.init.normal_(module.weight, mean=0.0, std=0.01) - nn.init.zeros_(module.bias) - elif hasattr(module, 'zero_init_last'): + if module.bias is not None: + nn.init.zeros_(module.bias) + elif zero_init_last and hasattr(module, 'zero_init_last'): module.zero_init_last() @@ -545,13 +629,25 @@ def regnety_040s_gn(pretrained=False, **kwargs): return _create_regnet('regnety_040s_gn', pretrained, **kwargs) +@register_model +def regnetv_040(pretrained=False, **kwargs): + """""" + return _create_regnet('regnetv_040', pretrained, **kwargs) + + +@register_model +def regnetw_040(pretrained=False, **kwargs): + """""" + return _create_regnet('regnetw_040', pretrained, **kwargs) + + @register_model def regnetz_005(pretrained=False, **kwargs): """RegNetZ-500MF NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py but it's not clear it is equivalent to paper model as not detailed in the paper. """ - return _create_regnet('regnetz_005', pretrained, **kwargs) + return _create_regnet('regnetz_005', pretrained, zero_init_last=False, **kwargs) @register_model @@ -560,4 +656,4 @@ def regnetz_040(pretrained=False, **kwargs): NOTE: config found in https://github.com/facebookresearch/ClassyVision/blob/main/classy_vision/models/regnet.py but it's not clear it is equivalent to paper model as not detailed in the paper. """ - return _create_regnet('regnetz_040', pretrained, **kwargs) + return _create_regnet('regnetz_040', pretrained, zero_init_last=False, **kwargs)