diff --git a/timm/models/__init__.py b/timm/models/__init__.py index fba6d1b8..a7e85084 100644 --- a/timm/models/__init__.py +++ b/timm/models/__init__.py @@ -11,6 +11,7 @@ from .inception_v3 import * from .inception_v4 import * from .mobilenetv3 import * from .nasnet import * +from .nfnet import * from .pnasnet import * from .regnet import * from .res2net import * diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 142377a9..8f52099f 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -10,13 +10,13 @@ from .config import is_exportable, is_scriptable, is_no_jit, set_exportable, set from .conv2d_same import Conv2dSame, conv2d_same from .conv_bn_act import ConvBnAct from .create_act import create_act_layer, get_act_layer, get_act_fn -from .create_attn import create_attn +from .create_attn import get_attn, create_attn from .create_conv2d import create_conv2d from .create_norm_act import create_norm_act, get_norm_act_layer from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .eca import EcaModule, CecaModule from .evo_norm import EvoNormBatch2d, EvoNormSample2d -from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple +from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible from .inplace_abn import InplaceAbn from .linear import Linear from .mixed_conv2d import MixedConv2d @@ -29,5 +29,6 @@ from .separable_conv import SeparableConv2d, SeparableConvBnAct from .space_to_depth import SpaceToDepthModule from .split_attn import SplitAttnConv2d from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model +from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d from .test_time_pool import TestTimePoolHead, apply_test_time_pool from .weight_init import trunc_normal_ diff --git a/timm/models/layers/create_attn.py b/timm/models/layers/create_attn.py index 59ecd858..f4a4c2c9 100644 --- a/timm/models/layers/create_attn.py +++ b/timm/models/layers/create_attn.py @@ -8,7 +8,7 @@ from .eca import EcaModule, CecaModule from .cbam import CbamModule, LightCbamModule -def create_attn(attn_type, channels, **kwargs): +def get_attn(attn_type): module_cls = None if attn_type is not None: if isinstance(attn_type, str): @@ -32,6 +32,12 @@ def create_attn(attn_type, channels, **kwargs): module_cls = SEModule else: module_cls = attn_type + return module_cls + + +def create_attn(attn_type, channels, **kwargs): + module_cls = get_attn(attn_type) if module_cls is not None: + # NOTE: it's expected the first (positional) argument of all attention layers is the # input channels return module_cls(channels, **kwargs) return None diff --git a/timm/models/layers/helpers.py b/timm/models/layers/helpers.py index 8d7b559b..371b3cf6 100644 --- a/timm/models/layers/helpers.py +++ b/timm/models/layers/helpers.py @@ -22,6 +22,10 @@ to_4tuple = _ntuple(4) to_ntuple = _ntuple - - - +def make_divisible(v, divisor=8, min_value=None): + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v diff --git a/timm/models/layers/se.py b/timm/models/layers/se.py index a896fb71..54c0ef33 100644 --- a/timm/models/layers/se.py +++ b/timm/models/layers/se.py @@ -1,13 +1,27 @@ from torch import nn as nn +import torch.nn.functional as F + from .create_act import create_act_layer +from .helpers import make_divisible class SEModule(nn.Module): - - def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None, - gate_layer='sigmoid'): + """ SE Module as defined in original SE-Nets with a few additions + Additions include: + * min_channels can be specified to keep reduced channel count at a minimum (default: 8) + * divisor can be specified to keep channels rounded to specified values (default: 1) + * reduction channels can be specified directly by arg (if reduction_channels is set) + * reduction channels can be specified by float ratio (if reduction_ratio is set) + """ + def __init__(self, channels, reduction=16, act_layer=nn.ReLU, gate_layer='sigmoid', + reduction_ratio=None, reduction_channels=None, min_channels=8, divisor=1): super(SEModule, self).__init__() - reduction_channels = reduction_channels or max(channels // reduction, min_channels) + if reduction_channels is not None: + reduction_channels = reduction_channels # direct specification highest priority, no rounding/min done + elif reduction_ratio is not None: + reduction_channels = make_divisible(channels * reduction_ratio, divisor, min_channels) + else: + reduction_channels = make_divisible(channels // reduction, divisor, min_channels) self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True) self.act = act_layer(inplace=True) self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True) diff --git a/timm/models/layers/std_conv.py b/timm/models/layers/std_conv.py new file mode 100644 index 00000000..d7f8274e --- /dev/null +++ b/timm/models/layers/std_conv.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + +from .padding import get_padding +from .conv2d_same import conv2d_same + + +def get_weight(module): + std, mean = torch.std_mean(module.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) + weight = (module.weight - mean) / (std + module.eps) + return weight + + +class StdConv2d(nn.Conv2d): + """Conv2d with Weight Standardization. Used for BiT ResNet-V2 models. + + Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - + https://arxiv.org/abs/1903.10520v2 + """ + def __init__( + self, in_channel, out_channels, kernel_size, stride=1, + padding=None, dilation=1, groups=1, bias=False, eps=1e-5): + if padding is None: + padding = get_padding(kernel_size, stride, dilation) + super().__init__( + in_channel, out_channels, kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=bias) + self.eps = eps + + def get_weight(self): + std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) + weight = (self.weight - mean) / (std + self.eps) + return weight + + def forward(self, x): + x = F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) + return x + + +class StdConv2dSame(nn.Conv2d): + """Conv2d with Weight Standardization. TF compatible SAME padding. Used for ViT Hybrid model. + + Paper: `Micro-Batch Training with Batch-Channel Normalization and Weight Standardization` - + https://arxiv.org/abs/1903.10520v2 + """ + def __init__( + self, in_channel, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=False, eps=1e-5): + super().__init__( + in_channel, out_channels, kernel_size, stride=stride, + padding=0, dilation=dilation, groups=groups, bias=bias) + self.eps = eps + + def get_weight(self): + std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) + weight = (self.weight - mean) / (std + self.eps) + return weight + + def forward(self, x): + x = conv2d_same(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) + return x + + +class ScaledStdConv2d(nn.Conv2d): + """Conv2d layer with Scaled Weight Standardization. + + Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` - + https://arxiv.org/abs/2101.08692 + """ + + def __init__(self, in_channels, out_channels, kernel_size, + stride=1, padding=None, dilation=1, groups=1, bias=True, gain=True, gamma=1.0, eps=1e-5): + if padding is None: + padding = get_padding(kernel_size, stride, dilation) + super().__init__( + in_channels, out_channels, kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=groups, bias=bias) + self.gain = nn.Parameter(torch.ones(self.out_channels, 1, 1, 1)) if gain else None + self.gamma = gamma * self.weight[0].numel() ** 0.5 # gamma * sqrt(fan-in) + self.eps = eps + + def get_weight(self): + std, mean = torch.std_mean(self.weight, dim=[1, 2, 3], keepdim=True, unbiased=False) + weight = (self.weight - mean) / (self.gamma * std + self.eps) + if self.gain is not None: + weight = weight * self.gain + return weight + + def forward(self, x): + return F.conv2d(x, self.get_weight(), self.bias, self.stride, self.padding, self.dilation, self.groups) diff --git a/timm/models/nfnet.py b/timm/models/nfnet.py new file mode 100644 index 00000000..69ca9fee --- /dev/null +++ b/timm/models/nfnet.py @@ -0,0 +1,441 @@ +""" Normalizer Free RegNet / ResNet (pre-activation) Models + +Paper: `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + +Hacked together by / copyright Ross Wightman, 2021. +""" +import math +from dataclasses import dataclass, field +from collections import OrderedDict +from typing import Tuple, Optional +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from .helpers import build_model_with_cfg +from .registry import register_model +from .layers import ClassifierHead, DropPath, AvgPool2dSame, ScaledStdConv2d, get_act_layer, get_attn, make_divisible + + +def _dcfg(url='', **kwargs): + return { + 'url': url, + 'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7), + 'crop_pct': 0.875, 'interpolation': 'bicubic', + 'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, + 'first_conv': 'stem.conv', 'classifier': 'head.fc', + **kwargs + } + +# FIXME finish +default_cfgs = { + 'nf_regnet_b0': _dcfg(url=''), + 'nf_regnet_b1': _dcfg(url='', input_size=(3, 240, 240)), + 'nf_regnet_b2': _dcfg(url='', input_size=(3, 256, 256)), + 'nf_regnet_b3': _dcfg(url='', input_size=(3, 272, 272)), + 'nf_regnet_b4': _dcfg(url='', input_size=(3, 320, 320)), + 'nf_regnet_b5': _dcfg(url='', input_size=(3, 384, 384)), + + 'nf_resnet26d': _dcfg(url='', first_conv='stem.conv1'), + 'nf_resnet50d': _dcfg(url='', first_conv='stem.conv1'), + 'nf_resnet101d': _dcfg(url='', first_conv='stem.conv1'), + + 'nf_seresnet26d': _dcfg(url='', first_conv='stem.conv1'), + 'nf_seresnet50d': _dcfg(url='', first_conv='stem.conv1'), + 'nf_seresnet101d': _dcfg(url='', first_conv='stem.conv1'), + + 'nf_ecaresnet26d': _dcfg(url='', first_conv='stem.conv1'), + 'nf_ecaresnet50d': _dcfg(url='', first_conv='stem.conv1'), + 'nf_ecaresnet101d': _dcfg(url='', first_conv='stem.conv1'), +} + + +@dataclass +class NfCfg: + depths: Tuple[int, int, int, int] + channels: Tuple[int, int, int, int] + alpha: float = 0.2 + stem_type: str = '3x3' + stem_chs: Optional[int] = None + group_size: Optional[int] = 8 + attn_layer: Optional[str] = 'se' + attn_kwargs: dict = field(default_factory=lambda: dict(reduction_ratio=0.5, divisor=8)) + attn_gain: float = 2.0 # NF correction gain to apply if attn layer is used + width_factor: float = 0.75 + bottle_ratio: float = 2.25 + efficient: bool = True # enables EfficientNet-like options that are used in paper 'nf_regnet_b*' models + num_features: int = 1280 # num out_channels for final conv (when enabled in efficient mode) + ch_div: int = 8 # round channels % 8 == 0 to keep tensor-core use optimal + skipinit: bool = False + act_layer: str = 'silu' + + +model_cfgs = dict( + # EffNet influenced RegNet defs + nf_regnet_b0=NfCfg(depths=(1, 3, 6, 6), channels=(48, 104, 208, 440), num_features=1280), + nf_regnet_b1=NfCfg(depths=(2, 4, 7, 7), channels=(48, 104, 208, 440), num_features=1280), + nf_regnet_b2=NfCfg(depths=(2, 4, 8, 8), channels=(56, 112, 232, 488), num_features=1416), + nf_regnet_b3=NfCfg(depths=(2, 5, 9, 9), channels=(56, 128, 248, 528), num_features=1536), + nf_regnet_b4=NfCfg(depths=(2, 6, 11, 11), channels=(64, 144, 288, 616), num_features=1792), + nf_regnet_b5=NfCfg(depths=(3, 7, 14, 14), channels=(80, 168, 336, 704), num_features=2048), + + # ResNet (preact, D style deep stem/avg down) defs + nf_resnet26d=NfCfg( + depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048), + stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, + act_layer='relu', attn_layer=None,), + nf_resnet50d=NfCfg( + depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), + stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, + act_layer='relu', attn_layer=None), + nf_resnet101d=NfCfg( + depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), + stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, + act_layer='relu', attn_layer=None), + + + nf_seresnet26d=NfCfg( + depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048), + stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, + act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), + nf_seresnet50d=NfCfg( + depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), + stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, + act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), + nf_seresnet101d=NfCfg( + depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), + stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, + act_layer='relu', attn_layer='se', attn_kwargs=dict(reduction_ratio=0.25)), + + + nf_ecaresnet26d=NfCfg( + depths=(2, 2, 2, 2), channels=(256, 512, 1024, 2048), + stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, + act_layer='relu', attn_layer='eca', attn_kwargs=dict()), + nf_ecaresnet50d=NfCfg( + depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), + stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, + act_layer='relu', attn_layer='eca', attn_kwargs=dict()), + nf_ecaresnet101d=NfCfg( + depths=(3, 4, 6, 3), channels=(256, 512, 1024, 2048), + stem_type='deep', stem_chs=64, width_factor=1.0, bottle_ratio=0.25, efficient=False, group_size=None, + act_layer='relu', attn_layer='eca', attn_kwargs=dict()), + +) + +# class NormFreeSiLU(nn.Module): +# _K = 1. / 0.5595 +# def __init__(self, inplace=False): +# super().__init__() +# self.inplace = inplace +# +# def forward(self, x): +# return F.silu(x, inplace=self.inplace) * self._K +# +# +# class NormFreeReLU(nn.Module): +# _K = (0.5 * (1. - 1. / math.pi)) ** -0.5 +# +# def __init__(self, inplace=False): +# super().__init__() +# self.inplace = inplace +# +# def forward(self, x): +# return F.relu(x, inplace=self.inplace) * self._K + + +class DownsampleAvg(nn.Module): + def __init__( + self, in_chs, out_chs, stride=1, dilation=1, first_dilation=None, conv_layer=ScaledStdConv2d): + """ AvgPool Downsampling as in 'D' ResNet variants. Support for dilation.""" + super(DownsampleAvg, self).__init__() + avg_stride = stride if dilation == 1 else 1 + if stride > 1 or dilation > 1: + avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + self.pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False) + else: + self.pool = nn.Identity() + self.conv = conv_layer(in_chs, out_chs, 1, stride=1) + + def forward(self, x): + return self.conv(self.pool(x)) + + +class NormalizationFreeBlock(nn.Module): + """Normalization-free pre-activation block. + """ + + def __init__( + self, in_chs, out_chs=None, stride=1, dilation=1, first_dilation=None, + alpha=1.0, beta=1.0, bottle_ratio=0.25, efficient=True, ch_div=1, group_size=None, + attn_layer=None, attn_gain=2.0, act_layer=None, conv_layer=None, drop_path_rate=0., skipinit=False): + super().__init__() + first_dilation = first_dilation or dilation + out_chs = out_chs or in_chs + # EfficientNet-like models scale bottleneck from in_chs, otherwise scale from out_chs like ResNet + mid_chs = make_divisible(in_chs * bottle_ratio if efficient else out_chs * bottle_ratio, ch_div) + groups = 1 + if group_size is not None: + # NOTE: not correcting the mid_chs % group_size, fix model def if broken. I want % ch_div == 0 to stand. + groups = mid_chs // group_size + self.alpha = alpha + self.beta = beta + self.attn_gain = attn_gain + + if in_chs != out_chs or stride != 1 or dilation != first_dilation: + self.downsample = DownsampleAvg( + in_chs, out_chs, stride=stride, dilation=dilation, first_dilation=first_dilation, conv_layer=conv_layer) + else: + self.downsample = None + + self.act1 = act_layer() + self.conv1 = conv_layer(in_chs, mid_chs, 1) + self.act2 = act_layer(inplace=True) + self.conv2 = conv_layer(mid_chs, mid_chs, 3, stride=stride, dilation=first_dilation, groups=groups) + if attn_layer is not None: + self.attn = attn_layer(mid_chs) + else: + self.attn = None + self.act3 = act_layer() + self.conv3 = conv_layer(mid_chs, out_chs, 1) + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else nn.Identity() + self.skipinit_gain = nn.Parameter(torch.tensor(0.)) if skipinit else None + + def forward(self, x): + out = self.act1(x) * self.beta + + # shortcut branch + shortcut = x + if self.downsample is not None: + shortcut = self.downsample(out) + + # residual branch + out = self.conv1(out) + out = self.conv2(self.act2(out)) + if self.attn is not None: + out = self.attn_gain * self.attn(out) + out = self.conv3(self.act3(out)) + out = self.drop_path(out) + if self.skipinit_gain is None: + out = out * self.alpha + shortcut + else: + # this really slows things down for some reason, TBD + out = out * self.alpha * self.skipinit_gain + shortcut + return out + + +def create_stem(in_chs, out_chs, stem_type='', conv_layer=None): + stem = OrderedDict() + assert stem_type in ('', 'deep', '3x3', '7x7') + if 'deep' in stem_type: + # 3 deep 3x3 conv stack as in ResNet V1D models + mid_chs = out_chs // 2 + stem['conv1'] = conv_layer(in_chs, mid_chs, kernel_size=3, stride=2) + stem['conv2'] = conv_layer(mid_chs, mid_chs, kernel_size=3, stride=1) + stem['conv3'] = conv_layer(mid_chs, out_chs, kernel_size=3, stride=1) + elif '3x3' in stem_type: + # 3x3 stem conv as in RegNet + stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=3, stride=2) + else: + # 7x7 stem conv as in ResNet + stem['conv'] = conv_layer(in_chs, out_chs, kernel_size=7, stride=2) + + return nn.Sequential(stem) + + +_nonlin_gamma = dict( + silu=.5595, + relu=(0.5 * (1. - 1. / math.pi)) ** 0.5, + identity=1.0 +) + + +class NormalizerFreeNet(nn.Module): + """ Normalizer-free ResNets and RegNets + + As described in `Characterizing signal propagation to close the performance gap in unnormalized ResNets` + - https://arxiv.org/abs/2101.08692 + + This model aims to cover both the NFRegNet-Bx models as detailed in the paper's code snippets and + the (preact) ResNet models described earlier in the paper. + + There are a few differences: + * channels are rounded to be divisible by 8 by default (keep TC happy), this changes param counts + * activation correcting gamma constants are moved into the ScaledStdConv as it has less performance + impact in PyTorch when done with the weight scaling there. This likely wasn't a concern in the JAX impl. + * skipinit is disabled by default, it seems to have a rather drastic impact on GPU memory use and throughput + for what it is/does. Approx 8-10% throughput loss. + """ + def __init__(self, cfg: NfCfg, num_classes=1000, in_chans=3, global_pool='avg', output_stride=32, + drop_rate=0., drop_path_rate=0.): + super().__init__() + self.num_classes = num_classes + self.drop_rate = drop_rate + act_layer = get_act_layer(cfg.act_layer) + assert cfg.act_layer in _nonlin_gamma, f"Please add non-linearity constants for activation ({cfg.act_layer})." + conv_layer = partial(ScaledStdConv2d, bias=True, gain=True, gamma=_nonlin_gamma[cfg.act_layer]) + attn_layer = partial(get_attn(cfg.attn_layer), **cfg.attn_kwargs) if cfg.attn_layer else None + + self.feature_info = [] # FIXME fill out feature info + + stem_chs = cfg.stem_chs or cfg.channels[0] + stem_chs = make_divisible(stem_chs * cfg.width_factor, cfg.ch_div) + self.stem = create_stem(in_chans, stem_chs, cfg.stem_type, conv_layer=conv_layer) + + prev_chs = stem_chs + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(cfg.depths)).split(cfg.depths)] + net_stride = 2 + dilation = 1 + expected_var = 1.0 + stages = [] + for stage_idx, stage_depth in enumerate(cfg.depths): + if net_stride >= output_stride: + dilation *= 2 + stride = 1 + else: + stride = 2 + net_stride *= stride + first_dilation = 1 if dilation in (1, 2) else 2 + + blocks = [] + for block_idx in range(cfg.depths[stage_idx]): + first_block = block_idx == 0 and stage_idx == 0 + out_chs = make_divisible(cfg.channels[stage_idx] * cfg.width_factor, cfg.ch_div) + blocks += [NormalizationFreeBlock( + in_chs=prev_chs, out_chs=out_chs, + alpha=cfg.alpha, + beta=1. / expected_var ** 0.5, # NOTE: beta used as multiplier in block + stride=stride if block_idx == 0 else 1, + dilation=dilation, + first_dilation=first_dilation, + group_size=cfg.group_size, + bottle_ratio=1. if cfg.efficient and first_block else cfg.bottle_ratio, + efficient=cfg.efficient, + ch_div=cfg.ch_div, + attn_layer=attn_layer, + attn_gain=cfg.attn_gain, + act_layer=act_layer, + conv_layer=conv_layer, + drop_path_rate=dpr[stage_idx][block_idx], + skipinit=cfg.skipinit, + )] + if block_idx == 0: + expected_var = 1. # expected var is reset after first block of each stage + expected_var += cfg.alpha ** 2 # Even if reset occurs, increment expected variance + first_dilation = dilation + prev_chs = out_chs + stages += [nn.Sequential(*blocks)] + self.stages = nn.Sequential(*stages) + + if cfg.efficient and cfg.num_features: + # The paper NFRegNet models have an EfficientNet-like final head convolution. + self.num_features = make_divisible(cfg.width_factor * cfg.num_features, cfg.ch_div) + self.final_conv = conv_layer(prev_chs, self.num_features, 1) + else: + self.num_features = prev_chs + self.final_conv = nn.Identity() + self.final_act = act_layer() + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + for n, m in self.named_modules(): + if 'fc' in n and isinstance(m, nn.Linear): + nn.init.zeros_(m.weight) + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Conv2d): + # as per discussion with paper authors, original in haiku is + # hk.initializers.VarianceScaling(1.0, 'fan_in', 'normal')' w/ zero'd bias + nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='linear') + if m.bias is not None: + nn.init.zeros_(m.bias) + + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool='avg'): + self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate) + + def forward_features(self, x): + x = self.stem(x) + x = self.stages(x) + x = self.final_conv(x) + x = self.final_act(x) + return x + + def forward(self, x): + x = self.forward_features(x) + x = self.head(x) + return x + + +def _create_normfreenet(variant, pretrained=False, **kwargs): + feature_cfg = dict(flatten_sequential=True) + feature_cfg['feature_cls'] = 'hook' # pre-act models need hooks to grab feat from act1 in bottleneck blocks + + return build_model_with_cfg( + NormalizerFreeNet, variant, pretrained, model_cfg=model_cfgs[variant], default_cfg=default_cfgs[variant], + feature_cfg=feature_cfg, **kwargs) + + +@register_model +def nf_regnet_b0(pretrained=False, **kwargs): + return _create_normfreenet('nf_regnet_b0', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b1(pretrained=False, **kwargs): + return _create_normfreenet('nf_regnet_b1', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b2(pretrained=False, **kwargs): + return _create_normfreenet('nf_regnet_b2', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b3(pretrained=False, **kwargs): + return _create_normfreenet('nf_regnet_b3', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b4(pretrained=False, **kwargs): + return _create_normfreenet('nf_regnet_b4', pretrained=pretrained, **kwargs) + + +@register_model +def nf_regnet_b5(pretrained=False, **kwargs): + return _create_normfreenet('nf_regnet_b5', pretrained=pretrained, **kwargs) + + +@register_model +def nf_resnet26d(pretrained=False, **kwargs): + return _create_normfreenet('nf_resnet26d', pretrained=pretrained, **kwargs) + + +@register_model +def nf_resnet50d(pretrained=False, **kwargs): + return _create_normfreenet('nf_resnet50d', pretrained=pretrained, **kwargs) + + +@register_model +def nf_seresnet26d(pretrained=False, **kwargs): + return _create_normfreenet('nf_seresnet26d', pretrained=pretrained, **kwargs) + + +@register_model +def nf_seresnet50d(pretrained=False, **kwargs): + return _create_normfreenet('nf_seresnet50d', pretrained=pretrained, **kwargs) + + +@register_model +def nf_ecaresnet26d(pretrained=False, **kwargs): + return _create_normfreenet('nf_ecaresnet26d', pretrained=pretrained, **kwargs) + + +@register_model +def nf_ecaresnet50d(pretrained=False, **kwargs): + return _create_normfreenet('nf_ecaresnet50d', pretrained=pretrained, **kwargs) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index 1acc5eb0..73c2e42c 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -32,13 +32,12 @@ from collections import OrderedDict # pylint: disable=g-importing-member import torch import torch.nn as nn -import torch.nn.functional as F from functools import partial from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .helpers import build_model_with_cfg from .registry import register_model -from .layers import get_padding, GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, conv2d_same +from .layers import GroupNormAct, ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d def _cfg(url='', **kwargs): @@ -112,43 +111,6 @@ def make_div(v, divisor=8): return new_v -class StdConv2d(nn.Conv2d): - - def __init__( - self, in_channel, out_channels, kernel_size, stride=1, dilation=1, bias=False, groups=1, eps=1e-5): - padding = get_padding(kernel_size, stride, dilation) - super().__init__( - in_channel, out_channels, kernel_size, stride=stride, - padding=padding, dilation=dilation, bias=bias, groups=groups) - self.eps = eps - - def forward(self, x): - w = self.weight - v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) - w = (w - m) / (torch.sqrt(v) + self.eps) - x = F.conv2d(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups) - return x - - -class StdConv2dSame(nn.Conv2d): - """StdConv2d w/ TF compatible SAME padding. Used for ViT Hybrid model. - """ - def __init__( - self, in_channel, out_channels, kernel_size, stride=1, dilation=1, bias=False, groups=1, eps=1e-5): - padding = get_padding(kernel_size, stride, dilation) - super().__init__( - in_channel, out_channels, kernel_size, stride=stride, - padding=padding, dilation=dilation, bias=bias, groups=groups) - self.eps = eps - - def forward(self, x): - w = self.weight - v, m = torch.var_mean(w, dim=[1, 2, 3], keepdim=True, unbiased=False) - w = (w - m) / (torch.sqrt(v) + self.eps) - x = conv2d_same(x, w, self.bias, self.stride, self.padding, self.dilation, self.groups) - return x - - def tf2th(conv_weights): """Possibly convert HWIO to OIHW.""" if conv_weights.ndim == 4: diff --git a/timm/models/rexnet.py b/timm/models/rexnet.py index 6444b3c8..c4e9b366 100644 --- a/timm/models/rexnet.py +++ b/timm/models/rexnet.py @@ -15,7 +15,7 @@ from math import ceil from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg -from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath +from .layers import ClassifierHead, create_act_layer, ConvBnAct, DropPath, make_divisible from .registry import register_model from .efficientnet_builder import efficientnet_init_weights @@ -49,12 +49,6 @@ default_cfgs = dict( ) -def make_divisible(v, divisor=8, min_value=None): - min_value = min_value or divisor - new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) - return new_v - - class SEWithNorm(nn.Module): def __init__(self, channels, se_ratio=1 / 12., act_layer=nn.ReLU, divisor=1, reduction_channels=None, diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index ff2510f1..acd4d18d 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -28,9 +28,9 @@ import torch.nn.functional as F from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import load_pretrained -from .layers import DropPath, to_2tuple, trunc_normal_ +from .layers import StdConv2dSame, DropPath, to_2tuple, trunc_normal_ from .resnet import resnet26d, resnet50d -from .resnetv2 import ResNetV2, StdConv2dSame +from .resnetv2 import ResNetV2 from .registry import register_model _logger = logging.getLogger(__name__)