Initial Normalizer-Free Reg/ResNet impl. A bit of related layer refactoring.

pull/389/head
Ross Wightman 3 years ago
parent 9a38416fbd
commit 5a8e1e643e

@ -11,6 +11,7 @@ from .inception_v3 import *
from .inception_v4 import *
from .mobilenetv3 import *
from .nasnet import *
from .nfnet import *
from .pnasnet import *
from .regnet import *
from .res2net import *

@ -10,13 +10,13 @@ from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set
from .conv2d_same import Conv2dSame, conv2d_same
from .conv_bn_act import ConvBnAct
from .create_act import create_act_layer, get_act_layer, get_act_fn
from .create_attn import create_attn
from .create_attn import get_attn, create_attn
from .create_conv2d import create_conv2d
from .create_norm_act import create_norm_act, get_norm_act_layer
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
from .eca import EcaModule, CecaModule
from .evo_norm import EvoNormBatch2d, EvoNormSample2d
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible
from .inplace_abn import InplaceAbn
from .linear import Linear
from .mixed_conv2d import MixedConv2d
@ -29,5 +29,6 @@ from .separable_conv import SeparableConv2d, SeparableConvBnAct
from .space_to_depth import SpaceToDepthModule
from .split_attn import SplitAttnConv2d
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
from .weight_init import trunc_normal_

@ -8,7 +8,7 @@ from .eca import EcaModule, CecaModule
from .cbam import CbamModule, LightCbamModule
def create_attn(attn_type, channels, **kwargs):
def get_attn(attn_type):
module_cls = None
if attn_type is not None:
if isinstance(attn_type, str):
@ -32,6 +32,12 @@ def create_attn(attn_type, channels, **kwargs):
module_cls = SEModule
else:
module_cls = attn_type
return module_cls
def create_attn(attn_type, channels, **kwargs):
module_cls = get_attn(attn_type)
if module_cls is not None:
# NOTE: it's expected the first (positional) argument of all attention layers is the # input channels
return module_cls(channels, **kwargs)
return None

@ -22,6 +22,10 @@ to_4tuple = _ntuple(4)
to_ntuple = _ntuple
def make_divisible(v, divisor=8, min_value=None):
min_value = min_value or divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v

@ -1,13 +1,27 @@
from torch import nn as nn
import torch.nn.functional as F
from .create_act import create_act_layer
from .helpers import make_divisible
class SEModule(nn.Module):
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None,
gate_layer='sigmoid'):
""" SE Module as defined in original SE-Nets with a few additions
Additions include:
* min_channels can be specified to keep reduced channel count at a minimum (default: 8)
* divisor can be specified to keep channels rounded to specified values (default: 1)
* reduction channels can be specified directly by arg (if reduction_channels is set)
* reduction channels can be specified by float ratio (if reduction_ratio is set)
"""
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, gate_layer='sigmoid',
reduction_ratio=None, reduction_channels=None, min_channels=8, divisor=1):
super(SEModule, self).__init__()
reduction_channels = reduction_channels or max(channels // reduction, min_channels)
if reduction_channels is not None:
reduction_channels = reduction_channels # direct specification highest priority, no rounding/min done
elif reduction_ratio is not None:
reduction_channels = make_divisible(channels * reduction_ratio, divisor, min_channels)
else:
reduction_channels = make_divisible(channels // reduction, divisor, min_channels)
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
self.act = act_layer(inplace=True)
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)

@ -0,0 +1,91 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from .padding import get_padding
from .conv2d_same import conv2d_same
def get_weight(module):
std, mean = torch.std_mean(module.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = (module.weight - mean) / (std + module.eps)
return weight
class StdConv2d(nn.Conv2d):
"""Conv2d with Weight Standardization. Used for BiT ResNet-V2 models.
Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
https://arxiv.org/abs/1903.10520v2
"""
def __init__(
self, in_channel, out_channels, kernel_size, stride=1,
padding=None, dilation=1, groups=1, bias=False, eps=1e-5):
if padding is None:
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
in_channel, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
self.eps = eps
def get_weight(self):
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = (self.weight - mean) / (std + self.eps)
return weight
def forward(self, x):
x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
return x
class StdConv2dSame(nn.Conv2d):
"""Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model.
Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` -
https://arxiv.org/abs/1903.10520v2
"""
def __init__(
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=False, eps=1e-5):
super().__init__(
in_channel, out_channels, kernel_size, stride=stride,
padding=0, dilation=dilation, groups=groups, bias=bias)
self.eps = eps
def get_weight(self):
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = (self.weight - mean) / (std + self.eps)
return weight
def forward(self, x):
x = conv2d_same(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)
return x
class ScaledStdConv2d(nn.Conv2d):
"""Conv2d layer with Scaled Weight Standardization.
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` -
https://arxiv.org/abs/2101.08692
"""
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=None, dilation=1, groups=1, bias=True, gain=True, gamma=1.0, eps=1e-5):
if padding is None:
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
in_channels, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) if gain else None
self.gamma = gamma * self.weight[0].numel() ** 0.5 # gamma * sqrt(fan-in)
self.eps = eps
def get_weight(self):
std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False)
weight = (self.weight - mean) / (self.gamma * std + self.eps)
if self.gain is not None:
weight = weight * self.gain
return weight
def forward(self, x):
return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups)

