From 78912b6375ae857ef5e9cc99ac951a154c5a5a71 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 1 Dec 2021 12:07:45 -0800 Subject: [PATCH] Updated EvoNorm implementations with some experimentation. Add FilterResponseNorm. Updated RegnetZ and ResNetV2 model defs for trials. --- timm/models/byobnet.py | 37 ++- timm/models/fx_features.py | 7 +- timm/models/layers/__init__.py | 4 +- timm/models/layers/create_act.py | 9 +- timm/models/layers/create_norm_act.py | 50 ++-- timm/models/layers/evo_norm.py | 323 ++++++++++++++++++--- timm/models/layers/filter_response_norm.py | 68 +++++ timm/models/resnetv2.py | 33 ++- timm/models/vovnet.py | 2 +- 9 files changed, 456 insertions(+), 77 deletions(-) create mode 100644 timm/models/layers/filter_response_norm.py diff --git a/timm/models/byobnet.py b/timm/models/byobnet.py index fa57943a..44f26e4e 100644 --- a/timm/models/byobnet.py +++ b/timm/models/byobnet.py @@ -35,7 +35,8 @@ import torch.nn as nn from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from .helpers import build_model_with_cfg, named_apply from .layers import ClassifierHead, ConvBnAct, BatchNormAct2d, DropPath, AvgPool2dSame, \ - create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple, EvoNormSample2d + create_conv2d, get_act_layer, convert_norm_act, get_attn, make_divisible, to_2tuple, EvoNorm2dS0, EvoNorm2dS0a,\ + EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a, FilterResponseNormAct2d, FilterResponseNormTlu2d from .registry import register_model __all__ = ['ByobNet', 'ByoModelCfg', 'ByoBlockCfg', 'create_byob_stem', 'create_block'] @@ -152,6 +153,12 @@ default_cfgs = { 'regnetz_e8': _cfgr( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-attn-weights/regnetz_e8_bh-aace8e6e.pth', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=1.0), + + 'regnetz_b16_evos': _cfgr( + url='', + mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), + input_size=(3, 224, 224), pool_size=(7, 7), test_input_size=(3, 288, 288), first_conv='stem.conv', + crop_pct=0.94), 'regnetz_d8_evob': _cfgr( url='', mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), test_input_size=(3, 320, 320), crop_pct=0.95), @@ -597,6 +604,23 @@ model_cfgs = dict( ), # experimental EvoNorm configs + regnetz_b16_evos=ByoModelCfg( + blocks=( + ByoBlockCfg(type='bottle', d=2, c=48, s=2, gs=16, br=3), + ByoBlockCfg(type='bottle', d=6, c=96, s=2, gs=16, br=3), + ByoBlockCfg(type='bottle', d=12, c=192, s=2, gs=16, br=3), + ByoBlockCfg(type='bottle', d=2, c=288, s=2, gs=16, br=3), + ), + stem_chs=32, + stem_pool='', + downsample='', + num_features=1536, + act_layer='silu', + norm_layer=partial(EvoNorm2dS0a, group_size=16), + attn_layer='se', + attn_kwargs=dict(rd_ratio=0.25), + block_kwargs=dict(bottle_in=True, linear_out=True), + ), regnetz_d8_evob=ByoModelCfg( blocks=( ByoBlockCfg(type='bottle', d=3, c=64, s=1, gs=8, br=4), @@ -610,7 +634,7 @@ model_cfgs = dict( downsample='', num_features=1792, act_layer='silu', - norm_layer='evonormbatch', + norm_layer='evonormb0', attn_layer='se', attn_kwargs=dict(rd_ratio=0.25), block_kwargs=dict(bottle_in=True, linear_out=True), @@ -628,7 +652,7 @@ model_cfgs = dict( downsample='', num_features=1792, act_layer='silu', - norm_layer=partial(EvoNormSample2d, groups=32), + norm_layer=partial(EvoNorm2dS0a, group_size=16), attn_layer='se', attn_kwargs=dict(rd_ratio=0.25), block_kwargs=dict(bottle_in=True, linear_out=True), @@ -856,6 +880,13 @@ def regnetz_e8(pretrained=False, **kwargs): return _create_byobnet('regnetz_e8', pretrained=pretrained, **kwargs) +@register_model +def regnetz_b16_evos(pretrained=False, **kwargs): + """ + """ + return _create_byobnet('regnetz_b16_evos', pretrained=pretrained, **kwargs) + + @register_model def regnetz_d8_evob(pretrained=False, **kwargs): """ diff --git a/timm/models/fx_features.py b/timm/models/fx_features.py index 5a25ee3e..f709d92e 100644 --- a/timm/models/fx_features.py +++ b/timm/models/fx_features.py @@ -14,6 +14,8 @@ except ImportError: # Layers we went to treat as leaf modules from .layers import Conv2dSame, ScaledStdConv2dSame, BatchNormAct2d, BlurPool2d, CondConv2d, StdConv2dSame, DropPath +from .layers import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2 +from .layers import EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a from .layers.non_local_attn import BilinearAttnTransform from .layers.pool2d_same import MaxPool2dSame, AvgPool2dSame @@ -24,9 +26,12 @@ _leaf_modules = { BilinearAttnTransform, # reason: flow control t <= 1 BlurPool2d, # reason: TypeError: F.conv2d received Proxy in groups=x.shape[1] # Reason: get_same_padding has a max which raises a control flow error - Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, + Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]) DropPath, # reason: TypeError: rand recieved Proxy in `size` argument + EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2, # to(dtype) use that causes tracing failure (on scripted models only?) + EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a, + } try: diff --git a/timm/models/layers/__init__.py b/timm/models/layers/__init__.py index 4831af9a..0ed0c3af 100644 --- a/timm/models/layers/__init__.py +++ b/timm/models/layers/__init__.py @@ -14,7 +14,9 @@ from .create_conv2d import create_conv2d from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path from .eca import EcaModule, CecaModule, EfficientChannelAttn, CircularEfficientChannelAttn -from .evo_norm import EvoNormBatch2d, EvoNormSample2d +from .evo_norm import EvoNorm2dB0, EvoNorm2dB1, EvoNorm2dB2,\ + EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a +from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d from .gather_excite import GatherExcite from .global_context import GlobalContext from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible diff --git a/timm/models/layers/create_act.py b/timm/models/layers/create_act.py index aa557692..e38f2e03 100644 --- a/timm/models/layers/create_act.py +++ b/timm/models/layers/create_act.py @@ -116,9 +116,6 @@ def get_act_fn(name: Union[Callable, str] = 'relu'): # custom autograd, then fallback if name in _ACT_FN_ME: return _ACT_FN_ME[name] - if is_exportable() and name in ('silu', 'swish'): - # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack - return swish if not (is_no_jit() or is_exportable()): if name in _ACT_FN_JIT: return _ACT_FN_JIT[name] @@ -132,14 +129,12 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'): """ if not name: return None - if isinstance(name, type): + if not isinstance(name, str): + # callable, module, etc return name if not (is_no_jit() or is_exportable() or is_scriptable()): if name in _ACT_LAYER_ME: return _ACT_LAYER_ME[name] - if is_exportable() and name in ('silu', 'swish'): - # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack - return Swish if not (is_no_jit() or is_exportable()): if name in _ACT_LAYER_JIT: return _ACT_LAYER_JIT[name] diff --git a/timm/models/layers/create_norm_act.py b/timm/models/layers/create_norm_act.py index 5b562945..5d4894a0 100644 --- a/timm/models/layers/create_norm_act.py +++ b/timm/models/layers/create_norm_act.py @@ -9,36 +9,42 @@ Hacked together by / Copyright 2020 Ross Wightman import types import functools -import torch -import torch.nn as nn - -from .evo_norm import EvoNormBatch2d, EvoNormSample2d +from .evo_norm import * +from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d from .norm_act import BatchNormAct2d, GroupNormAct from .inplace_abn import InplaceAbn -_NORM_ACT_TYPES = {BatchNormAct2d, GroupNormAct, EvoNormBatch2d, EvoNormSample2d, InplaceAbn} -_NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, InplaceAbn} # requires act_layer arg to define act type +_NORM_ACT_MAP = dict( + batchnorm=BatchNormAct2d, + groupnorm=GroupNormAct, + evonormb0=EvoNorm2dB0, + evonormb1=EvoNorm2dB1, + evonormb2=EvoNorm2dB2, + evonorms0=EvoNorm2dS0, + evonorms0a=EvoNorm2dS0a, + evonorms1=EvoNorm2dS1, + evonorms1a=EvoNorm2dS1a, + evonorms2=EvoNorm2dS2, + evonorms2a=EvoNorm2dS2a, + frn=FilterResponseNormAct2d, + frntlu=FilterResponseNormTlu2d, + inplaceabn=InplaceAbn, + iabn=InplaceAbn, +) +_NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()} +# has act_layer arg to define act type +_NORM_ACT_REQUIRES_ARG = {BatchNormAct2d, GroupNormAct, FilterResponseNormAct2d, InplaceAbn} -def get_norm_act_layer(layer_class): - layer_class = layer_class.replace('_', '').lower() - if layer_class.startswith("batchnorm"): - layer = BatchNormAct2d - elif layer_class.startswith("groupnorm"): - layer = GroupNormAct - elif layer_class == "evonormbatch": - layer = EvoNormBatch2d - elif layer_class == "evonormsample": - layer = EvoNormSample2d - elif layer_class == "iabn" or layer_class == "inplaceabn": - layer = InplaceAbn - else: - assert False, "Invalid norm_act layer (%s)" % layer_class +def get_norm_act_layer(layer_name): + layer_name = layer_name.replace('_', '').lower().split('-')[0] + layer = _NORM_ACT_MAP.get(layer_name, None) + assert layer is not None, "Invalid norm_act layer (%s)" % layer_name return layer -def create_norm_act(layer_type, num_features, apply_act=True, jit=False, **kwargs): - layer_parts = layer_type.split('-') # e.g. batchnorm-leaky_relu +def create_norm_act(layer_name, num_features, apply_act=True, jit=False, **kwargs): + layer_parts = layer_name.split('-') # e.g. batchnorm-leaky_relu assert len(layer_parts) in (1, 2) layer = get_norm_act_layer(layer_parts[0]) #activation_class = layer_parts[1].lower() if len(layer_parts) > 1 else '' # FIXME support string act selection? diff --git a/timm/models/layers/evo_norm.py b/timm/models/layers/evo_norm.py index 6ef0c881..d42c502c 100644 --- a/timm/models/layers/evo_norm.py +++ b/timm/models/layers/evo_norm.py @@ -1,81 +1,332 @@ -"""EvoNormB0 (Batched) and EvoNormS0 (Sample) in PyTorch +""" EvoNorm in PyTorch + +Based on `Evolving Normalization-Activation Layers` - https://arxiv.org/abs/2004.02967 +@inproceedings{NEURIPS2020, + author = {Liu, Hanxiao and Brock, Andy and Simonyan, Karen and Le, Quoc}, + booktitle = {Advances in Neural Information Processing Systems}, + editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. Lin}, + pages = {13539--13550}, + publisher = {Curran Associates, Inc.}, + title = {Evolving Normalization-Activation Layers}, + url = {https://proceedings.neurips.cc/paper/2020/file/9d4c03631b8b0c85ae08bf05eda37d0f-Paper.pdf}, + volume = {33}, + year = {2020} +} An attempt at getting decent performing EvoNorms running in PyTorch. -While currently faster than other impl, still quite a ways off the built-in BN -in terms of memory usage and throughput (roughly 5x mem, 1/2 - 1/3x speed). +While faster than other PyTorch impl, still quite a ways off the built-in BatchNorm +in terms of memory usage and throughput on GPUs. -Still very much a WIP, fiddling with buffer usage, in-place/jit optimizations, and layouts. +I'm testing these modules on TPU w/ PyTorch XLA. Promising start but +currently working around some issues with builtin torch/tensor.var/std. Unlike +GPU, similar train speeds for EvoNormS variants and BatchNorm. Hacked together by / Copyright 2020 Ross Wightman """ import torch import torch.nn as nn +import torch.nn.functional as F +from .create_act import create_act_layer from .trace_utils import _assert -class EvoNormBatch2d(nn.Module): - def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, drop_block=None): - super(EvoNormBatch2d, self).__init__() +def instance_std(x, eps: float = 1e-5): + rms = x.float().var(dim=(2, 3), unbiased=False, keepdim=True).add(eps).sqrt().to(x.dtype) + return rms.expand(x.shape) + + +def instance_rms(x, eps: float = 1e-5): + rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).sqrt().to(dtype=x.dtype) + return rms.expand(x.shape) + + +def group_std(x, groups: int = 32, eps: float = 1e-5, flatten: bool = False): + B, C, H, W = x.shape + x_dtype = x.dtype + _assert(C % groups == 0, '') + # x = x.reshape(B, groups, -1) # FIXME simpler shape causing TPU / XLA issues + # std = x.float().var(dim=2, unbiased=False, keepdim=True).add(eps).sqrt() + x = x.reshape(B, groups, C // groups, H, W) + std = x.float().var(dim=(2, 3, 4), unbiased=False, keepdim=True).add(eps).sqrt() + return std.expand(x.shape).reshape(B, C, H, W).to(x_dtype) + + +def group_std_tpu(x, groups: int = 32, eps: float = 1e-5, diff_sqm: bool = False): + # This is a workaround for some stability / odd behaviour of .var and .std + # running on PyTorch XLA w/ TPUs. These manual var impl are producing much better results + B, C, H, W = x.shape + _assert(C % groups == 0, '') + x_dtype = x.dtype + x = x.float().reshape(B, groups, C // groups, H, W) + xm = x.mean(dim=(2, 3, 4), keepdim=True) + if diff_sqm: + # difference of squared mean and mean squared, faster on TPU + var = (x.square().mean(dim=(2, 3, 4), keepdim=True) - xm.square()).clamp(0) + else: + var = (x - xm).square().mean(dim=(2, 3, 4), keepdim=True) + return var.add(eps).sqrt().expand(x.shape).reshape(B, C, H, W).to(x_dtype) +# group_std = group_std_tpu # temporary, for TPU / PT XLA + + +def group_rms(x, groups: int = 32, eps: float = 1e-5): + B, C, H, W = x.shape + _assert(C % groups == 0, '') + x_dtype = x.dtype + x = x.reshape(B, groups, C // groups, H, W) + sqm = x.square().mean(dim=(2, 3, 4), keepdim=True).add(eps).sqrt_().to(dtype=x_dtype) + return sqm.expand(x.shape).reshape(B, C, H, W) + + +class EvoNorm2dB0(nn.Module): + def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_): + super().__init__() self.apply_act = apply_act # apply activation (non-linearity) self.momentum = momentum self.eps = eps - self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True) - self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True) - self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None self.register_buffer('running_var', torch.ones(num_features)) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) - if self.apply_act: + if self.v is not None: nn.init.ones_(self.v) def forward(self, x): _assert(x.dim() == 4, 'expected 4D input') - x_type = x.dtype + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) if self.v is not None: - running_var = self.running_var.view(1, -1, 1, 1) if self.training: - var = x.var(dim=(0, 2, 3), unbiased=False, keepdim=True) + var = x.float().var(dim=(0, 2, 3), unbiased=False) + n = x.numel() / x.shape[1] + self.running_var.copy_( + self.running_var * (1 - self.momentum) + + var.detach() * self.momentum * (n / (n - 1))) + else: + var = self.running_var + left = var.add(self.eps).sqrt_().to(x_dtype).view(v_shape).expand_as(x) + v = self.v.to(x_dtype).view(v_shape) + right = x * v + instance_std(x, self.eps) + x = x / left.max(right) + return x * self.weight.to(x_dtype).view(v_shape) + self.bias.to(x_dtype).view(v_shape) + + +class EvoNorm2dB1(nn.Module): + def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_): + super().__init__() + self.apply_act = apply_act # apply activation (non-linearity) + self.momentum = momentum + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + if self.apply_act: + if self.training: + var = x.float().var(dim=(0, 2, 3), unbiased=False) n = x.numel() / x.shape[1] - running_var = var.detach() * self.momentum * (n / (n - 1)) + running_var * (1 - self.momentum) - self.running_var.copy_(running_var.view(self.running_var.shape)) + self.running_var.copy_( + self.running_var * (1 - self.momentum) + + var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1))) else: - var = running_var - v = self.v.to(dtype=x_type).reshape(1, -1, 1, 1) - d = x * v + (x.var(dim=(2, 3), unbiased=False, keepdim=True) + self.eps).sqrt().to(dtype=x_type) - d = d.max((var + self.eps).sqrt().to(dtype=x_type)) - x = x / d - return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) + var = self.running_var + var = var.to(dtype=x_dtype).view(v_shape) + left = var.add(self.eps).sqrt_() + right = (x + 1) * instance_rms(x, self.eps) + x = x / left.max(right) + return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) -class EvoNormSample2d(nn.Module): - def __init__(self, num_features, apply_act=True, groups=32, eps=1e-5, drop_block=None): - super(EvoNormSample2d, self).__init__() +class EvoNorm2dB2(nn.Module): + def __init__(self, num_features, apply_act=True, momentum=0.1, eps=1e-5, **_): + super().__init__() self.apply_act = apply_act # apply activation (non-linearity) - self.groups = groups + self.momentum = momentum self.eps = eps - self.weight = nn.Parameter(torch.ones(num_features), requires_grad=True) - self.bias = nn.Parameter(torch.zeros(num_features), requires_grad=True) - self.v = nn.Parameter(torch.ones(num_features), requires_grad=True) if apply_act else None + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) if self.apply_act: + if self.training: + var = x.float().var(dim=(0, 2, 3), unbiased=False) + n = x.numel() / x.shape[1] + self.running_var.copy_( + self.running_var * (1 - self.momentum) + + var.detach().to(dtype=self.running_var.dtype) * self.momentum * (n / (n - 1))) + else: + var = self.running_var + var = var.to(dtype=x_dtype).view(v_shape) + left = var.add(self.eps).sqrt_() + right = instance_rms(x, self.eps) - x + x = x / left.max(right) + return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + + +class EvoNorm2dS0(nn.Module): + def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_): + super().__init__() + self.apply_act = apply_act # apply activation (non-linearity) + if group_size: + assert num_features % group_size == 0 + self.groups = num_features // group_size + else: + self.groups = groups + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.v = nn.Parameter(torch.ones(num_features)) if apply_act else None + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + if self.v is not None: nn.init.ones_(self.v) def forward(self, x): _assert(x.dim() == 4, 'expected 4D input') - B, C, H, W = x.shape - _assert(C % self.groups == 0, '') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + if self.v is not None: + v = self.v.view(v_shape).to(dtype=x_dtype) + x = x * (x * v).sigmoid() / group_std(x, self.groups, self.eps) + return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + + +class EvoNorm2dS0a(EvoNorm2dS0): + def __init__(self, num_features, groups=32, group_size=None, apply_act=True, eps=1e-5, **_): + super().__init__( + num_features, groups=groups, group_size=group_size, apply_act=apply_act, eps=eps) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + d = group_std(x, self.groups, self.eps) if self.v is not None: - n = x * (x * self.v.view(1, -1, 1, 1)).sigmoid() - x = x.reshape(B, self.groups, -1) - x = n.reshape(B, self.groups, -1) / (x.var(dim=-1, unbiased=False, keepdim=True) + self.eps).sqrt() - x = x.reshape(B, C, H, W) - return x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1) + v = self.v.view(v_shape).to(dtype=x_dtype) + x = x * (x * v).sigmoid_() + x = x / d + return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + + +class EvoNorm2dS1(nn.Module): + def __init__( + self, num_features, groups=32, group_size=None, + apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_): + super().__init__() + self.apply_act = apply_act # apply activation (non-linearity) + if act_layer is not None and apply_act: + self.act = create_act_layer(act_layer) + else: + self.act = nn.Identity() + if group_size: + assert num_features % group_size == 0 + self.groups = num_features // group_size + else: + self.groups = groups + self.eps = eps + self.pre_act_norm = False + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + if self.apply_act: + x = self.act(x) / group_std(x, self.groups, self.eps) + return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + + +class EvoNorm2dS1a(EvoNorm2dS1): + def __init__( + self, num_features, groups=32, group_size=None, + apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_): + super().__init__( + num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + x = self.act(x) / group_std(x, self.groups, self.eps) + return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + + +class EvoNorm2dS2(nn.Module): + def __init__( + self, num_features, groups=32, group_size=None, + apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_): + super().__init__() + self.apply_act = apply_act # apply activation (non-linearity) + if act_layer is not None and apply_act: + self.act = create_act_layer(act_layer) + else: + self.act = nn.Identity() + if group_size: + assert num_features % group_size == 0 + self.groups = num_features // group_size + else: + self.groups = groups + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + if self.apply_act: + x = self.act(x) / group_rms(x, self.groups, self.eps) + return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + + +class EvoNorm2dS2a(EvoNorm2dS2): + def __init__( + self, num_features, groups=32, group_size=None, + apply_act=True, act_layer=nn.SiLU, eps=1e-5, **_): + super().__init__( + num_features, groups=groups, group_size=group_size, apply_act=apply_act, act_layer=act_layer, eps=eps) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + x = self.act(x) / group_rms(x, self.groups, self.eps) + return x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) diff --git a/timm/models/layers/filter_response_norm.py b/timm/models/layers/filter_response_norm.py new file mode 100644 index 00000000..a66a1cd4 --- /dev/null +++ b/timm/models/layers/filter_response_norm.py @@ -0,0 +1,68 @@ +""" Filter Response Norm in PyTorch + +Based on `Filter Response Normalization Layer` - https://arxiv.org/abs/1911.09737 + +Hacked together by / Copyright 2021 Ross Wightman +""" +import torch +import torch.nn as nn + +from .create_act import create_act_layer +from .trace_utils import _assert + + +def inv_instance_rms(x, eps: float = 1e-5): + rms = x.square().float().mean(dim=(2, 3), keepdim=True).add(eps).rsqrt().to(x.dtype) + return rms.expand(x.shape) + + +class FilterResponseNormTlu2d(nn.Module): + def __init__(self, num_features, apply_act=True, eps=1e-5, rms=True, **_): + super(FilterResponseNormTlu2d, self).__init__() + self.apply_act = apply_act # apply activation (non-linearity) + self.rms = rms + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.tau = nn.Parameter(torch.zeros(num_features)) if apply_act else None + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + if self.tau is not None: + nn.init.zeros_(self.tau) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + x = x * inv_instance_rms(x, self.eps) + x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + return torch.maximum(x, self.tau.reshape(v_shape).to(dtype=x_dtype)) if self.tau is not None else x + + +class FilterResponseNormAct2d(nn.Module): + def __init__(self, num_features, apply_act=True, act_layer=nn.ReLU, inplace=None, rms=True, eps=1e-5, **_): + super(FilterResponseNormAct2d, self).__init__() + if act_layer is not None and apply_act: + self.act = create_act_layer(act_layer, inplace=inplace) + else: + self.act = nn.Identity() + self.rms = rms + self.eps = eps + self.weight = nn.Parameter(torch.ones(num_features)) + self.bias = nn.Parameter(torch.zeros(num_features)) + self.reset_parameters() + + def reset_parameters(self): + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, x): + _assert(x.dim() == 4, 'expected 4D input') + x_dtype = x.dtype + v_shape = (1, -1, 1, 1) + x = x * inv_instance_rms(x, self.eps) + x = x * self.weight.view(v_shape).to(dtype=x_dtype) + self.bias.view(v_shape).to(dtype=x_dtype) + return self.act(x) diff --git a/timm/models/resnetv2.py b/timm/models/resnetv2.py index e38eaf5e..2c6fb9a0 100644 --- a/timm/models/resnetv2.py +++ b/timm/models/resnetv2.py @@ -38,7 +38,8 @@ from functools import partial from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD from .helpers import build_model_with_cfg, named_apply, adapt_input_conv from .registry import register_model -from .layers import GroupNormAct, BatchNormAct2d, EvoNormBatch2d, EvoNormSample2d,\ +from .layers import GroupNormAct, BatchNormAct2d, EvoNorm2dB0, EvoNorm2dS0,\ + EvoNorm2dS1, EvoNorm2dS2, FilterResponseNormTlu2d, FilterResponseNormAct2d,\ ClassifierHead, DropPath, AvgPool2dSame, create_pool2d, StdConv2d, create_conv2d @@ -125,7 +126,11 @@ default_cfgs = { interpolation='bicubic', first_conv='stem.conv1'), 'resnetv2_50d_evob': _cfg( interpolation='bicubic', first_conv='stem.conv1'), - 'resnetv2_50d_evos': _cfg( + 'resnetv2_50d_evos0': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), + 'resnetv2_50d_evos1': _cfg( + interpolation='bicubic', first_conv='stem.conv1'), + 'resnetv2_50d_frn': _cfg( interpolation='bicubic', first_conv='stem.conv1'), } @@ -660,13 +665,29 @@ def resnetv2_50d_gn(pretrained=False, **kwargs): def resnetv2_50d_evob(pretrained=False, **kwargs): return _create_resnetv2( 'resnetv2_50d_evob', pretrained=pretrained, - layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormBatch2d, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dB0, + stem_type='deep', avg_down=True, zero_init_last=True, **kwargs) + + +@register_model +def resnetv2_50d_evos0(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50d_evos0', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNorm2dS0, + stem_type='deep', avg_down=True, **kwargs) + + +@register_model +def resnetv2_50d_evos1(pretrained=False, **kwargs): + return _create_resnetv2( + 'resnetv2_50d_evos1', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=partial(EvoNorm2dS1, group_size=16), stem_type='deep', avg_down=True, **kwargs) @register_model -def resnetv2_50d_evos(pretrained=False, **kwargs): +def resnetv2_50d_frn(pretrained=False, **kwargs): return _create_resnetv2( - 'resnetv2_50d_evos', pretrained=pretrained, - layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=EvoNormSample2d, + 'resnetv2_50d_frn', pretrained=pretrained, + layers=[3, 4, 6, 3], conv_layer=create_conv2d, norm_layer=FilterResponseNormTlu2d, stem_type='deep', avg_down=True, **kwargs) diff --git a/timm/models/vovnet.py b/timm/models/vovnet.py index ec5b3e81..608cd45b 100644 --- a/timm/models/vovnet.py +++ b/timm/models/vovnet.py @@ -395,7 +395,7 @@ def eca_vovnet39b(pretrained=False, **kwargs): @register_model def ese_vovnet39b_evos(pretrained=False, **kwargs): def norm_act_fn(num_features, **nkwargs): - return create_norm_act('EvoNormSample', num_features, jit=False, **nkwargs) + return create_norm_act('evonorms0', num_features, jit=False, **nkwargs) return _create_vovnet('ese_vovnet39b_evos', pretrained=pretrained, norm_layer=norm_act_fn, **kwargs)