Fix up a few details in NFResNet models, managed stable training. Add support for gamma gain to be applied in activation or ScaleStdConv. Some tweaks to ScaledStdConv.

pull/389/head
Ross Wightman 3 years ago
parent 5a8e1e643e
commit 90980de4a9

@ -1,7 +1,6 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .padding import get_padding
from .conv2d_same import conv2d_same
@ -69,20 +68,24 @@ class ScaledStdConv2d(nn.Conv2d):
https://arxiv.org/abs/2101.08692
"""
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=None, dilation=1, groups=1, bias=True, gain=True, gamma=1.0, eps=1e-5):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=None, dilation=1, groups=1,
bias=True, gain=True, gamma=1.0, eps=1e-5, use_layernorm=False):
if padding is None:
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
in_channels, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) if gain else None
self.gamma = gamma * self.weight[0].numel() ** 0.5 # gamma * sqrt(fan-in)
self.eps = eps
self.scale = gamma * self.weight[0].numel() ** -0.5 # gamma * 1 / sqrt(fan-in)
self.eps = eps ** 2 if use_layernorm else eps
self.use_layernorm = use_layernorm # experimental, slightly faster/less GPU memory use
def get_weight(self):
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = (self.weight - mean) / (self.gamma * std + self.eps)
if self.use_layernorm:
weight = self.scale * F.layer_norm(self.weight, self.weight.shape[1:], eps=self.eps)
else:
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = self.scale * (self.weight - mean) / (std + self.eps)
if self.gain is not None:
weight = weight * self.gain
return weight

@ -18,7 +18,7 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .registry import register_model
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, get_act_layer, get_attn, make_divisible
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, get_act_layer, get_attn, make_divisible, get_act_fn
def _dcfg(url='', **kwargs):
@ -40,17 +40,17 @@ default_cfgs = {
'nf_regnet_b4': _dcfg(url='', input_size=(3, 320, 320)),
'nf_regnet_b5': _dcfg(url='', input_size=(3, 384, 384)),
'nf_resnet26d': _dcfg(url='', first_conv='stem.conv1'),
'nf_resnet50d': _dcfg(url='', first_conv='stem.conv1'),
'nf_resnet101d': _dcfg(url='', first_conv='stem.conv1'),
'nf_resnet26': _dcfg(url='', first_conv='stem.conv'),
'nf_resnet50': _dcfg(url='', first_conv='stem.conv'),
'nf_resnet101': _dcfg(url='', first_conv='stem.conv'),
'nf_seresnet26d': _dcfg(url='', first_conv='stem.conv1'),
'nf_seresnet50d': _dcfg(url='', first_conv='stem.conv1'),
'nf_seresnet101d': _dcfg(url='', first_conv='stem.conv1'),
'nf_seresnet26': _dcfg(url='', first_conv='stem.conv'),
'nf_seresnet50': _dcfg(url='', first_conv='stem.conv'),
'nf_seresnet101': _dcfg(url='', first_conv='stem.conv'),
'nf_ecaresnet26d': _dcfg(url='', first_conv='stem.conv1'),
'nf_ecaresnet50d': _dcfg(url='', first_conv='stem.conv1'),
'nf_ecaresnet101d': _dcfg(url='', first_conv='stem.conv1'),
'nf_ecaresnet26': _dcfg(url='', first_conv='stem.conv'),
'nf_ecaresnet50': _dcfg(url='', first_conv='stem.conv'),
'nf_ecaresnet101': _dcfg(url='', first_conv='stem.conv'),
}
@ -59,6 +59,7 @@ class NfCfg:
depths: Tuple[int, int, int, int]
channels: Tuple[int, int, int, int]
alpha: float = 0.2
gamma_in_act: bool = False
stem_type: str = '3x3'
stem_chs: Optional[int] = None
group_size: Optional[int] = 8
@ -84,68 +85,65 @@ model_cfgs = dict(
nf_regnet_b5=NfCfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704), num_features=2048),
# ResNet (preact, D style deep stem/avg down) defs
nf_resnet26d=NfCfg(
nf_resnet26=NfCfg(
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer=None,),
nf_resnet50d=NfCfg(
nf_resnet50=NfCfg(
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer=None),
nf_resnet101d=NfCfg(
nf_resnet101=NfCfg(
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer=None),
nf_seresnet26d=NfCfg(
nf_seresnet26=NfCfg(
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
nf_seresnet50d=NfCfg(
nf_seresnet50=NfCfg(
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
nf_seresnet101d=NfCfg(
nf_seresnet101=NfCfg(
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
nf_ecaresnet26d=NfCfg(
nf_ecaresnet26=NfCfg(
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
nf_ecaresnet50d=NfCfg(
nf_ecaresnet50=NfCfg(
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
nf_ecaresnet101d=NfCfg(
nf_ecaresnet101=NfCfg(
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
stem_type='7x7_pool', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
)
# class NormFreeSiLU(nn.Module):
# _K = 1. / 0.5595
# def __init__(self, inplace=False):
# super().__init__()
# self.inplace = inplace
#
# def forward(self, x):
# return F.silu(x, inplace=self.inplace) * self._K
#
#
# class NormFreeReLU(nn.Module):
# _K = (0.5 * (1. - 1. / math.pi)) ** -0.5
#
# def __init__(self, inplace=False):
# super().__init__()
# self.inplace = inplace
#
# def forward(self, x):
# return F.relu(x, inplace=self.inplace) * self._K
class GammaAct(nn.Module):
def __init__(self, act_type='relu', gamma: float = 1.0, inplace=False):
super().__init__()
self.act_fn = get_act_fn(act_type)
self.gamma = gamma
self.inplace = inplace
def forward(self, x):
return self.gamma * self.act_fn(x, inplace=self.inplace)
def act_with_gamma(act_type, gamma: float = 1.):
def _create(inplace=False):
return GammaAct(act_type, gamma=gamma, inplace=inplace)
return _create
class DownsampleAvg(nn.Module):
@ -178,10 +176,9 @@ class NormalizationFreeBlock(nn.Module):
out_chs = out_chs or in_chs
# EfficientNet-like models scale bottleneck from in_chs, otherwise scale from out_chs like ResNet
mid_chs = make_divisible(in_chs * bottle_ratio if efficient else out_chs * bottle_ratio, ch_div)
groups = 1
if group_size is not None:
# NOTE: not correcting the mid_chs % group_size, fix model def if broken. I want % ch_div == 0 to stand.
groups = mid_chs // group_size
groups = 1 if group_size is None else mid_chs // group_size
if group_size and group_size % ch_div == 0:
mid_chs = group_size * groups # correct mid_chs if group_size divisible by ch_div, otherwise error
self.alpha = alpha
self.beta = beta
self.attn_gain = attn_gain
@ -229,10 +226,11 @@ class NormalizationFreeBlock(nn.Module):
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None):
stem_stride = 2
stem = OrderedDict()
assert stem_type in ('', 'deep', '3x3', '7x7')
assert stem_type in ('', 'deep', '3x3', '7x7', 'deep_pool', '3x3_pool', '7x7_pool')
if 'deep' in stem_type:
# 3 deep 3x3 conv stack as in ResNet V1D models
# 3 deep 3x3 conv stack as in ResNet V1D models. NOTE: doesn't work as well here
mid_chs = out_chs // 2
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
@ -244,12 +242,16 @@ def create_stem(in_chs, out_chs, stem_type='', conv_layer=None):
# 7x7 stem conv as in ResNet
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
return nn.Sequential(stem)
if 'pool' in stem_type:
stem['pool'] = nn.MaxPool2d(3, stride=2, padding=1)
stem_stride = 4
return nn.Sequential(stem), stem_stride
_nonlin_gamma = dict(
silu=.5595,
relu=(0.5 * (1. - 1. / math.pi)) ** 0.5,
silu=1./.5595,
relu=(0.5 * (1. - 1. / math.pi)) ** -0.5,
identity=1.0
)
@ -264,9 +266,12 @@ class NormalizerFreeNet(nn.Module):
the (preact) ResNet models described earlier in the paper.
There are a few differences:
* channels are rounded to be divisible by 8 by default (keep TC happy), this changes param counts
* channels are rounded to be divisible by 8 by default (keep tensor core kernels happy),
this changes channel dim and param counts slightly from the paper models
* activation correcting gamma constants are moved into the ScaledStdConv as it has less performance
impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl.
* a config option `gamma_in_act` can be enabled to not apply gamma in StdConv as described above, but
apply it in each activation. This is slightly slower, and yields slightly different results.
* skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput
for what it is/does. Approx 8-10% throughput loss.
"""
@ -275,29 +280,33 @@ class NormalizerFreeNet(nn.Module):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
act_layer = get_act_layer(cfg.act_layer)
assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})."
conv_layer = partial(ScaledStdConv2d, bias=True, gain=True, gamma=_nonlin_gamma[cfg.act_layer])
if cfg.gamma_in_act:
act_layer = act_with_gamma(cfg.act_layer, gamma=_nonlin_gamma[cfg.act_layer])
conv_layer = partial(ScaledStdConv2d, bias=True, gain=True)
else:
act_layer = get_act_layer(cfg.act_layer)
conv_layer = partial(ScaledStdConv2d, bias=True, gain=True, gamma=_nonlin_gamma[cfg.act_layer])
attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
self.feature_info = [] # FIXME fill out feature info
stem_chs = cfg.stem_chs or cfg.channels[0]
stem_chs = make_divisible(stem_chs * cfg.width_factor, cfg.ch_div)
self.stem = create_stem(in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer)
self.stem, stem_stride = create_stem(in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer)
prev_chs = stem_chs
self.feature_info = [] # NOTE: there will be no stride == 2 feature if stem_stride == 4
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
net_stride = 2
prev_chs = stem_chs
net_stride = stem_stride
dilation = 1
expected_var = 1.0
stages = []
for stage_idx, stage_depth in enumerate(cfg.depths):
if net_stride >= output_stride:
dilation *= 2
stride = 1 if stage_idx == 0 and stem_stride > 2 else 2
self.feature_info += [dict(
num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}.0.act1' if stride == 2 else '')]
if net_stride >= output_stride and stride > 1:
dilation *= stride
stride = 1
else:
stride = 2
net_stride *= stride
first_dilation = 1 if dilation in (1, 2) else 2
@ -338,7 +347,10 @@ class NormalizerFreeNet(nn.Module):
else:
self.num_features = prev_chs
self.final_conv = nn.Identity()
# FIXME not 100% clear on gamma subtleties final conv/final act in case where it's in stdconv
self.final_act = act_layer()
self.feature_info += [dict(num_chs=self.num_features, reduction=net_stride, module='final_act')]
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
for n, m in self.named_modules():
@ -373,11 +385,14 @@ class NormalizerFreeNet(nn.Module):
def _create_normfreenet(variant, pretrained=False, **kwargs):
model_cfg = model_cfgs[variant]
feature_cfg = dict(flatten_sequential=True)
feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks
if 'pool' in model_cfg.stem_type:
feature_cfg['out_indices'] = (1, 2, 3, 4) # no stride 2, 0 level feat for stride 4 maxpool stems in ResNet
return build_model_with_cfg(
NormalizerFreeNet, variant, pretrained, model_cfg=model_cfgs[variant], default_cfg=default_cfgs[variant],
NormalizerFreeNet, variant, pretrained, model_cfg=model_cfg, default_cfg=default_cfgs[variant],
feature_cfg=feature_cfg, **kwargs)
@ -412,30 +427,30 @@ def nf_regnet_b5(pretrained=False, **kwargs):
@register_model
def nf_resnet26d(pretrained=False, **kwargs):
return _create_normfreenet('nf_resnet26d', pretrained=pretrained, **kwargs)
def nf_resnet26(pretrained=False, **kwargs):
return _create_normfreenet('nf_resnet26', pretrained=pretrained, **kwargs)
@register_model
def nf_resnet50d(pretrained=False, **kwargs):
return _create_normfreenet('nf_resnet50d', pretrained=pretrained, **kwargs)
def nf_resnet50(pretrained=False, **kwargs):
return _create_normfreenet('nf_resnet50', pretrained=pretrained, **kwargs)
@register_model
def nf_seresnet26d(pretrained=False, **kwargs):
return _create_normfreenet('nf_seresnet26d', pretrained=pretrained, **kwargs)
def nf_seresnet26(pretrained=False, **kwargs):
return _create_normfreenet('nf_seresnet26', pretrained=pretrained, **kwargs)
@register_model
def nf_seresnet50d(pretrained=False, **kwargs):
return _create_normfreenet('nf_seresnet50d', pretrained=pretrained, **kwargs)
def nf_seresnet50(pretrained=False, **kwargs):
return _create_normfreenet('nf_seresnet50', pretrained=pretrained, **kwargs)
@register_model
def nf_ecaresnet26d(pretrained=False, **kwargs):
return _create_normfreenet('nf_ecaresnet26d', pretrained=pretrained, **kwargs)
def nf_ecaresnet26(pretrained=False, **kwargs):
return _create_normfreenet('nf_ecaresnet26', pretrained=pretrained, **kwargs)
@register_model
def nf_ecaresnet50d(pretrained=False, **kwargs):
return _create_normfreenet('nf_ecaresnet50d', pretrained=pretrained, **kwargs)
def nf_ecaresnet50(pretrained=False, **kwargs):
return _create_normfreenet('nf_ecaresnet50', pretrained=pretrained, **kwargs)

Loading…
Cancel
Save