@ -0,0 +1,441 @@
""" Normalizer Free RegNet / ResNet (pre-activation) Models
Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
- https://arxiv.org/abs/2101.08692
Hacked together by / copyright Ross Wightman, 2021.
"""
import math
from dataclasses import dataclass, field
from collections import OrderedDict
from typing import Tuple, Optional
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .registry import register_model
from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, get_act_layer, get_attn, make_divisible
def _dcfg(url='', **kwargs):
return {
'url': url,
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
'crop_pct': 0.875, 'interpolation': 'bicubic',
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
'first_conv': 'stem.conv', 'classifier': 'head.fc',
**kwargs
}
# FIXME finish
default_cfgs = {
'nf_regnet_b0': _dcfg(url=''),
'nf_regnet_b1': _dcfg(url='', input_size=(3, 240, 240)),
'nf_regnet_b2': _dcfg(url='', input_size=(3, 256, 256)),
'nf_regnet_b3': _dcfg(url='', input_size=(3, 272, 272)),
'nf_regnet_b4': _dcfg(url='', input_size=(3, 320, 320)),
'nf_regnet_b5': _dcfg(url='', input_size=(3, 384, 384)),
'nf_resnet26d': _dcfg(url='', first_conv='stem.conv1'),
'nf_resnet50d': _dcfg(url='', first_conv='stem.conv1'),
'nf_resnet101d': _dcfg(url='', first_conv='stem.conv1'),
'nf_seresnet26d': _dcfg(url='', first_conv='stem.conv1'),
'nf_seresnet50d': _dcfg(url='', first_conv='stem.conv1'),
'nf_seresnet101d': _dcfg(url='', first_conv='stem.conv1'),
'nf_ecaresnet26d': _dcfg(url='', first_conv='stem.conv1'),
'nf_ecaresnet50d': _dcfg(url='', first_conv='stem.conv1'),
'nf_ecaresnet101d': _dcfg(url='', first_conv='stem.conv1'),
}
@dataclass
class NfCfg:
depths: Tuple[int, int, int, int]
channels: Tuple[int, int, int, int]
alpha: float = 0.2
stem_type: str = '3x3'
stem_chs: Optional[int] = None
group_size: Optional[int] = 8
attn_layer: Optional[str] = 'se'
attn_kwargs: dict = field(default_factory=lambda: dict(reduction_ratio=0.5, divisor=8))
attn_gain: float = 2.0 # NF correction gain to apply if attn layer is used
width_factor: float = 0.75
bottle_ratio: float = 2.25
efficient: bool = True # enables EfficientNet-like options that are used in paper 'nf_regnet_b*' models
num_features: int = 1280 # num out_channels for final conv (when enabled in efficient mode)
ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal
skipinit: bool = False
act_layer: str = 'silu'
model_cfgs = dict(
# EffNet influenced RegNet defs
nf_regnet_b0=NfCfg(depths=(1, 3, 6, 6), channels=(48, 104, 208, 440), num_features=1280),
nf_regnet_b1=NfCfg(depths=(2, 4, 7, 7), channels=(48, 104, 208, 440), num_features=1280),
nf_regnet_b2=NfCfg(depths=(2, 4, 8, 8), channels=(56, 112, 232, 488), num_features=1416),
nf_regnet_b3=NfCfg(depths=(2, 5, 9, 9), channels=(56, 128, 248, 528), num_features=1536),
nf_regnet_b4=NfCfg(depths=(2, 6, 11, 11), channels=(64, 144, 288, 616), num_features=1792),
nf_regnet_b5=NfCfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704), num_features=2048),
# ResNet (preact, D style deep stem/avg down) defs
nf_resnet26d=NfCfg(
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer=None,),
nf_resnet50d=NfCfg(
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer=None),
nf_resnet101d=NfCfg(
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer=None),
nf_seresnet26d=NfCfg(
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
nf_seresnet50d=NfCfg(
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
nf_seresnet101d=NfCfg(
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)),
nf_ecaresnet26d=NfCfg(
depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048),
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
nf_ecaresnet50d=NfCfg(
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
nf_ecaresnet101d=NfCfg(
depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048),
stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None,
act_layer='relu', attn_layer='eca', attn_kwargs=dict()),
)
# class NormFreeSiLU(nn.Module):
# _K = 1. / 0.5595
# def __init__(self, inplace=False):
# super().__init__()
# self.inplace = inplace
#
# def forward(self, x):
# return F.silu(x, inplace=self.inplace) * self._K
#
#
# class NormFreeReLU(nn.Module):
# _K = (0.5 * (1. - 1. / math.pi)) ** -0.5
#
# def __init__(self, inplace=False):
# super().__init__()
# self.inplace = inplace
#
# def forward(self, x):
# return F.relu(x, inplace=self.inplace) * self._K
class DownsampleAvg(nn.Module):
def __init__(
self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, conv_layer=ScaledStdConv2d):
""" AvgPool Downsampling as in 'D' ResNet variants. Support for dilation."""
super(DownsampleAvg, self).__init__()
avg_stride = stride if dilation == 1 else 1
if stride > 1 or dilation > 1:
avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
else:
self.pool = nn.Identity()
self.conv = conv_layer(in_chs, out_chs, 1, stride=1)
def forward(self, x):
return self.conv(self.pool(x))
class NormalizationFreeBlock(nn.Module):
"""Normalization-free pre-activation block.
"""
def __init__(
self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None,
alpha=1.0, beta=1.0, bottle_ratio=0.25, efficient=True, ch_div=1, group_size=None,
attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0., skipinit=False):
super().__init__()
first_dilation = first_dilation or dilation
out_chs = out_chs or in_chs
# EfficientNet-like models scale bottleneck from in_chs, otherwise scale from out_chs like ResNet
mid_chs = make_divisible(in_chs * bottle_ratio if efficient else out_chs * bottle_ratio, ch_div)
groups = 1
if group_size is not None:
# NOTE: not correcting the mid_chs % group_size, fix model def if broken. I want % ch_div == 0 to stand.
groups = mid_chs // group_size
self.alpha = alpha
self.beta = beta
self.attn_gain = attn_gain
if in_chs != out_chs or stride != 1 or dilation != first_dilation:
self.downsample = DownsampleAvg(
in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, conv_layer=conv_layer)
else:
self.downsample = None
self.act1 = act_layer()
self.conv1 = conv_layer(in_chs, mid_chs, 1)
self.act2 = act_layer(inplace=True)
self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups)
if attn_layer is not None:
self.attn = attn_layer(mid_chs)
else:
self.attn = None
self.act3 = act_layer()
self.conv3 = conv_layer(mid_chs, out_chs, 1)
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.skipinit_gain = nn.Parameter(torch.tensor(0.)) if skipinit else None
def forward(self, x):
out = self.act1(x) * self.beta
# shortcut branch
shortcut = x
if self.downsample is not None:
shortcut = self.downsample(out)
# residual branch
out = self.conv1(out)
out = self.conv2(self.act2(out))
if self.attn is not None:
out = self.attn_gain * self.attn(out)
out = self.conv3(self.act3(out))
out = self.drop_path(out)
if self.skipinit_gain is None:
out = out * self.alpha + shortcut
else:
# this really slows things down for some reason, TBD
out = out * self.alpha * self.skipinit_gain + shortcut
return out
def create_stem(in_chs, out_chs, stem_type='', conv_layer=None):
stem = OrderedDict()
assert stem_type in ('', 'deep', '3x3', '7x7')
if 'deep' in stem_type:
# 3 deep 3x3 conv stack as in ResNet V1D models
mid_chs = out_chs // 2
stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2)
stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1)
stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1)
elif '3x3' in stem_type:
# 3x3 stem conv as in RegNet
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=3, stride=2)
else:
# 7x7 stem conv as in ResNet
stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2)
return nn.Sequential(stem)
_nonlin_gamma = dict(
silu=.5595,
relu=(0.5 * (1. - 1. / math.pi)) ** 0.5,
identity=1.0
)
class NormalizerFreeNet(nn.Module):
""" Normalizer-free ResNets and RegNets
As described in `Characterizing signal propagation to close the performance gap in unnormalized ResNets`
- https://arxiv.org/abs/2101.08692
This model aims to cover both the NFRegNet-Bx models as detailed in the paper's code snippets and
the (preact) ResNet models described earlier in the paper.
There are a few differences:
* channels are rounded to be divisible by 8 by default (keep TC happy), this changes param counts
* activation correcting gamma constants are moved into the ScaledStdConv as it has less performance
impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl.
* skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput
for what it is/does. Approx 8-10% throughput loss.
"""
def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32,
drop_rate=0., drop_path_rate=0.):
super().__init__()
self.num_classes = num_classes
self.drop_rate = drop_rate
act_layer = get_act_layer(cfg.act_layer)
assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})."
conv_layer = partial(ScaledStdConv2d, bias=True, gain=True, gamma=_nonlin_gamma[cfg.act_layer])
attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None
self.feature_info = [] # FIXME fill out feature info
stem_chs = cfg.stem_chs or cfg.channels[0]
stem_chs = make_divisible(stem_chs * cfg.width_factor, cfg.ch_div)
self.stem = create_stem(in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer)
prev_chs = stem_chs
dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)]
net_stride = 2
dilation = 1
expected_var = 1.0
stages = []
for stage_idx, stage_depth in enumerate(cfg.depths):
if net_stride >= output_stride:
dilation *= 2
stride = 1
else:
stride = 2
net_stride *= stride
first_dilation = 1 if dilation in (1, 2) else 2
blocks = []
for block_idx in range(cfg.depths[stage_idx]):
first_block = block_idx == 0 and stage_idx == 0
out_chs = make_divisible(cfg.channels[stage_idx] * cfg.width_factor, cfg.ch_div)
blocks += [NormalizationFreeBlock(
in_chs=prev_chs, out_chs=out_chs,
alpha=cfg.alpha,
beta=1. / expected_var ** 0.5, # NOTE: beta used as multiplier in block
stride=stride if block_idx == 0 else 1,
dilation=dilation,
first_dilation=first_dilation,
group_size=cfg.group_size,
bottle_ratio=1. if cfg.efficient and first_block else cfg.bottle_ratio,
efficient=cfg.efficient,
ch_div=cfg.ch_div,
attn_layer=attn_layer,
attn_gain=cfg.attn_gain,
act_layer=act_layer,
conv_layer=conv_layer,
drop_path_rate=dpr[stage_idx][block_idx],
skipinit=cfg.skipinit,
)]
if block_idx == 0:
expected_var = 1. # expected var is reset after first block of each stage
expected_var += cfg.alpha ** 2 # Even if reset occurs, increment expected variance
first_dilation = dilation
prev_chs = out_chs
stages += [nn.Sequential(*blocks)]
self.stages = nn.Sequential(*stages)
if cfg.efficient and cfg.num_features:
# The paper NFRegNet models have an EfficientNet-like final head convolution.
self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div)
self.final_conv = conv_layer(prev_chs, self.num_features, 1)
else:
self.num_features = prev_chs
self.final_conv = nn.Identity()
self.final_act = act_layer()
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
for n, m in self.named_modules():
if 'fc' in n and isinstance(m, nn.Linear):
nn.init.zeros_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Conv2d):
# as per discussion with paper authors, original in haiku is
# hk.initializers.VarianceScaling(1.0, 'fan_in', 'normal')' w/ zero'd bias
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='linear')
if m.bias is not None:
nn.init.zeros_(m.bias)
def get_classifier(self):
return self.head.fc
def reset_classifier(self, num_classes, global_pool='avg'):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def forward_features(self, x):
x = self.stem(x)
x = self.stages(x)
x = self.final_conv(x)
x = self.final_act(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
def _create_normfreenet(variant, pretrained=False, **kwargs):
feature_cfg = dict(flatten_sequential=True)
feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks
return build_model_with_cfg(
NormalizerFreeNet, variant, pretrained, model_cfg=model_cfgs[variant], default_cfg=default_cfgs[variant],
feature_cfg=feature_cfg, **kwargs)
@register_model
def nf_regnet_b0(pretrained=False, **kwargs):
return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs)
@register_model
def nf_regnet_b1(pretrained=False, **kwargs):
return _create_normfreenet('nf_regnet_b1', pretrained=pretrained, **kwargs)
@register_model
def nf_regnet_b2(pretrained=False, **kwargs):
return _create_normfreenet('nf_regnet_b2', pretrained=pretrained, **kwargs)
@register_model
def nf_regnet_b3(pretrained=False, **kwargs):
return _create_normfreenet('nf_regnet_b3', pretrained=pretrained, **kwargs)
@register_model
def nf_regnet_b4(pretrained=False, **kwargs):
return _create_normfreenet('nf_regnet_b4', pretrained=pretrained, **kwargs)
@register_model
def nf_regnet_b5(pretrained=False, **kwargs):
return _create_normfreenet('nf_regnet_b5', pretrained=pretrained, **kwargs)
@register_model
def nf_resnet26d(pretrained=False, **kwargs):
return _create_normfreenet('nf_resnet26d', pretrained=pretrained, **kwargs)
@register_model
def nf_resnet50d(pretrained=False, **kwargs):
return _create_normfreenet('nf_resnet50d', pretrained=pretrained, **kwargs)
@register_model
def nf_seresnet26d(pretrained=False, **kwargs):
return _create_normfreenet('nf_seresnet26d', pretrained=pretrained, **kwargs)
@register_model
def nf_seresnet50d(pretrained=False, **kwargs):
return _create_normfreenet('nf_seresnet50d', pretrained=pretrained, **kwargs)
@register_model
def nf_ecaresnet26d(pretrained=False, **kwargs):
return _create_normfreenet('nf_ecaresnet26d', pretrained=pretrained, **kwargs)
@register_model
def nf_ecaresnet50d(pretrained=False, **kwargs):
return _create_normfreenet('nf_ecaresnet50d', pretrained=pretrained, **kwargs)

