From 90980de4a9a0d3419b50b4e51408a8a0fe266f29 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Sat, 30 Jan 2021 16:32:07 -0800 Subject: [PATCH] Fix up a few details in NFResNet models, managed stable training. Add support for gamma gain to be applied in activation or ScaleStdConv. Some tweaks to ScaledStdConv. --- timm/models/layers/std_conv.py | 17 ++-- timm/models/nfnet.py | 177 ++++++++++++++++++--------------- 2 files changed, 106 insertions(+), 88 deletions(-) diff --git a/timm/models/layers/std_conv.py b/timm/models/layers/std_conv.py index d7f8274e..80a8e5d7 100644 --- a/timm/models/layers/std_conv.py +++ b/timm/models/layers/std_conv.py @@ -1,7 +1,6 @@ import torch import torch.nn as nn import torch.nn.functional as F -import numpy as np from .padding import get_padding from .conv2d_same import conv2d_same @@ -69,20 +68,24 @@ class ScaledStdConv2d(nn.Conv2d): https://arxiv.org/abs/2101.08692 """ - def __init__(self, in_channels, out_channels, kernel_size, - stride=1, padding=None, dilation=1, groups=1, bias=True, gain=True, gamma=1.0, eps=1e-5): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1, + bias=True, gain=True, gamma=1.0, eps=1e-5, use_layernorm=False): if padding is None: padding = get_padding(kernel_size, stride, dilation) super().__init__( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias) self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) if gain else None - self.gamma = gamma * self.weight[0].numel() ** 0.5 # gamma * sqrt(fan-in) - self.eps = eps + self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in) + self.eps = eps ** 2 if use_layernorm else eps + self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory use def get_weight(self): - std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) - weight = (self.weight - mean) / (self.gamma * std + self.eps) + if self.use_layernorm: + weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps) + else: + std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) + weight = self.scale * (self.weight - mean) / (std + self.eps) if self.gain is not None: weight = weight * self.gain return weight diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py index 69ca9fee..f54bb84a 100644 --- a/timm/models/nfnet.py +++ b/timm/models/nfnet.py @@ -18,7 +18,7 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg from .registry import register_model -from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, get_act_layer, get_attn, make_divisible +from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, get_act_layer, get_attn, make_divisible, get_act_fn def _dcfg(url='', **kwargs): @@ -40,17 +40,17 @@ default_cfgs = { 'nf_regnet_b4': _dcfg(url='', input_size=(3, 320, 320)), 'nf_regnet_b5': _dcfg(url='', input_size=(3, 384, 384)), - 'nf_resnet26d': _dcfg(url='', first_conv='stem.conv1'), - 'nf_resnet50d': _dcfg(url='', first_conv='stem.conv1'), - 'nf_resnet101d': _dcfg(url='', first_conv='stem.conv1'), + 'nf_resnet26': _dcfg(url='', first_conv='stem.conv'), + 'nf_resnet50': _dcfg(url='', first_conv='stem.conv'), + 'nf_resnet101': _dcfg(url='', first_conv='stem.conv'), - 'nf_seresnet26d': _dcfg(url='', first_conv='stem.conv1'), - 'nf_seresnet50d': _dcfg(url='', first_conv='stem.conv1'), - 'nf_seresnet101d': _dcfg(url='', first_conv='stem.conv1'), + 'nf_seresnet26': _dcfg(url='', first_conv='stem.conv'), + 'nf_seresnet50': _dcfg(url='', first_conv='stem.conv'), + 'nf_seresnet101': _dcfg(url='', first_conv='stem.conv'), - 'nf_ecaresnet26d': _dcfg(url='', first_conv='stem.conv1'), - 'nf_ecaresnet50d': _dcfg(url='', first_conv='stem.conv1'), - 'nf_ecaresnet101d': _dcfg(url='', first_conv='stem.conv1'), + 'nf_ecaresnet26': _dcfg(url='', first_conv='stem.conv'), + 'nf_ecaresnet50': _dcfg(url='', first_conv='stem.conv'), + 'nf_ecaresnet101': _dcfg(url='', first_conv='stem.conv'), } @@ -59,6 +59,7 @@ class NfCfg: depths: Tuple[int, int, int, int] channels: Tuple[int, int, int, int] alpha: float = 0.2 + gamma_in_act: bool = False stem_type: str = '3x3' stem_chs: Optional[int] = None group_size: Optional[int] = 8 @@ -84,68 +85,65 @@ model_cfgs = dict( nf_regnet_b5=NfCfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704), num_features=2048), # ResNet (preact, D style deep stem/avg down) defs - nf_resnet26d=NfCfg( + nf_resnet26=NfCfg( depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048), - stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, + stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, act_layer='relu', attn_layer=None,), - nf_resnet50d=NfCfg( + nf_resnet50=NfCfg( depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), - stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, + stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, act_layer='relu', attn_layer=None), - nf_resnet101d=NfCfg( + nf_resnet101=NfCfg( depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), - stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, + stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, act_layer='relu', attn_layer=None), - nf_seresnet26d=NfCfg( + nf_seresnet26=NfCfg( depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048), - stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, + stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), - nf_seresnet50d=NfCfg( + nf_seresnet50=NfCfg( depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), - stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, + stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), - nf_seresnet101d=NfCfg( + nf_seresnet101=NfCfg( depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), - stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, + stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), - nf_ecaresnet26d=NfCfg( + nf_ecaresnet26=NfCfg( depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048), - stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, + stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, act_layer='relu', attn_layer='eca', attn_kwargs=dict()), - nf_ecaresnet50d=NfCfg( + nf_ecaresnet50=NfCfg( depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), - stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, + stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, act_layer='relu', attn_layer='eca', attn_kwargs=dict()), - nf_ecaresnet101d=NfCfg( + nf_ecaresnet101=NfCfg( depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), - stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, + stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, act_layer='relu', attn_layer='eca', attn_kwargs=dict()), ) -# class NormFreeSiLU(nn.Module): -# _K = 1. / 0.5595 -# def __init__(self, inplace=False): -# super().__init__() -# self.inplace = inplace -# -# def forward(self, x): -# return F.silu(x, inplace=self.inplace) * self._K -# -# -# class NormFreeReLU(nn.Module): -# _K = (0.5 * (1. - 1. / math.pi)) ** -0.5 -# -# def __init__(self, inplace=False): -# super().__init__() -# self.inplace = inplace -# -# def forward(self, x): -# return F.relu(x, inplace=self.inplace) * self._K + +class GammaAct(nn.Module): + def __init__(self, act_type='relu', gamma: float = 1.0, inplace=False): + super().__init__() + self.act_fn = get_act_fn(act_type) + self.gamma = gamma + self.inplace = inplace + + def forward(self, x): + return self.gamma * self.act_fn(x, inplace=self.inplace) + + +def act_with_gamma(act_type, gamma: float = 1.): + def _create(inplace=False): + return GammaAct(act_type, gamma=gamma, inplace=inplace) + return _create class DownsampleAvg(nn.Module): @@ -178,10 +176,9 @@ class NormalizationFreeBlock(nn.Module): out_chs = out_chs or in_chs # EfficientNet-like models scale bottleneck from in_chs, otherwise scale from out_chs like ResNet mid_chs = make_divisible(in_chs * bottle_ratio if efficient else out_chs * bottle_ratio, ch_div) - groups = 1 - if group_size is not None: - # NOTE: not correcting the mid_chs % group_size, fix model def if broken. I want % ch_div == 0 to stand. - groups = mid_chs // group_size + groups = 1 if group_size is None else mid_chs // group_size + if group_size and group_size % ch_div == 0: + mid_chs = group_size * groups # correct mid_chs if group_size divisible by ch_div, otherwise error self.alpha = alpha self.beta = beta self.attn_gain = attn_gain @@ -229,10 +226,11 @@ class NormalizationFreeBlock(nn.Module): def create_stem(in_chs, out_chs, stem_type='', conv_layer=None): + stem_stride = 2 stem = OrderedDict() - assert stem_type in ('', 'deep', '3x3', '7x7') + assert stem_type in ('', 'deep', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool') if 'deep' in stem_type: - # 3 deep 3x3 conv stack as in ResNet V1D models + # 3 deep 3x3 conv stack as in ResNet V1D models. NOTE: doesn't work as well here mid_chs = out_chs // 2 stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2) stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1) @@ -244,12 +242,16 @@ def create_stem(in_chs, out_chs, stem_type='', conv_layer=None): # 7x7 stem conv as in ResNet stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2) - return nn.Sequential(stem) + if 'pool' in stem_type: + stem['pool'] = nn.MaxPool2d(3, stride=2, padding=1) + stem_stride = 4 + + return nn.Sequential(stem), stem_stride _nonlin_gamma = dict( - silu=.5595, - relu=(0.5 * (1. - 1. / math.pi)) ** 0.5, + silu=1./.5595, + relu=(0.5 * (1. - 1. / math.pi)) ** -0.5, identity=1.0 ) @@ -264,9 +266,12 @@ class NormalizerFreeNet(nn.Module): the (preact) ResNet models described earlier in the paper. There are a few differences: - * channels are rounded to be divisible by 8 by default (keep TC happy), this changes param counts + * channels are rounded to be divisible by 8 by default (keep tensor core kernels happy), + this changes channel dim and param counts slightly from the paper models * activation correcting gamma constants are moved into the ScaledStdConv as it has less performance impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl. + * a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but + apply it in each activation. This is slightly slower, and yields slightly different results. * skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput for what it is/does. Approx 8-10% throughput loss. """ @@ -275,29 +280,33 @@ class NormalizerFreeNet(nn.Module): super().__init__() self.num_classes = num_classes self.drop_rate = drop_rate - act_layer = get_act_layer(cfg.act_layer) assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})." - conv_layer = partial(ScaledStdConv2d, bias=True, gain=True, gamma=_nonlin_gamma[cfg.act_layer]) + if cfg.gamma_in_act: + act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer]) + conv_layer = partial(ScaledStdConv2d, bias=True, gain=True) + else: + act_layer = get_act_layer(cfg.act_layer) + conv_layer = partial(ScaledStdConv2d, bias=True, gain=True, gamma=_nonlin_gamma[cfg.act_layer]) attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None - self.feature_info = [] # FIXME fill out feature info - stem_chs = cfg.stem_chs or cfg.channels[0] stem_chs = make_divisible(stem_chs * cfg.width_factor, cfg.ch_div) - self.stem = create_stem(in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer) + self.stem, stem_stride = create_stem(in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer) - prev_chs = stem_chs + self.feature_info = [] # NOTE: there will be no stride == 2 feature if stem_stride == 4 dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] - net_stride = 2 + prev_chs = stem_chs + net_stride = stem_stride dilation = 1 expected_var = 1.0 stages = [] for stage_idx, stage_depth in enumerate(cfg.depths): - if net_stride >= output_stride: - dilation *= 2 + stride = 1 if stage_idx == 0 and stem_stride > 2 else 2 + self.feature_info += [dict( + num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1' if stride == 2 else '')] + if net_stride >= output_stride and stride > 1: + dilation *= stride stride = 1 - else: - stride = 2 net_stride *= stride first_dilation = 1 if dilation in (1, 2) else 2 @@ -338,7 +347,10 @@ class NormalizerFreeNet(nn.Module): else: self.num_features = prev_chs self.final_conv = nn.Identity() + # FIXME not 100% clear on gamma subtleties final conv/final act in case where it's in stdconv self.final_act = act_layer() + self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_act')] + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) for n, m in self.named_modules(): @@ -373,11 +385,14 @@ class NormalizerFreeNet(nn.Module): def _create_normfreenet(variant, pretrained=False, **kwargs): + model_cfg = model_cfgs[variant] feature_cfg = dict(flatten_sequential=True) feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks + if 'pool' in model_cfg.stem_type: + feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2, 0 level feat for stride 4 maxpool stems in ResNet return build_model_with_cfg( - NormalizerFreeNet, variant, pretrained, model_cfg=model_cfgs[variant], default_cfg=default_cfgs[variant], + NormalizerFreeNet, variant, pretrained, model_cfg=model_cfg, default_cfg=default_cfgs[variant], feature_cfg=feature_cfg, **kwargs) @@ -412,30 +427,30 @@ def nf_regnet_b5(pretrained=False, **kwargs): @register_model -def nf_resnet26d(pretrained=False, **kwargs): - return _create_normfreenet('nf_resnet26d', pretrained=pretrained, **kwargs) +def nf_resnet26(pretrained=False, **kwargs): + return _create_normfreenet('nf_resnet26', pretrained=pretrained, **kwargs) @register_model -def nf_resnet50d(pretrained=False, **kwargs): - return _create_normfreenet('nf_resnet50d', pretrained=pretrained, **kwargs) +def nf_resnet50(pretrained=False, **kwargs): + return _create_normfreenet('nf_resnet50', pretrained=pretrained, **kwargs) @register_model -def nf_seresnet26d(pretrained=False, **kwargs): - return _create_normfreenet('nf_seresnet26d', pretrained=pretrained, **kwargs) +def nf_seresnet26(pretrained=False, **kwargs): + return _create_normfreenet('nf_seresnet26', pretrained=pretrained, **kwargs) @register_model -def nf_seresnet50d(pretrained=False, **kwargs): - return _create_normfreenet('nf_seresnet50d', pretrained=pretrained, **kwargs) +def nf_seresnet50(pretrained=False, **kwargs): + return _create_normfreenet('nf_seresnet50', pretrained=pretrained, **kwargs) @register_model -def nf_ecaresnet26d(pretrained=False, **kwargs): - return _create_normfreenet('nf_ecaresnet26d', pretrained=pretrained, **kwargs) +def nf_ecaresnet26(pretrained=False, **kwargs): + return _create_normfreenet('nf_ecaresnet26', pretrained=pretrained, **kwargs) @register_model -def nf_ecaresnet50d(pretrained=False, **kwargs): - return _create_normfreenet('nf_ecaresnet50d', pretrained=pretrained, **kwargs) +def nf_ecaresnet50(pretrained=False, **kwargs): + return _create_normfreenet('nf_ecaresnet50', pretrained=pretrained, **kwargs)