@ -32,13 +32,12 @@ from collections import OrderedDict # pylint: disable=g-importing-member
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
from .helpers import build_model_with_cfg
from .registry import register_model
from .layers import get_padding, GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, conv2d_same
from .layers import GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d
def _cfg(url='', **kwargs):
@ -112,43 +111,6 @@ def make_div(v, divisor=8):
return new_v
class StdConv2d(nn.Conv2d):
def __init__(
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, bias=False, groups=1, eps=1e-5):
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
in_channel, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=bias, groups=groups)
self.eps = eps
def forward(self, x):
w = self.weight
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
w = (w - m) / (torch.sqrt(v) + self.eps)
x = F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
return x
class StdConv2dSame(nn.Conv2d):
"""StdConv2d w/ TF compatible SAME padding. Used for ViT Hybrid model.
"""
def __init__(
self, in_channel, out_channels, kernel_size, stride=1, dilation=1, bias=False, groups=1, eps=1e-5):
padding = get_padding(kernel_size, stride, dilation)
super().__init__(
in_channel, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=bias, groups=groups)
self.eps = eps
def forward(self, x):
w = self.weight
v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False)
w = (w - m) / (torch.sqrt(v) + self.eps)
x = conv2d_same(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups)
return x
def tf2th(conv_weights):
"""Possibly convert HWIO to OIHW."""
if conv_weights.ndim == 4:

@ -15,7 +15,7 @@ from math import ceil
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import build_model_with_cfg
from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath
from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath, make_divisible
from .registry import register_model
from .efficientnet_builder import efficientnet_init_weights
@ -49,12 +49,6 @@ default_cfgs = dict(
)
def make_divisible(v, divisor=8, min_value=None):
min_value = min_value or divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
return new_v
class SEWithNorm(nn.Module):
def __init__(self, channels, se_ratio=1 / 12., act_layer=nn.ReLU, divisor=1, reduction_channels=None,

@ -28,9 +28,9 @@ import torch.nn.functional as F
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
from .helpers import load_pretrained
from .layers import DropPath, to_2tuple, trunc_normal_
from .layers import StdConv2dSame, DropPath, to_2tuple, trunc_normal_
from .resnet import resnet26d, resnet50d
from .resnetv2 import ResNetV2, StdConv2dSame
from .resnetv2 import ResNetV2
from .registry import register_model
_logger = logging.getLogger(__name__)

Loading…
Cancel
Save